hyper_timeout/
lib.rs

1use std::future::Future;
2use std::io;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5use std::time::Duration;
6
7use tokio::io::{AsyncRead, AsyncWrite};
8use tokio::time::timeout;
9use tokio_io_timeout::TimeoutStream;
10
11use hyper::client::connect::{Connected, Connection};
12use hyper::{service::Service, Uri};
13
14mod stream;
15
16use stream::TimeoutConnectorStream;
17
18type BoxError = Box<dyn std::error::Error + Send + Sync>;
19
20/// A connector that enforces as connection timeout
21#[derive(Debug, Clone)]
22pub struct TimeoutConnector<T> {
23    /// A connector implementing the `Connect` trait
24    connector: T,
25    /// Amount of time to wait connecting
26    connect_timeout: Option<Duration>,
27    /// Amount of time to wait reading response
28    read_timeout: Option<Duration>,
29    /// Amount of time to wait writing request
30    write_timeout: Option<Duration>,
31}
32
33impl<T> TimeoutConnector<T>
34where
35    T: Service<Uri> + Send,
36    T::Response: AsyncRead + AsyncWrite + Send + Unpin,
37    T::Future: Send + 'static,
38    T::Error: Into<BoxError>,
39{
40    /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait
41    pub fn new(connector: T) -> Self {
42        TimeoutConnector {
43            connector,
44            connect_timeout: None,
45            read_timeout: None,
46            write_timeout: None,
47        }
48    }
49}
50
51impl<T> Service<Uri> for TimeoutConnector<T>
52where
53    T: Service<Uri> + Send,
54    T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin,
55    T::Future: Send + 'static,
56    T::Error: Into<BoxError>,
57{
58    type Response = Pin<Box<TimeoutConnectorStream<T::Response>>>;
59    type Error = BoxError;
60    #[allow(clippy::type_complexity)]
61    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
62
63    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64        self.connector.poll_ready(cx).map_err(Into::into)
65    }
66
67    fn call(&mut self, dst: Uri) -> Self::Future {
68        let connect_timeout = self.connect_timeout;
69        let read_timeout = self.read_timeout;
70        let write_timeout = self.write_timeout;
71        let connecting = self.connector.call(dst);
72
73        let fut = async move {
74            let stream = match connect_timeout {
75                None => {
76                    let io = connecting.await.map_err(Into::into)?;
77                    TimeoutStream::new(io)
78                }
79                Some(connect_timeout) => {
80                    let timeout = timeout(connect_timeout, connecting);
81                    let connecting = timeout
82                        .await
83                        .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
84                    let io = connecting.map_err(Into::into)?;
85                    TimeoutStream::new(io)
86                }
87            };
88
89            let mut tm = TimeoutConnectorStream::new(stream);
90            tm.set_read_timeout(read_timeout);
91            tm.set_write_timeout(write_timeout);
92            Ok(Box::pin(tm))
93        };
94
95        Box::pin(fut)
96    }
97}
98
99impl<T> TimeoutConnector<T> {
100    /// Set the timeout for connecting to a URL.
101    ///
102    /// Default is no timeout.
103    #[inline]
104    pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
105        self.connect_timeout = val;
106    }
107
108    /// Set the timeout for the response.
109    ///
110    /// Default is no timeout.
111    #[inline]
112    pub fn set_read_timeout(&mut self, val: Option<Duration>) {
113        self.read_timeout = val;
114    }
115
116    /// Set the timeout for the request.
117    ///
118    /// Default is no timeout.
119    #[inline]
120    pub fn set_write_timeout(&mut self, val: Option<Duration>) {
121        self.write_timeout = val;
122    }
123}
124
125impl<T> Connection for TimeoutConnector<T>
126where
127    T: AsyncRead + AsyncWrite + Connection + Service<Uri> + Send + Unpin,
128    T::Response: AsyncRead + AsyncWrite + Send + Unpin,
129    T::Future: Send + 'static,
130    T::Error: Into<BoxError>,
131{
132    fn connected(&self) -> Connected {
133        self.connector.connected()
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use std::error::Error;
140    use std::io;
141    use std::time::Duration;
142
143    use hyper::client::HttpConnector;
144    use hyper::Client;
145
146    use super::TimeoutConnector;
147
148    #[tokio::test]
149    async fn test_timeout_connector() {
150        // 10.255.255.1 is a not a routable IP address
151        let url = "http://10.255.255.1".parse().unwrap();
152
153        let http = HttpConnector::new();
154        let mut connector = TimeoutConnector::new(http);
155        connector.set_connect_timeout(Some(Duration::from_millis(1)));
156
157        let client = Client::builder().build::<_, hyper::Body>(connector);
158
159        let res = client.get(url).await;
160
161        match res {
162            Ok(_) => panic!("Expected a timeout"),
163            Err(e) => {
164                if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
165                    assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
166                } else {
167                    panic!("Expected timeout error");
168                }
169            }
170        }
171    }
172
173    #[tokio::test]
174    async fn test_read_timeout() {
175        let url = "http://example.com".parse().unwrap();
176
177        let http = HttpConnector::new();
178        let mut connector = TimeoutConnector::new(http);
179        // A 1 ms read timeout should be so short that we trigger a timeout error
180        connector.set_read_timeout(Some(Duration::from_millis(1)));
181
182        let client = Client::builder().build::<_, hyper::Body>(connector);
183
184        let res = client.get(url).await;
185
186        match res {
187            Ok(_) => panic!("Expected a timeout"),
188            Err(e) => {
189                if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
190                    assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
191                } else {
192                    panic!("Expected timeout error");
193                }
194            }
195        }
196    }
197}