1use super::*;
2use crate::common::IoSession;
3use rustls::Session;
4
5#[derive(Debug)]
8pub struct TlsStream<IO> {
9 pub(crate) io: IO,
10 pub(crate) session: ServerSession,
11 pub(crate) state: TlsState,
12}
13
14impl<IO> TlsStream<IO> {
15 #[inline]
16 pub fn get_ref(&self) -> (&IO, &ServerSession) {
17 (&self.io, &self.session)
18 }
19
20 #[inline]
21 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerSession) {
22 (&mut self.io, &mut self.session)
23 }
24
25 #[inline]
26 pub fn into_inner(self) -> (IO, ServerSession) {
27 (self.io, self.session)
28 }
29}
30
31impl<IO> IoSession for TlsStream<IO> {
32 type Io = IO;
33 type Session = ServerSession;
34
35 #[inline]
36 fn skip_handshake(&self) -> bool {
37 false
38 }
39
40 #[inline]
41 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
42 (&mut self.state, &mut self.io, &mut self.session)
43 }
44
45 #[inline]
46 fn into_io(self) -> Self::Io {
47 self.io
48 }
49}
50
51impl<IO> AsyncRead for TlsStream<IO>
52where
53 IO: AsyncRead + AsyncWrite + Unpin,
54{
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 let this = self.get_mut();
61 let mut stream =
62 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
63
64 match &this.state {
65 TlsState::Stream | TlsState::WriteShutdown => {
66 let prev = buf.remaining();
67
68 match stream.as_mut_pin().poll_read(cx, buf) {
69 Poll::Ready(Ok(())) => {
70 if prev == buf.remaining() {
71 this.state.shutdown_read();
72 }
73
74 Poll::Ready(Ok(()))
75 }
76 Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
77 this.state.shutdown_read();
78 Poll::Ready(Ok(()))
79 }
80 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
81 Poll::Pending => Poll::Pending,
82 }
83 }
84 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
85 #[cfg(feature = "early-data")]
86 s => unreachable!("server TLS can not hit this state: {:?}", s),
87 }
88 }
89}
90
91impl<IO> AsyncWrite for TlsStream<IO>
92where
93 IO: AsyncRead + AsyncWrite + Unpin,
94{
95 fn poll_write(
98 self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 buf: &[u8],
101 ) -> Poll<io::Result<usize>> {
102 let this = self.get_mut();
103 let mut stream =
104 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
105 stream.as_mut_pin().poll_write(cx, buf)
106 }
107
108 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
109 let this = self.get_mut();
110 let mut stream =
111 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
112 stream.as_mut_pin().poll_flush(cx)
113 }
114
115 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
116 if self.state.writeable() {
117 self.session.send_close_notify();
118 self.state.shutdown_write();
119 }
120
121 let this = self.get_mut();
122 let mut stream =
123 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
124 stream.as_mut_pin().poll_shutdown(cx)
125 }
126}