tonic/transport/service/
connector.rs

1use super::super::BoxFuture;
2use super::io::BoxedIo;
3#[cfg(feature = "tls")]
4use super::tls::TlsConnector;
5use http::Uri;
6use std::task::{Context, Poll};
7use tower::make::MakeConnection;
8use tower_service::Service;
9
10#[cfg(not(feature = "tls"))]
11pub(crate) fn connector<C>(inner: C) -> Connector<C> {
12    Connector::new(inner)
13}
14
15#[cfg(feature = "tls")]
16pub(crate) fn connector<C>(inner: C, tls: Option<TlsConnector>) -> Connector<C> {
17    Connector::new(inner, tls)
18}
19
20pub(crate) struct Connector<C> {
21    inner: C,
22    #[cfg(feature = "tls")]
23    tls: Option<TlsConnector>,
24    #[cfg(not(feature = "tls"))]
25    #[allow(dead_code)]
26    tls: Option<()>,
27}
28
29impl<C> Connector<C> {
30    #[cfg(not(feature = "tls"))]
31    pub(crate) fn new(inner: C) -> Self {
32        Self { inner, tls: None }
33    }
34
35    #[cfg(feature = "tls")]
36    fn new(inner: C, tls: Option<TlsConnector>) -> Self {
37        Self { inner, tls }
38    }
39
40    #[cfg(feature = "tls-roots-common")]
41    fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> {
42        use tokio_rustls::webpki::DNSNameRef;
43
44        if self.tls.is_some() {
45            return self.tls.clone();
46        }
47
48        match (scheme, host) {
49            (Some("https"), Some(host)) => {
50                if DNSNameRef::try_from_ascii(host.as_bytes()).is_ok() {
51                    TlsConnector::new_with_rustls_cert(None, None, host.to_owned()).ok()
52                } else {
53                    None
54                }
55            }
56            _ => None,
57        }
58    }
59}
60
61impl<C> Service<Uri> for Connector<C>
62where
63    C: MakeConnection<Uri>,
64    C::Connection: Unpin + Send + 'static,
65    C::Future: Send + 'static,
66    crate::Error: From<C::Error> + Send + 'static,
67{
68    type Response = BoxedIo;
69    type Error = crate::Error;
70    type Future = BoxFuture<Self::Response, Self::Error>;
71
72    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into)
74    }
75
76    fn call(&mut self, uri: Uri) -> Self::Future {
77        #[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
78        let tls = self.tls.clone();
79
80        #[cfg(feature = "tls-roots-common")]
81        let tls = self.tls_or_default(uri.scheme_str(), uri.host());
82
83        let connect = self.inner.make_connection(uri);
84
85        Box::pin(async move {
86            let io = connect.await?;
87
88            #[cfg(feature = "tls")]
89            {
90                if let Some(tls) = tls {
91                    let conn = tls.connect(io).await?;
92                    return Ok(BoxedIo::new(conn));
93                }
94            }
95
96            Ok(BoxedIo::new(io))
97        })
98    }
99}