1macro_rules! ready {
4 ( $e:expr ) => {
5 match $e {
6 std::task::Poll::Ready(t) => t,
7 std::task::Poll::Pending => return std::task::Poll::Pending,
8 }
9 };
10}
11
12pub mod client;
13mod common;
14pub mod server;
15
16use common::{MidHandshake, Stream, TlsState};
17use rustls::{ClientConfig, ClientSession, ServerConfig, ServerSession, Session};
18use std::future::Future;
19use std::io;
20use std::pin::Pin;
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
24use webpki::DNSNameRef;
25
26pub use rustls;
27pub use webpki;
28
29#[derive(Clone)]
31pub struct TlsConnector {
32 inner: Arc<ClientConfig>,
33 #[cfg(feature = "early-data")]
34 early_data: bool,
35}
36
37#[derive(Clone)]
39pub struct TlsAcceptor {
40 inner: Arc<ServerConfig>,
41}
42
43impl From<Arc<ClientConfig>> for TlsConnector {
44 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
45 TlsConnector {
46 inner,
47 #[cfg(feature = "early-data")]
48 early_data: false,
49 }
50 }
51}
52
53impl From<Arc<ServerConfig>> for TlsAcceptor {
54 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
55 TlsAcceptor { inner }
56 }
57}
58
59impl TlsConnector {
60 #[cfg(feature = "early-data")]
65 pub fn early_data(mut self, flag: bool) -> TlsConnector {
66 self.early_data = flag;
67 self
68 }
69
70 #[inline]
71 pub fn connect<IO>(&self, domain: DNSNameRef, stream: IO) -> Connect<IO>
72 where
73 IO: AsyncRead + AsyncWrite + Unpin,
74 {
75 self.connect_with(domain, stream, |_| ())
76 }
77
78 pub fn connect_with<IO, F>(&self, domain: DNSNameRef, stream: IO, f: F) -> Connect<IO>
79 where
80 IO: AsyncRead + AsyncWrite + Unpin,
81 F: FnOnce(&mut ClientSession),
82 {
83 let mut session = ClientSession::new(&self.inner, domain);
84 f(&mut session);
85
86 Connect(MidHandshake::Handshaking(client::TlsStream {
87 io: stream,
88
89 #[cfg(not(feature = "early-data"))]
90 state: TlsState::Stream,
91
92 #[cfg(feature = "early-data")]
93 state: if self.early_data && session.early_data().is_some() {
94 TlsState::EarlyData(0, Vec::new())
95 } else {
96 TlsState::Stream
97 },
98
99 session,
100 }))
101 }
102}
103
104impl TlsAcceptor {
105 #[inline]
106 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
107 where
108 IO: AsyncRead + AsyncWrite + Unpin,
109 {
110 self.accept_with(stream, |_| ())
111 }
112
113 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
114 where
115 IO: AsyncRead + AsyncWrite + Unpin,
116 F: FnOnce(&mut ServerSession),
117 {
118 let mut session = ServerSession::new(&self.inner);
119 f(&mut session);
120
121 Accept(MidHandshake::Handshaking(server::TlsStream {
122 session,
123 io: stream,
124 state: TlsState::Stream,
125 }))
126 }
127}
128
129pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
132
133pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
136
137pub struct FailableConnect<IO>(MidHandshake<client::TlsStream<IO>>);
139
140pub struct FailableAccept<IO>(MidHandshake<server::TlsStream<IO>>);
142
143impl<IO> Connect<IO> {
144 #[inline]
145 pub fn into_failable(self) -> FailableConnect<IO> {
146 FailableConnect(self.0)
147 }
148}
149
150impl<IO> Accept<IO> {
151 #[inline]
152 pub fn into_failable(self) -> FailableAccept<IO> {
153 FailableAccept(self.0)
154 }
155}
156
157impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
158 type Output = io::Result<client::TlsStream<IO>>;
159
160 #[inline]
161 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
162 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
163 }
164}
165
166impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
167 type Output = io::Result<server::TlsStream<IO>>;
168
169 #[inline]
170 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
171 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
172 }
173}
174
175impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableConnect<IO> {
176 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
177
178 #[inline]
179 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
180 Pin::new(&mut self.0).poll(cx)
181 }
182}
183
184impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FailableAccept<IO> {
185 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
186
187 #[inline]
188 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
189 Pin::new(&mut self.0).poll(cx)
190 }
191}
192
193#[derive(Debug)]
198pub enum TlsStream<T> {
199 Client(client::TlsStream<T>),
200 Server(server::TlsStream<T>),
201}
202
203impl<T> TlsStream<T> {
204 pub fn get_ref(&self) -> (&T, &dyn Session) {
205 use TlsStream::*;
206 match self {
207 Client(io) => {
208 let (io, session) = io.get_ref();
209 (io, &*session)
210 }
211 Server(io) => {
212 let (io, session) = io.get_ref();
213 (io, &*session)
214 }
215 }
216 }
217
218 pub fn get_mut(&mut self) -> (&mut T, &mut dyn Session) {
219 use TlsStream::*;
220 match self {
221 Client(io) => {
222 let (io, session) = io.get_mut();
223 (io, &mut *session)
224 }
225 Server(io) => {
226 let (io, session) = io.get_mut();
227 (io, &mut *session)
228 }
229 }
230 }
231}
232
233impl<T> From<client::TlsStream<T>> for TlsStream<T> {
234 fn from(s: client::TlsStream<T>) -> Self {
235 Self::Client(s)
236 }
237}
238
239impl<T> From<server::TlsStream<T>> for TlsStream<T> {
240 fn from(s: server::TlsStream<T>) -> Self {
241 Self::Server(s)
242 }
243}
244
245impl<T> AsyncRead for TlsStream<T>
246where
247 T: AsyncRead + AsyncWrite + Unpin,
248{
249 #[inline]
250 fn poll_read(
251 self: Pin<&mut Self>,
252 cx: &mut Context<'_>,
253 buf: &mut ReadBuf<'_>,
254 ) -> Poll<io::Result<()>> {
255 match self.get_mut() {
256 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
257 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
258 }
259 }
260}
261
262impl<T> AsyncWrite for TlsStream<T>
263where
264 T: AsyncRead + AsyncWrite + Unpin,
265{
266 #[inline]
267 fn poll_write(
268 self: Pin<&mut Self>,
269 cx: &mut Context<'_>,
270 buf: &[u8],
271 ) -> Poll<io::Result<usize>> {
272 match self.get_mut() {
273 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
274 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
275 }
276 }
277
278 #[inline]
279 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
280 match self.get_mut() {
281 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
282 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
283 }
284 }
285
286 #[inline]
287 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
288 match self.get_mut() {
289 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
290 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
291 }
292 }
293}