1mod handshake;
2
3pub(crate) use handshake::{IoSession, MidHandshake};
4use rustls::Session;
5use std::io::{self, IoSlice, Read, Write};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
9
10#[derive(Debug)]
11pub enum TlsState {
12 #[cfg(feature = "early-data")]
13 EarlyData(usize, Vec<u8>),
14 Stream,
15 ReadShutdown,
16 WriteShutdown,
17 FullyShutdown,
18}
19
20impl TlsState {
21 #[inline]
22 pub fn shutdown_read(&mut self) {
23 match *self {
24 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
25 _ => *self = TlsState::ReadShutdown,
26 }
27 }
28
29 #[inline]
30 pub fn shutdown_write(&mut self) {
31 match *self {
32 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
33 _ => *self = TlsState::WriteShutdown,
34 }
35 }
36
37 #[inline]
38 pub fn writeable(&self) -> bool {
39 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
40 }
41
42 #[inline]
43 pub fn readable(&self) -> bool {
44 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
45 }
46
47 #[inline]
48 #[cfg(feature = "early-data")]
49 pub fn is_early_data(&self) -> bool {
50 matches!(self, TlsState::EarlyData(..))
51 }
52
53 #[inline]
54 #[cfg(not(feature = "early-data"))]
55 pub const fn is_early_data(&self) -> bool {
56 false
57 }
58}
59
60pub struct Stream<'a, IO, S> {
61 pub io: &'a mut IO,
62 pub session: &'a mut S,
63 pub eof: bool,
64}
65
66impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
67 pub fn new(io: &'a mut IO, session: &'a mut S) -> Self {
68 Stream {
69 io,
70 session,
71 eof: false,
74 }
75 }
76
77 pub fn set_eof(mut self, eof: bool) -> Self {
78 self.eof = eof;
79 self
80 }
81
82 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
83 Pin::new(self)
84 }
85
86 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
87 struct Reader<'a, 'b, T> {
88 io: &'a mut T,
89 cx: &'a mut Context<'b>,
90 }
91
92 impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
93 #[inline]
94 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
95 let mut buf = ReadBuf::new(buf);
96 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
97 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
98 Poll::Ready(Err(err)) => Err(err),
99 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
100 }
101 }
102 }
103
104 let mut reader = Reader { io: self.io, cx };
105
106 let n = match self.session.read_tls(&mut reader) {
107 Ok(n) => n,
108 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
109 Err(err) => return Poll::Ready(Err(err)),
110 };
111
112 self.session.process_new_packets().map_err(|err| {
113 let _ = self.write_io(cx);
117
118 io::Error::new(io::ErrorKind::InvalidData, err)
119 })?;
120
121 Poll::Ready(Ok(n))
122 }
123
124 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
125 struct Writer<'a, 'b, T> {
126 io: &'a mut T,
127 cx: &'a mut Context<'b>,
128 }
129
130 impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
131 #[inline]
132 fn poll_with<U>(
133 &mut self,
134 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
135 ) -> io::Result<U> {
136 match f(Pin::new(&mut self.io), self.cx) {
137 Poll::Ready(result) => result,
138 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
139 }
140 }
141 }
142
143 impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
144 #[inline]
145 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
146 self.poll_with(|io, cx| io.poll_write(cx, buf))
147 }
148
149 #[inline]
150 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
151 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
152 }
153
154 fn flush(&mut self) -> io::Result<()> {
155 self.poll_with(|io, cx| io.poll_flush(cx))
156 }
157 }
158
159 let mut writer = Writer { io: self.io, cx };
160
161 match self.session.write_tls(&mut writer) {
162 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
163 result => Poll::Ready(result),
164 }
165 }
166
167 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
168 let mut wrlen = 0;
169 let mut rdlen = 0;
170
171 loop {
172 let mut write_would_block = false;
173 let mut read_would_block = false;
174
175 while self.session.wants_write() {
176 match self.write_io(cx) {
177 Poll::Ready(Ok(n)) => wrlen += n,
178 Poll::Pending => {
179 write_would_block = true;
180 break;
181 }
182 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
183 }
184 }
185
186 while !self.eof && self.session.wants_read() {
187 match self.read_io(cx) {
188 Poll::Ready(Ok(0)) => self.eof = true,
189 Poll::Ready(Ok(n)) => rdlen += n,
190 Poll::Pending => {
191 read_would_block = true;
192 break;
193 }
194 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
195 }
196 }
197
198 return match (self.eof, self.session.is_handshaking()) {
199 (true, true) => {
200 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
201 Poll::Ready(Err(err))
202 }
203 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
204 (_, true) if write_would_block || read_would_block => {
205 if rdlen != 0 || wrlen != 0 {
206 Poll::Ready(Ok((rdlen, wrlen)))
207 } else {
208 Poll::Pending
209 }
210 }
211 (..) => continue,
212 };
213 }
214 }
215}
216
217impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
218 fn poll_read(
219 mut self: Pin<&mut Self>,
220 cx: &mut Context<'_>,
221 buf: &mut ReadBuf<'_>,
222 ) -> Poll<io::Result<()>> {
223 let prev = buf.remaining();
224
225 while buf.remaining() != 0 {
226 let mut would_block = false;
227
228 while self.session.wants_read() {
230 match self.read_io(cx) {
231 Poll::Ready(Ok(0)) => {
232 self.eof = true;
233 break;
234 }
235 Poll::Ready(Ok(_)) => (),
236 Poll::Pending => {
237 would_block = true;
238 break;
239 }
240 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
241 }
242 }
243
244 return match self.session.read(buf.initialize_unfilled()) {
245 Ok(0) if prev == buf.remaining() && would_block => Poll::Pending,
246 Ok(n) => {
247 buf.advance(n);
248
249 if self.eof || would_block {
250 break;
251 } else {
252 continue;
253 }
254 }
255 Err(ref err)
256 if err.kind() == io::ErrorKind::ConnectionAborted
257 && prev != buf.remaining() =>
258 {
259 break
260 }
261 Err(err) => Poll::Ready(Err(err)),
262 };
263 }
264
265 Poll::Ready(Ok(()))
266 }
267}
268
269impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'a, IO, S> {
270 fn poll_write(
271 mut self: Pin<&mut Self>,
272 cx: &mut Context,
273 buf: &[u8],
274 ) -> Poll<io::Result<usize>> {
275 let mut pos = 0;
276
277 while pos != buf.len() {
278 let mut would_block = false;
279
280 match self.session.write(&buf[pos..]) {
281 Ok(n) => pos += n,
282 Err(err) => return Poll::Ready(Err(err)),
283 };
284
285 while self.session.wants_write() {
286 match self.write_io(cx) {
287 Poll::Ready(Ok(0)) | Poll::Pending => {
288 would_block = true;
289 break;
290 }
291 Poll::Ready(Ok(_)) => (),
292 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
293 }
294 }
295
296 return match (pos, would_block) {
297 (0, true) => Poll::Pending,
298 (n, true) => Poll::Ready(Ok(n)),
299 (_, false) => continue,
300 };
301 }
302
303 Poll::Ready(Ok(pos))
304 }
305
306 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
307 self.session.flush()?;
308 while self.session.wants_write() {
309 ready!(self.write_io(cx))?;
310 }
311 Pin::new(&mut self.io).poll_flush(cx)
312 }
313
314 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
315 while self.session.wants_write() {
316 ready!(self.write_io(cx))?;
317 }
318 Pin::new(&mut self.io).poll_shutdown(cx)
319 }
320}
321
322#[cfg(test)]
323mod test_stream;