tonic/transport/service/
tls.rs

1use super::io::BoxedIo;
2use crate::transport::{
3    server::{Connected, TlsStream},
4    Certificate, Identity,
5};
6#[cfg(feature = "tls-roots")]
7use rustls_native_certs;
8use std::{fmt, sync::Arc};
9use tokio::io::{AsyncRead, AsyncWrite};
10#[cfg(feature = "tls")]
11use tokio_rustls::{
12    rustls::{ClientConfig, NoClientAuth, ServerConfig, Session},
13    webpki::DNSNameRef,
14    TlsAcceptor as RustlsAcceptor, TlsConnector as RustlsConnector,
15};
16
17/// h2 alpn in plain format for rustls.
18#[cfg(feature = "tls")]
19const ALPN_H2: &str = "h2";
20
21#[derive(Debug)]
22enum TlsError {
23    #[allow(dead_code)]
24    H2NotNegotiated,
25    #[cfg(feature = "tls")]
26    CertificateParseError,
27    #[cfg(feature = "tls")]
28    PrivateKeyParseError,
29}
30
31#[derive(Clone)]
32pub(crate) struct TlsConnector {
33    config: Arc<ClientConfig>,
34    domain: Arc<String>,
35}
36
37impl TlsConnector {
38    #[cfg(feature = "tls")]
39    pub(crate) fn new_with_rustls_cert(
40        ca_cert: Option<Certificate>,
41        identity: Option<Identity>,
42        domain: String,
43    ) -> Result<Self, crate::Error> {
44        let mut config = ClientConfig::new();
45        config.set_protocols(&[Vec::from(ALPN_H2)]);
46
47        if let Some(identity) = identity {
48            let (client_cert, client_key) = rustls_keys::load_identity(identity)?;
49            config.set_single_client_cert(client_cert, client_key)?;
50        }
51
52        #[cfg(feature = "tls-roots")]
53        {
54            config.root_store = match rustls_native_certs::load_native_certs() {
55                Ok(store) | Err((Some(store), _)) => store,
56                Err((None, error)) => return Err(error.into()),
57            };
58        }
59
60        #[cfg(feature = "tls-webpki-roots")]
61        {
62            config
63                .root_store
64                .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
65        }
66
67        if let Some(cert) = ca_cert {
68            let mut buf = std::io::Cursor::new(&cert.pem[..]);
69            config.root_store.add_pem_file(&mut buf).unwrap();
70        }
71
72        Ok(Self {
73            config: Arc::new(config),
74            domain: Arc::new(domain),
75        })
76    }
77
78    #[cfg(feature = "tls")]
79    pub(crate) fn new_with_rustls_raw(
80        config: tokio_rustls::rustls::ClientConfig,
81        domain: String,
82    ) -> Result<Self, crate::Error> {
83        Ok(Self {
84            config: Arc::new(config),
85            domain: Arc::new(domain),
86        })
87    }
88
89    pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
90    where
91        I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
92    {
93        let tls_io = {
94            let dns = DNSNameRef::try_from_ascii_str(self.domain.as_str())?.to_owned();
95
96            let io = RustlsConnector::from(self.config.clone())
97                .connect(dns.as_ref(), io)
98                .await?;
99
100            let (_, session) = io.get_ref();
101
102            match session.get_alpn_protocol() {
103                Some(b) if b == b"h2" => (),
104                _ => return Err(TlsError::H2NotNegotiated.into()),
105            };
106
107            BoxedIo::new(io)
108        };
109
110        Ok(tls_io)
111    }
112}
113
114impl fmt::Debug for TlsConnector {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("TlsConnector").finish()
117    }
118}
119
120#[derive(Clone)]
121pub(crate) struct TlsAcceptor {
122    inner: Arc<ServerConfig>,
123}
124
125impl TlsAcceptor {
126    #[cfg(feature = "tls")]
127    pub(crate) fn new_with_rustls_identity(
128        identity: Identity,
129        client_ca_root: Option<Certificate>,
130    ) -> Result<Self, crate::Error> {
131        let (cert, key) = rustls_keys::load_identity(identity)?;
132
133        let mut config = match client_ca_root {
134            None => ServerConfig::new(NoClientAuth::new()),
135            Some(cert) => {
136                let mut cert = std::io::Cursor::new(&cert.pem[..]);
137
138                let mut client_root_cert_store = tokio_rustls::rustls::RootCertStore::empty();
139                if client_root_cert_store.add_pem_file(&mut cert).is_err() {
140                    return Err(Box::new(TlsError::CertificateParseError));
141                }
142
143                let client_auth =
144                    tokio_rustls::rustls::AllowAnyAuthenticatedClient::new(client_root_cert_store);
145                ServerConfig::new(client_auth)
146            }
147        };
148        config.set_single_cert(cert, key)?;
149        config.set_protocols(&[Vec::from(ALPN_H2)]);
150
151        Ok(Self {
152            inner: Arc::new(config),
153        })
154    }
155
156    #[cfg(feature = "tls")]
157    pub(crate) fn new_with_rustls_raw(
158        config: tokio_rustls::rustls::ServerConfig,
159    ) -> Result<Self, crate::Error> {
160        Ok(Self {
161            inner: Arc::new(config),
162        })
163    }
164
165    pub(crate) async fn accept<IO>(&self, io: IO) -> Result<TlsStream<IO>, crate::Error>
166    where
167        IO: AsyncRead + AsyncWrite + Connected + Unpin + Send + 'static,
168    {
169        let acceptor = RustlsAcceptor::from(self.inner.clone());
170        acceptor.accept(io).await.map_err(Into::into)
171    }
172}
173
174impl fmt::Debug for TlsAcceptor {
175    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
176        f.debug_struct("TlsAcceptor").finish()
177    }
178}
179
180impl fmt::Display for TlsError {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        match self {
183            TlsError::H2NotNegotiated => write!(f, "HTTP/2 was not negotiated."),
184            TlsError::CertificateParseError => write!(f, "Error parsing TLS certificate."),
185            TlsError::PrivateKeyParseError => write!(
186                f,
187                "Error parsing TLS private key - no RSA or PKCS8-encoded keys found."
188            ),
189        }
190    }
191}
192
193impl std::error::Error for TlsError {}
194
195#[cfg(feature = "tls")]
196mod rustls_keys {
197    use tokio_rustls::rustls::{internal::pemfile, Certificate, PrivateKey};
198
199    use crate::transport::service::tls::TlsError;
200    use crate::transport::Identity;
201
202    fn load_rustls_private_key(
203        mut cursor: std::io::Cursor<&[u8]>,
204    ) -> Result<PrivateKey, crate::Error> {
205        // First attempt to load the private key assuming it is PKCS8-encoded
206        if let Ok(mut keys) = pemfile::pkcs8_private_keys(&mut cursor) {
207            if !keys.is_empty() {
208                return Ok(keys.remove(0));
209            }
210        }
211
212        // If it not, try loading the private key as an RSA key
213        cursor.set_position(0);
214        if let Ok(mut keys) = pemfile::rsa_private_keys(&mut cursor) {
215            if !keys.is_empty() {
216                return Ok(keys.remove(0));
217            }
218        }
219
220        // Otherwise we have a Private Key parsing problem
221        Err(Box::new(TlsError::PrivateKeyParseError))
222    }
223
224    pub(crate) fn load_identity(
225        identity: Identity,
226    ) -> Result<(Vec<Certificate>, PrivateKey), crate::Error> {
227        let cert = {
228            let mut cert = std::io::Cursor::new(&identity.cert.pem[..]);
229            match pemfile::certs(&mut cert) {
230                Ok(certs) => certs,
231                Err(_) => return Err(Box::new(TlsError::CertificateParseError)),
232            }
233        };
234
235        let key = {
236            let key = std::io::Cursor::new(&identity.key[..]);
237            match load_rustls_private_key(key) {
238                Ok(key) => key,
239                Err(e) => {
240                    return Err(e);
241                }
242            }
243        };
244
245        Ok((cert, key))
246    }
247}