tokio_rustls/
server.rs

1use super::*;
2use crate::common::IoSession;
3use rustls::Session;
4
5/// A wrapper around an underlying raw stream which implements the TLS or SSL
6/// protocol.
7#[derive(Debug)]
8pub struct TlsStream<IO> {
9    pub(crate) io: IO,
10    pub(crate) session: ServerSession,
11    pub(crate) state: TlsState,
12}
13
14impl<IO> TlsStream<IO> {
15    #[inline]
16    pub fn get_ref(&self) -> (&IO, &ServerSession) {
17        (&self.io, &self.session)
18    }
19
20    #[inline]
21    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) {
22        (&mut self.io, &mut self.session)
23    }
24
25    #[inline]
26    pub fn into_inner(self) -> (IO, ServerSession) {
27        (self.io, self.session)
28    }
29}
30
31impl<IO> IoSession for TlsStream<IO> {
32    type Io = IO;
33    type Session = ServerSession;
34
35    #[inline]
36    fn skip_handshake(&self) -> bool {
37        false
38    }
39
40    #[inline]
41    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
42        (&mut self.state, &mut self.io, &mut self.session)
43    }
44
45    #[inline]
46    fn into_io(self) -> Self::Io {
47        self.io
48    }
49}
50
51impl<IO> AsyncRead for TlsStream<IO>
52where
53    IO: AsyncRead + AsyncWrite + Unpin,
54{
55    fn poll_read(
56        self: Pin<&mut Self>,
57        cx: &mut Context<'_>,
58        buf: &mut ReadBuf<'_>,
59    ) -> Poll<io::Result<()>> {
60        let this = self.get_mut();
61        let mut stream =
62            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
63
64        match &this.state {
65            TlsState::Stream | TlsState::WriteShutdown => {
66                let prev = buf.remaining();
67
68                match stream.as_mut_pin().poll_read(cx, buf) {
69                    Poll::Ready(Ok(())) => {
70                        if prev == buf.remaining() {
71                            this.state.shutdown_read();
72                        }
73
74                        Poll::Ready(Ok(()))
75                    }
76                    Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
77                        this.state.shutdown_read();
78                        Poll::Ready(Ok(()))
79                    }
80                    Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
81                    Poll::Pending => Poll::Pending,
82                }
83            }
84            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
85            #[cfg(feature = "early-data")]
86            s => unreachable!("server TLS can not hit this state: {:?}", s),
87        }
88    }
89}
90
91impl<IO> AsyncWrite for TlsStream<IO>
92where
93    IO: AsyncRead + AsyncWrite + Unpin,
94{
95    /// Note: that it does not guarantee the final data to be sent.
96    /// To be cautious, you must manually call `flush`.
97    fn poll_write(
98        self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100        buf: &[u8],
101    ) -> Poll<io::Result<usize>> {
102        let this = self.get_mut();
103        let mut stream =
104            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
105        stream.as_mut_pin().poll_write(cx, buf)
106    }
107
108    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109        let this = self.get_mut();
110        let mut stream =
111            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
112        stream.as_mut_pin().poll_flush(cx)
113    }
114
115    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
116        if self.state.writeable() {
117            self.session.send_close_notify();
118            self.state.shutdown_write();
119        }
120
121        let this = self.get_mut();
122        let mut stream =
123            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
124        stream.as_mut_pin().poll_shutdown(cx)
125    }
126}