tonic/transport/service/
tls.rs1use super::io::BoxedIo;
2use crate::transport::{
3 server::{Connected, TlsStream},
4 Certificate, Identity,
5};
6#[cfg(feature = "tls-roots")]
7use rustls_native_certs;
8use std::{fmt, sync::Arc};
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "tls")]
11use tokio_rustls::{
12 rustls::{ClientConfig, NoClientAuth, ServerConfig, Session},
13 webpki::DNSNameRef,
14 TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
15};
16
17#[cfg(feature = "tls")]
19const ALPN_H2: &str = "h2";
20
21#[derive(Debug)]
22enum TlsError {
23 #[allow(dead_code)]
24 H2NotNegotiated,
25 #[cfg(feature = "tls")]
26 CertificateParseError,
27 #[cfg(feature = "tls")]
28 PrivateKeyParseError,
29}
30
31#[derive(Clone)]
32pub(crate) struct TlsConnector {
33 config: Arc<ClientConfig>,
34 domain: Arc<String>,
35}
36
37impl TlsConnector {
38 #[cfg(feature = "tls")]
39 pub(crate) fn new_with_rustls_cert(
40 ca_cert: Option<Certificate>,
41 identity: Option<Identity>,
42 domain: String,
43 ) -> Result<Self, crate::Error> {
44 let mut config = ClientConfig::new();
45 config.set_protocols(&[Vec::from(ALPN_H2)]);
46
47 if let Some(identity) = identity {
48 let (client_cert, client_key) = rustls_keys::load_identity(identity)?;
49 config.set_single_client_cert(client_cert, client_key)?;
50 }
51
52 #[cfg(feature = "tls-roots")]
53 {
54 config.root_store = match rustls_native_certs::load_native_certs() {
55 Ok(store) | Err((Some(store), _)) => store,
56 Err((None, error)) => return Err(error.into()),
57 };
58 }
59
60 #[cfg(feature = "tls-webpki-roots")]
61 {
62 config
63 .root_store
64 .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
65 }
66
67 if let Some(cert) = ca_cert {
68 let mut buf = std::io::Cursor::new(&cert.pem[..]);
69 config.root_store.add_pem_file(&mut buf).unwrap();
70 }
71
72 Ok(Self {
73 config: Arc::new(config),
74 domain: Arc::new(domain),
75 })
76 }
77
78 #[cfg(feature = "tls")]
79 pub(crate) fn new_with_rustls_raw(
80 config: tokio_rustls::rustls::ClientConfig,
81 domain: String,
82 ) -> Result<Self, crate::Error> {
83 Ok(Self {
84 config: Arc::new(config),
85 domain: Arc::new(domain),
86 })
87 }
88
89 pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
90 where
91 I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
92 {
93 let tls_io = {
94 let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned();
95
96 let io = RustlsConnector::from(self.config.clone())
97 .connect(dns.as_ref(), io)
98 .await?;
99
100 let (_, session) = io.get_ref();
101
102 match session.get_alpn_protocol() {
103 Some(b) if b == b"h2" => (),
104 _ => return Err(TlsError::H2NotNegotiated.into()),
105 };
106
107 BoxedIo::new(io)
108 };
109
110 Ok(tls_io)
111 }
112}
113
114impl fmt::Debug for TlsConnector {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("TlsConnector").finish()
117 }
118}
119
120#[derive(Clone)]
121pub(crate) struct TlsAcceptor {
122 inner: Arc<ServerConfig>,
123}
124
125impl TlsAcceptor {
126 #[cfg(feature = "tls")]
127 pub(crate) fn new_with_rustls_identity(
128 identity: Identity,
129 client_ca_root: Option<Certificate>,
130 ) -> Result<Self, crate::Error> {
131 let (cert, key) = rustls_keys::load_identity(identity)?;
132
133 let mut config = match client_ca_root {
134 None => ServerConfig::new(NoClientAuth::new()),
135 Some(cert) => {
136 let mut cert = std::io::Cursor::new(&cert.pem[..]);
137
138 let mut client_root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
139 if client_root_cert_store.add_pem_file(&mut cert).is_err() {
140 return Err(Box::new(TlsError::CertificateParseError));
141 }
142
143 let client_auth =
144 tokio_rustls::rustls::AllowAnyAuthenticatedClient::new(client_root_cert_store);
145 ServerConfig::new(client_auth)
146 }
147 };
148 config.set_single_cert(cert, key)?;
149 config.set_protocols(&[Vec::from(ALPN_H2)]);
150
151 Ok(Self {
152 inner: Arc::new(config),
153 })
154 }
155
156 #[cfg(feature = "tls")]
157 pub(crate) fn new_with_rustls_raw(
158 config: tokio_rustls::rustls::ServerConfig,
159 ) -> Result<Self, crate::Error> {
160 Ok(Self {
161 inner: Arc::new(config),
162 })
163 }
164
165 pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
166 where
167 IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
168 {
169 let acceptor = RustlsAcceptor::from(self.inner.clone());
170 acceptor.accept(io).await.map_err(Into::into)
171 }
172}
173
174impl fmt::Debug for TlsAcceptor {
175 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176 f.debug_struct("TlsAcceptor").finish()
177 }
178}
179
180impl fmt::Display for TlsError {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 match self {
183 TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
184 TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
185 TlsError::PrivateKeyParseError => write!(
186 f,
187 "Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
188 ),
189 }
190 }
191}
192
193impl std::error::Error for TlsError {}
194
195#[cfg(feature = "tls")]
196mod rustls_keys {
197 use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey};
198
199 use crate::transport::service::tls::TlsError;
200 use crate::transport::Identity;
201
202 fn load_rustls_private_key(
203 mut cursor: std::io::Cursor<&[u8]>,
204 ) -> Result<PrivateKey, crate::Error> {
205 if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
207 if !keys.is_empty() {
208 return Ok(keys.remove(0));
209 }
210 }
211
212 cursor.set_position(0);
214 if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
215 if !keys.is_empty() {
216 return Ok(keys.remove(0));
217 }
218 }
219
220 Err(Box::new(TlsError::PrivateKeyParseError))
222 }
223
224 pub(crate) fn load_identity(
225 identity: Identity,
226 ) -> Result<(Vec<Certificate>, PrivateKey), crate::Error> {
227 let cert = {
228 let mut cert = std::io::Cursor::new(&identity.cert.pem[..]);
229 match pemfile::certs(&mut cert) {
230 Ok(certs) => certs,
231 Err(_) => return Err(Box::new(TlsError::CertificateParseError)),
232 }
233 };
234
235 let key = {
236 let key = std::io::Cursor::new(&identity.key[..]);
237 match load_rustls_private_key(key) {
238 Ok(key) => key,
239 Err(e) => {
240 return Err(e);
241 }
242 }
243 };
244
245 Ok((cert, key))
246 }
247}