tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/ctz/rustls).
2
3macro_rules! ready {
4    ( $e:expr ) => {
5        match $e {
6            std::task::Poll::Ready(t) => t,
7            std::task::Poll::Pending => return std::task::Poll::Pending,
8        }
9    };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
18use std::future::Future;
19use std::io;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24use webpki::DNSNameRef;
25
26pub use rustls;
27pub use webpki;
28
29/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
30#[derive(Clone)]
31pub struct TlsConnector {
32    inner: Arc<ClientConfig>,
33    #[cfg(feature = "early-data")]
34    early_data: bool,
35}
36
37/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
38#[derive(Clone)]
39pub struct TlsAcceptor {
40    inner: Arc<ServerConfig>,
41}
42
43impl From<Arc<ClientConfig>> for TlsConnector {
44    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
45        TlsConnector {
46            inner,
47            #[cfg(feature = "early-data")]
48            early_data: false,
49        }
50    }
51}
52
53impl From<Arc<ServerConfig>> for TlsAcceptor {
54    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
55        TlsAcceptor { inner }
56    }
57}
58
59impl TlsConnector {
60    /// Enable 0-RTT.
61    ///
62    /// If you want to use 0-RTT,
63    /// You must also set `ClientConfig.enable_early_data` to `true`.
64    #[cfg(feature = "early-data")]
65    pub fn early_data(mut self, flag: bool) -> TlsConnector {
66        self.early_data = flag;
67        self
68    }
69
70    #[inline]
71    pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
72    where
73        IO: AsyncRead + AsyncWrite + Unpin,
74    {
75        self.connect_with(domain, stream, |_| ())
76    }
77
78    pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
79    where
80        IO: AsyncRead + AsyncWrite + Unpin,
81        F: FnOnce(&mut ClientSession),
82    {
83        let mut session = ClientSession::new(&self.inner, domain);
84        f(&mut session);
85
86        Connect(MidHandshake::Handshaking(client::TlsStream {
87            io: stream,
88
89            #[cfg(not(feature = "early-data"))]
90            state: TlsState::Stream,
91
92            #[cfg(feature = "early-data")]
93            state: if self.early_data && session.early_data().is_some() {
94                TlsState::EarlyData(0, Vec::new())
95            } else {
96                TlsState::Stream
97            },
98
99            session,
100        }))
101    }
102}
103
104impl TlsAcceptor {
105    #[inline]
106    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
107    where
108        IO: AsyncRead + AsyncWrite + Unpin,
109    {
110        self.accept_with(stream, |_| ())
111    }
112
113    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
114    where
115        IO: AsyncRead + AsyncWrite + Unpin,
116        F: FnOnce(&mut ServerSession),
117    {
118        let mut session = ServerSession::new(&self.inner);
119        f(&mut session);
120
121        Accept(MidHandshake::Handshaking(server::TlsStream {
122            session,
123            io: stream,
124            state: TlsState::Stream,
125        }))
126    }
127}
128
129/// Future returned from `TlsConnector::connect` which will resolve
130/// once the connection handshake has finished.
131pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
132
133/// Future returned from `TlsAcceptor::accept` which will resolve
134/// once the accept handshake has finished.
135pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
136
137/// Like [Connect], but returns `IO` on failure.
138pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>);
139
140/// Like [Accept], but returns `IO` on failure.
141pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>);
142
143impl<IO> Connect<IO> {
144    #[inline]
145    pub fn into_failable(self) -> FailableConnect<IO> {
146        FailableConnect(self.0)
147    }
148}
149
150impl<IO> Accept<IO> {
151    #[inline]
152    pub fn into_failable(self) -> FailableAccept<IO> {
153        FailableAccept(self.0)
154    }
155}
156
157impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
158    type Output = io::Result<client::TlsStream<IO>>;
159
160    #[inline]
161    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
163    }
164}
165
166impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
167    type Output = io::Result<server::TlsStream<IO>>;
168
169    #[inline]
170    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
172    }
173}
174
175impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> {
176    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
177
178    #[inline]
179    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180        Pin::new(&mut self.0).poll(cx)
181    }
182}
183
184impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> {
185    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
186
187    #[inline]
188    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189        Pin::new(&mut self.0).poll(cx)
190    }
191}
192
193/// Unified TLS stream type
194///
195/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
196/// a single type to keep both client- and server-initiated TLS-encrypted connections.
197#[derive(Debug)]
198pub enum TlsStream<T> {
199    Client(client::TlsStream<T>),
200    Server(server::TlsStream<T>),
201}
202
203impl<T> TlsStream<T> {
204    pub fn get_ref(&self) -> (&T, &dyn Session) {
205        use TlsStream::*;
206        match self {
207            Client(io) => {
208                let (io, session) = io.get_ref();
209                (io, &*session)
210            }
211            Server(io) => {
212                let (io, session) = io.get_ref();
213                (io, &*session)
214            }
215        }
216    }
217
218    pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) {
219        use TlsStream::*;
220        match self {
221            Client(io) => {
222                let (io, session) = io.get_mut();
223                (io, &mut *session)
224            }
225            Server(io) => {
226                let (io, session) = io.get_mut();
227                (io, &mut *session)
228            }
229        }
230    }
231}
232
233impl<T> From<client::TlsStream<T>> for TlsStream<T> {
234    fn from(s: client::TlsStream<T>) -> Self {
235        Self::Client(s)
236    }
237}
238
239impl<T> From<server::TlsStream<T>> for TlsStream<T> {
240    fn from(s: server::TlsStream<T>) -> Self {
241        Self::Server(s)
242    }
243}
244
245impl<T> AsyncRead for TlsStream<T>
246where
247    T: AsyncRead + AsyncWrite + Unpin,
248{
249    #[inline]
250    fn poll_read(
251        self: Pin<&mut Self>,
252        cx: &mut Context<'_>,
253        buf: &mut ReadBuf<'_>,
254    ) -> Poll<io::Result<()>> {
255        match self.get_mut() {
256            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
257            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
258        }
259    }
260}
261
262impl<T> AsyncWrite for TlsStream<T>
263where
264    T: AsyncRead + AsyncWrite + Unpin,
265{
266    #[inline]
267    fn poll_write(
268        self: Pin<&mut Self>,
269        cx: &mut Context<'_>,
270        buf: &[u8],
271    ) -> Poll<io::Result<usize>> {
272        match self.get_mut() {
273            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
274            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
275        }
276    }
277
278    #[inline]
279    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280        match self.get_mut() {
281            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
282            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
283        }
284    }
285
286    #[inline]
287    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
288        match self.get_mut() {
289            TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
290            TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
291        }
292    }
293}