1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::time::timeout;
9use tokio_io_timeout::TimeoutStream;
10
11use hyper::client::connect::{Connected, Connection};
12use hyper::{service::Service, Uri};
13
14mod stream;
15
16use stream::TimeoutConnectorStream;
17
18type BoxError = Box<dyn std::error::Error + Send + Sync>;
19
20#[derive(Debug, Clone)]
22pub struct TimeoutConnector<T> {
23 connector: T,
25 connect_timeout: Option<Duration>,
27 read_timeout: Option<Duration>,
29 write_timeout: Option<Duration>,
31}
32
33impl<T> TimeoutConnector<T>
34where
35 T: Service<Uri> + Send,
36 T::Response: AsyncRead + AsyncWrite + Send + Unpin,
37 T::Future: Send + 'static,
38 T::Error: Into<BoxError>,
39{
40 pub fn new(connector: T) -> Self {
42 TimeoutConnector {
43 connector,
44 connect_timeout: None,
45 read_timeout: None,
46 write_timeout: None,
47 }
48 }
49}
50
51impl<T> Service<Uri> for TimeoutConnector<T>
52where
53 T: Service<Uri> + Send,
54 T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin,
55 T::Future: Send + 'static,
56 T::Error: Into<BoxError>,
57{
58 type Response = Pin<Box<TimeoutConnectorStream<T::Response>>>;
59 type Error = BoxError;
60 #[allow(clippy::type_complexity)]
61 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
62
63 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64 self.connector.poll_ready(cx).map_err(Into::into)
65 }
66
67 fn call(&mut self, dst: Uri) -> Self::Future {
68 let connect_timeout = self.connect_timeout;
69 let read_timeout = self.read_timeout;
70 let write_timeout = self.write_timeout;
71 let connecting = self.connector.call(dst);
72
73 let fut = async move {
74 let stream = match connect_timeout {
75 None => {
76 let io = connecting.await.map_err(Into::into)?;
77 TimeoutStream::new(io)
78 }
79 Some(connect_timeout) => {
80 let timeout = timeout(connect_timeout, connecting);
81 let connecting = timeout
82 .await
83 .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
84 let io = connecting.map_err(Into::into)?;
85 TimeoutStream::new(io)
86 }
87 };
88
89 let mut tm = TimeoutConnectorStream::new(stream);
90 tm.set_read_timeout(read_timeout);
91 tm.set_write_timeout(write_timeout);
92 Ok(Box::pin(tm))
93 };
94
95 Box::pin(fut)
96 }
97}
98
99impl<T> TimeoutConnector<T> {
100 #[inline]
104 pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
105 self.connect_timeout = val;
106 }
107
108 #[inline]
112 pub fn set_read_timeout(&mut self, val: Option<Duration>) {
113 self.read_timeout = val;
114 }
115
116 #[inline]
120 pub fn set_write_timeout(&mut self, val: Option<Duration>) {
121 self.write_timeout = val;
122 }
123}
124
125impl<T> Connection for TimeoutConnector<T>
126where
127 T: AsyncRead + AsyncWrite + Connection + Service<Uri> + Send + Unpin,
128 T::Response: AsyncRead + AsyncWrite + Send + Unpin,
129 T::Future: Send + 'static,
130 T::Error: Into<BoxError>,
131{
132 fn connected(&self) -> Connected {
133 self.connector.connected()
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use std::error::Error;
140 use std::io;
141 use std::time::Duration;
142
143 use hyper::client::HttpConnector;
144 use hyper::Client;
145
146 use super::TimeoutConnector;
147
148 #[tokio::test]
149 async fn test_timeout_connector() {
150 let url = "http://10.255.255.1".parse().unwrap();
152
153 let http = HttpConnector::new();
154 let mut connector = TimeoutConnector::new(http);
155 connector.set_connect_timeout(Some(Duration::from_millis(1)));
156
157 let client = Client::builder().build::<_, hyper::Body>(connector);
158
159 let res = client.get(url).await;
160
161 match res {
162 Ok(_) => panic!("Expected a timeout"),
163 Err(e) => {
164 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
165 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
166 } else {
167 panic!("Expected timeout error");
168 }
169 }
170 }
171 }
172
173 #[tokio::test]
174 async fn test_read_timeout() {
175 let url = "http://example.com".parse().unwrap();
176
177 let http = HttpConnector::new();
178 let mut connector = TimeoutConnector::new(http);
179 connector.set_read_timeout(Some(Duration::from_millis(1)));
181
182 let client = Client::builder().build::<_, hyper::Body>(connector);
183
184 let res = client.get(url).await;
185
186 match res {
187 Ok(_) => panic!("Expected a timeout"),
188 Err(e) => {
189 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
190 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
191 } else {
192 panic!("Expected timeout error");
193 }
194 }
195 }
196 }
197}