1use super::*;
2use crate::common::IoSession;
3use rustls::Session;
4
5#[derive(Debug)]
8pub struct TlsStream<IO> {
9 pub(crate) io: IO,
10 pub(crate) session: ClientSession,
11 pub(crate) state: TlsState,
12}
13
14impl<IO> TlsStream<IO> {
15 #[inline]
16 pub fn get_ref(&self) -> (&IO, &ClientSession) {
17 (&self.io, &self.session)
18 }
19
20 #[inline]
21 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientSession) {
22 (&mut self.io, &mut self.session)
23 }
24
25 #[inline]
26 pub fn into_inner(self) -> (IO, ClientSession) {
27 (self.io, self.session)
28 }
29}
30
31impl<IO> IoSession for TlsStream<IO> {
32 type Io = IO;
33 type Session = ClientSession;
34
35 #[inline]
36 fn skip_handshake(&self) -> bool {
37 self.state.is_early_data()
38 }
39
40 #[inline]
41 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
42 (&mut self.state, &mut self.io, &mut self.session)
43 }
44
45 #[inline]
46 fn into_io(self) -> Self::Io {
47 self.io
48 }
49}
50
51impl<IO> AsyncRead for TlsStream<IO>
52where
53 IO: AsyncRead + AsyncWrite + Unpin,
54{
55 fn poll_read(
56 self: Pin<&mut Self>,
57 cx: &mut Context<'_>,
58 buf: &mut ReadBuf<'_>,
59 ) -> Poll<io::Result<()>> {
60 match self.state {
61 #[cfg(feature = "early-data")]
62 TlsState::EarlyData(..) => Poll::Pending,
63 TlsState::Stream | TlsState::WriteShutdown => {
64 let this = self.get_mut();
65 let mut stream =
66 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67 let prev = buf.remaining();
68
69 match stream.as_mut_pin().poll_read(cx, buf) {
70 Poll::Ready(Ok(())) => {
71 if prev == buf.remaining() {
72 this.state.shutdown_read();
73 }
74
75 Poll::Ready(Ok(()))
76 }
77 Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
78 this.state.shutdown_read();
79 Poll::Ready(Ok(()))
80 }
81 output => output,
82 }
83 }
84 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
85 }
86 }
87}
88
89impl<IO> AsyncWrite for TlsStream<IO>
90where
91 IO: AsyncRead + AsyncWrite + Unpin,
92{
93 fn poll_write(
96 self: Pin<&mut Self>,
97 cx: &mut Context<'_>,
98 buf: &[u8],
99 ) -> Poll<io::Result<usize>> {
100 let this = self.get_mut();
101 let mut stream =
102 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
103
104 #[allow(clippy::match_single_binding)]
105 match this.state {
106 #[cfg(feature = "early-data")]
107 TlsState::EarlyData(ref mut pos, ref mut data) => {
108 use std::io::Write;
109
110 if let Some(mut early_data) = stream.session.early_data() {
112 let len = match early_data.write(buf) {
113 Ok(n) => n,
114 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
115 return Poll::Pending
116 }
117 Err(err) => return Poll::Ready(Err(err)),
118 };
119 if len != 0 {
120 data.extend_from_slice(&buf[..len]);
121 return Poll::Ready(Ok(len));
122 }
123 }
124
125 while stream.session.is_handshaking() {
127 ready!(stream.handshake(cx))?;
128 }
129
130 if !stream.session.is_early_data_accepted() {
132 while *pos < data.len() {
133 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
134 *pos += len;
135 }
136 }
137
138 this.state = TlsState::Stream;
140 stream.as_mut_pin().poll_write(cx, buf)
141 }
142 _ => stream.as_mut_pin().poll_write(cx, buf),
143 }
144 }
145
146 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147 let this = self.get_mut();
148 let mut stream =
149 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
150
151 #[cfg(feature = "early-data")]
152 {
153 if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
154 while stream.session.is_handshaking() {
156 ready!(stream.handshake(cx))?;
157 }
158
159 if !stream.session.is_early_data_accepted() {
161 while *pos < data.len() {
162 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
163 *pos += len;
164 }
165 }
166
167 this.state = TlsState::Stream;
168 }
169 }
170
171 stream.as_mut_pin().poll_flush(cx)
172 }
173
174 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
175 if self.state.writeable() {
176 self.session.send_close_notify();
177 self.state.shutdown_write();
178 }
179
180 #[cfg(feature = "early-data")]
181 {
182 if let TlsState::EarlyData(..) = self.state {
184 return Pin::new(&mut self.io).poll_shutdown(cx);
185 }
186 }
187
188 let this = self.get_mut();
189 let mut stream =
190 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
191 stream.as_mut_pin().poll_shutdown(cx)
192 }
193}