tokio_rustls/
client.rs

1use super::*;
2use crate::common::IoSession;
3use rustls::Session;
4
5/// A wrapper around an underlying raw stream which implements the TLS or SSL
6/// protocol.
7#[derive(Debug)]
8pub struct TlsStream<IO> {
9    pub(crate) io: IO,
10    pub(crate) session: ClientSession,
11    pub(crate) state: TlsState,
12}
13
14impl<IO> TlsStream<IO> {
15    #[inline]
16    pub fn get_ref(&self) -> (&IO, &ClientSession) {
17        (&self.io, &self.session)
18    }
19
20    #[inline]
21    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
22        (&mut self.io, &mut self.session)
23    }
24
25    #[inline]
26    pub fn into_inner(self) -> (IO, ClientSession) {
27        (self.io, self.session)
28    }
29}
30
31impl<IO> IoSession for TlsStream<IO> {
32    type Io = IO;
33    type Session = ClientSession;
34
35    #[inline]
36    fn skip_handshake(&self) -> bool {
37        self.state.is_early_data()
38    }
39
40    #[inline]
41    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
42        (&mut self.state, &mut self.io, &mut self.session)
43    }
44
45    #[inline]
46    fn into_io(self) -> Self::Io {
47        self.io
48    }
49}
50
51impl<IO> AsyncRead for TlsStream<IO>
52where
53    IO: AsyncRead + AsyncWrite + Unpin,
54{
55    fn poll_read(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<io::Result<()>> {
60        match self.state {
61            #[cfg(feature = "early-data")]
62            TlsState::EarlyData(..) => Poll::Pending,
63            TlsState::Stream | TlsState::WriteShutdown => {
64                let this = self.get_mut();
65                let mut stream =
66                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67                let prev = buf.remaining();
68
69                match stream.as_mut_pin().poll_read(cx, buf) {
70                    Poll::Ready(Ok(())) => {
71                        if prev == buf.remaining() {
72                            this.state.shutdown_read();
73                        }
74
75                        Poll::Ready(Ok(()))
76                    }
77                    Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
78                        this.state.shutdown_read();
79                        Poll::Ready(Ok(()))
80                    }
81                    output => output,
82                }
83            }
84            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
85        }
86    }
87}
88
89impl<IO> AsyncWrite for TlsStream<IO>
90where
91    IO: AsyncRead + AsyncWrite + Unpin,
92{
93    /// Note: that it does not guarantee the final data to be sent.
94    /// To be cautious, you must manually call `flush`.
95    fn poll_write(
96        self: Pin<&mut Self>,
97        cx: &mut Context<'_>,
98        buf: &[u8],
99    ) -> Poll<io::Result<usize>> {
100        let this = self.get_mut();
101        let mut stream =
102            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
103
104        #[allow(clippy::match_single_binding)]
105        match this.state {
106            #[cfg(feature = "early-data")]
107            TlsState::EarlyData(ref mut pos, ref mut data) => {
108                use std::io::Write;
109
110                // write early data
111                if let Some(mut early_data) = stream.session.early_data() {
112                    let len = match early_data.write(buf) {
113                        Ok(n) => n,
114                        Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
115                            return Poll::Pending
116                        }
117                        Err(err) => return Poll::Ready(Err(err)),
118                    };
119                    if len != 0 {
120                        data.extend_from_slice(&buf[..len]);
121                        return Poll::Ready(Ok(len));
122                    }
123                }
124
125                // complete handshake
126                while stream.session.is_handshaking() {
127                    ready!(stream.handshake(cx))?;
128                }
129
130                // write early data (fallback)
131                if !stream.session.is_early_data_accepted() {
132                    while *pos < data.len() {
133                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
134                        *pos += len;
135                    }
136                }
137
138                // end
139                this.state = TlsState::Stream;
140                stream.as_mut_pin().poll_write(cx, buf)
141            }
142            _ => stream.as_mut_pin().poll_write(cx, buf),
143        }
144    }
145
146    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147        let this = self.get_mut();
148        let mut stream =
149            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
150
151        #[cfg(feature = "early-data")]
152        {
153            if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
154                // complete handshake
155                while stream.session.is_handshaking() {
156                    ready!(stream.handshake(cx))?;
157                }
158
159                // write early data (fallback)
160                if !stream.session.is_early_data_accepted() {
161                    while *pos < data.len() {
162                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
163                        *pos += len;
164                    }
165                }
166
167                this.state = TlsState::Stream;
168            }
169        }
170
171        stream.as_mut_pin().poll_flush(cx)
172    }
173
174    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175        if self.state.writeable() {
176            self.session.send_close_notify();
177            self.state.shutdown_write();
178        }
179
180        #[cfg(feature = "early-data")]
181        {
182            // we skip the handshake
183            if let TlsState::EarlyData(..) = self.state {
184                return Pin::new(&mut self.io).poll_shutdown(cx);
185            }
186        }
187
188        let this = self.get_mut();
189        let mut stream =
190            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
191        stream.as_mut_pin().poll_shutdown(cx)
192    }
193}