tonic/transport/service/
connector.rs1use super::super::BoxFuture;
2use super::io::BoxedIo;
3#[cfg(feature = "tls")]
4use super::tls::TlsConnector;
5use http::Uri;
6use std::task::{Context, Poll};
7use tower::make::MakeConnection;
8use tower_service::Service;
9
10#[cfg(not(feature = "tls"))]
11pub(crate) fn connector<C>(inner: C) -> Connector<C> {
12 Connector::new(inner)
13}
14
15#[cfg(feature = "tls")]
16pub(crate) fn connector<C>(inner: C, tls: Option<TlsConnector>) -> Connector<C> {
17 Connector::new(inner, tls)
18}
19
20pub(crate) struct Connector<C> {
21 inner: C,
22 #[cfg(feature = "tls")]
23 tls: Option<TlsConnector>,
24 #[cfg(not(feature = "tls"))]
25 #[allow(dead_code)]
26 tls: Option<()>,
27}
28
29impl<C> Connector<C> {
30 #[cfg(not(feature = "tls"))]
31 pub(crate) fn new(inner: C) -> Self {
32 Self { inner, tls: None }
33 }
34
35 #[cfg(feature = "tls")]
36 fn new(inner: C, tls: Option<TlsConnector>) -> Self {
37 Self { inner, tls }
38 }
39
40 #[cfg(feature = "tls-roots-common")]
41 fn tls_or_default(&self, scheme: Option<&str>, host: Option<&str>) -> Option<TlsConnector> {
42 use tokio_rustls::webpki::DNSNameRef;
43
44 if self.tls.is_some() {
45 return self.tls.clone();
46 }
47
48 match (scheme, host) {
49 (Some("https"), Some(host)) => {
50 if DNSNameRef::try_from_ascii(host.as_bytes()).is_ok() {
51 TlsConnector::new_with_rustls_cert(None, None, host.to_owned()).ok()
52 } else {
53 None
54 }
55 }
56 _ => None,
57 }
58 }
59}
60
61impl<C> Service<Uri> for Connector<C>
62where
63 C: MakeConnection<Uri>,
64 C::Connection: Unpin + Send + 'static,
65 C::Future: Send + 'static,
66 crate::Error: From<C::Error> + Send + 'static,
67{
68 type Response = BoxedIo;
69 type Error = crate::Error;
70 type Future = BoxFuture<Self::Response, Self::Error>;
71
72 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73 MakeConnection::poll_ready(&mut self.inner, cx).map_err(Into::into)
74 }
75
76 fn call(&mut self, uri: Uri) -> Self::Future {
77 #[cfg(all(feature = "tls", not(feature = "tls-roots-common")))]
78 let tls = self.tls.clone();
79
80 #[cfg(feature = "tls-roots-common")]
81 let tls = self.tls_or_default(uri.scheme_str(), uri.host());
82
83 let connect = self.inner.make_connection(uri);
84
85 Box::pin(async move {
86 let io = connect.await?;
87
88 #[cfg(feature = "tls")]
89 {
90 if let Some(tls) = tls {
91 let conn = tls.connect(io).await?;
92 return Ok(BoxedIo::new(conn));
93 }
94 }
95
96 Ok(BoxedIo::new(io))
97 })
98 }
99}