tokio_rustls/common/
mod.rs

1mod handshake;
2
3pub(crate) use handshake::{IoSession, MidHandshake};
4use rustls::Session;
5use std::io::{self, IoSlice, Read, Write};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10#[derive(Debug)]
11pub enum TlsState {
12    #[cfg(feature = "early-data")]
13    EarlyData(usize, Vec<u8>),
14    Stream,
15    ReadShutdown,
16    WriteShutdown,
17    FullyShutdown,
18}
19
20impl TlsState {
21    #[inline]
22    pub fn shutdown_read(&mut self) {
23        match *self {
24            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
25            _ => *self = TlsState::ReadShutdown,
26        }
27    }
28
29    #[inline]
30    pub fn shutdown_write(&mut self) {
31        match *self {
32            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
33            _ => *self = TlsState::WriteShutdown,
34        }
35    }
36
37    #[inline]
38    pub fn writeable(&self) -> bool {
39        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
40    }
41
42    #[inline]
43    pub fn readable(&self) -> bool {
44        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
45    }
46
47    #[inline]
48    #[cfg(feature = "early-data")]
49    pub fn is_early_data(&self) -> bool {
50        matches!(self, TlsState::EarlyData(..))
51    }
52
53    #[inline]
54    #[cfg(not(feature = "early-data"))]
55    pub const fn is_early_data(&self) -> bool {
56        false
57    }
58}
59
60pub struct Stream<'a, IO, S> {
61    pub io: &'a mut IO,
62    pub session: &'a mut S,
63    pub eof: bool,
64}
65
66impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
67    pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
68        Stream {
69            io,
70            session,
71            // The state so far is only used to detect EOF, so either Stream
72            // or EarlyData state should both be all right.
73            eof: false,
74        }
75    }
76
77    pub fn set_eof(mut self, eof: bool) -> Self {
78        self.eof = eof;
79        self
80    }
81
82    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
83        Pin::new(self)
84    }
85
86    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
87        struct Reader<'a, 'b, T> {
88            io: &'a mut T,
89            cx: &'a mut Context<'b>,
90        }
91
92        impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
93            #[inline]
94            fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
95                let mut buf = ReadBuf::new(buf);
96                match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
97                    Poll::Ready(Ok(())) => Ok(buf.filled().len()),
98                    Poll::Ready(Err(err)) => Err(err),
99                    Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
100                }
101            }
102        }
103
104        let mut reader = Reader { io: self.io, cx };
105
106        let n = match self.session.read_tls(&mut reader) {
107            Ok(n) => n,
108            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
109            Err(err) => return Poll::Ready(Err(err)),
110        };
111
112        self.session.process_new_packets().map_err(|err| {
113            // In case we have an alert to send describing this error,
114            // try a last-gasp write -- but don't predate the primary
115            // error.
116            let _ = self.write_io(cx);
117
118            io::Error::new(io::ErrorKind::InvalidData, err)
119        })?;
120
121        Poll::Ready(Ok(n))
122    }
123
124    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
125        struct Writer<'a, 'b, T> {
126            io: &'a mut T,
127            cx: &'a mut Context<'b>,
128        }
129
130        impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
131            #[inline]
132            fn poll_with<U>(
133                &mut self,
134                f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
135            ) -> io::Result<U> {
136                match f(Pin::new(&mut self.io), self.cx) {
137                    Poll::Ready(result) => result,
138                    Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
139                }
140            }
141        }
142
143        impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
144            #[inline]
145            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
146                self.poll_with(|io, cx| io.poll_write(cx, buf))
147            }
148
149            #[inline]
150            fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
151                self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
152            }
153
154            fn flush(&mut self) -> io::Result<()> {
155                self.poll_with(|io, cx| io.poll_flush(cx))
156            }
157        }
158
159        let mut writer = Writer { io: self.io, cx };
160
161        match self.session.write_tls(&mut writer) {
162            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
163            result => Poll::Ready(result),
164        }
165    }
166
167    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
168        let mut wrlen = 0;
169        let mut rdlen = 0;
170
171        loop {
172            let mut write_would_block = false;
173            let mut read_would_block = false;
174
175            while self.session.wants_write() {
176                match self.write_io(cx) {
177                    Poll::Ready(Ok(n)) => wrlen += n,
178                    Poll::Pending => {
179                        write_would_block = true;
180                        break;
181                    }
182                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
183                }
184            }
185
186            while !self.eof && self.session.wants_read() {
187                match self.read_io(cx) {
188                    Poll::Ready(Ok(0)) => self.eof = true,
189                    Poll::Ready(Ok(n)) => rdlen += n,
190                    Poll::Pending => {
191                        read_would_block = true;
192                        break;
193                    }
194                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
195                }
196            }
197
198            return match (self.eof, self.session.is_handshaking()) {
199                (true, true) => {
200                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
201                    Poll::Ready(Err(err))
202                }
203                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
204                (_, true) if write_would_block || read_would_block => {
205                    if rdlen != 0 || wrlen != 0 {
206                        Poll::Ready(Ok((rdlen, wrlen)))
207                    } else {
208                        Poll::Pending
209                    }
210                }
211                (..) => continue,
212            };
213        }
214    }
215}
216
217impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
218    fn poll_read(
219        mut self: Pin<&mut Self>,
220        cx: &mut Context<'_>,
221        buf: &mut ReadBuf<'_>,
222    ) -> Poll<io::Result<()>> {
223        let prev = buf.remaining();
224
225        while buf.remaining() != 0 {
226            let mut would_block = false;
227
228            // read a packet
229            while self.session.wants_read() {
230                match self.read_io(cx) {
231                    Poll::Ready(Ok(0)) => {
232                        self.eof = true;
233                        break;
234                    }
235                    Poll::Ready(Ok(_)) => (),
236                    Poll::Pending => {
237                        would_block = true;
238                        break;
239                    }
240                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
241                }
242            }
243
244            return match self.session.read(buf.initialize_unfilled()) {
245                Ok(0) if prev == buf.remaining() && would_block => Poll::Pending,
246                Ok(n) => {
247                    buf.advance(n);
248
249                    if self.eof || would_block {
250                        break;
251                    } else {
252                        continue;
253                    }
254                }
255                Err(ref err)
256                    if err.kind() == io::ErrorKind::ConnectionAborted
257                        && prev != buf.remaining() =>
258                {
259                    break
260                }
261                Err(err) => Poll::Ready(Err(err)),
262            };
263        }
264
265        Poll::Ready(Ok(()))
266    }
267}
268
269impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
270    fn poll_write(
271        mut self: Pin<&mut Self>,
272        cx: &mut Context,
273        buf: &[u8],
274    ) -> Poll<io::Result<usize>> {
275        let mut pos = 0;
276
277        while pos != buf.len() {
278            let mut would_block = false;
279
280            match self.session.write(&buf[pos..]) {
281                Ok(n) => pos += n,
282                Err(err) => return Poll::Ready(Err(err)),
283            };
284
285            while self.session.wants_write() {
286                match self.write_io(cx) {
287                    Poll::Ready(Ok(0)) | Poll::Pending => {
288                        would_block = true;
289                        break;
290                    }
291                    Poll::Ready(Ok(_)) => (),
292                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
293                }
294            }
295
296            return match (pos, would_block) {
297                (0, true) => Poll::Pending,
298                (n, true) => Poll::Ready(Ok(n)),
299                (_, false) => continue,
300            };
301        }
302
303        Poll::Ready(Ok(pos))
304    }
305
306    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
307        self.session.flush()?;
308        while self.session.wants_write() {
309            ready!(self.write_io(cx))?;
310        }
311        Pin::new(&mut self.io).poll_flush(cx)
312    }
313
314    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315        while self.session.wants_write() {
316            ready!(self.write_io(cx))?;
317        }
318        Pin::new(&mut self.io).poll_shutdown(cx)
319    }
320}
321
322#[cfg(test)]
323mod test_stream;