tonic/transport/service/
grpc_timeout.rs

1use crate::metadata::GRPC_TIMEOUT_HEADER;
2use crate::util::{OptionPin, OptionPinProj};
3use http::{HeaderMap, HeaderValue, Request};
4use pin_project::pin_project;
5use std::{
6    fmt,
7    future::Future,
8    pin::Pin,
9    task::{Context, Poll},
10    time::Duration,
11};
12use tokio::time::Sleep;
13use tower_service::Service;
14
15#[derive(Debug, Clone)]
16pub(crate) struct GrpcTimeout<S> {
17    inner: S,
18    server_timeout: Option<Duration>,
19}
20
21impl<S> GrpcTimeout<S> {
22    pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
23        Self {
24            inner,
25            server_timeout,
26        }
27    }
28}
29
30impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
31where
32    S: Service<Request<ReqBody>>,
33    S::Error: Into<crate::Error>,
34{
35    type Response = S::Response;
36    type Error = crate::Error;
37    type Future = ResponseFuture<S::Future>;
38
39    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
40        self.inner.poll_ready(cx).map_err(Into::into)
41    }
42
43    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
44        let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
45            tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
46            None
47        });
48
49        // Use the shorter of the two durations, if either are set
50        let timeout_duration = match (client_timeout, self.server_timeout) {
51            (None, None) => None,
52            (Some(dur), None) => Some(dur),
53            (None, Some(dur)) => Some(dur),
54            (Some(header), Some(server)) => {
55                let shorter_duration = std::cmp::min(header, server);
56                Some(shorter_duration)
57            }
58        };
59
60        ResponseFuture {
61            inner: self.inner.call(req),
62            sleep: timeout_duration
63                .map(tokio::time::sleep)
64                .map(OptionPin::Some)
65                .unwrap_or(OptionPin::None),
66        }
67    }
68}
69
70#[pin_project]
71pub(crate) struct ResponseFuture<F> {
72    #[pin]
73    inner: F,
74    #[pin]
75    sleep: OptionPin<Sleep>,
76}
77
78impl<F, Res, E> Future for ResponseFuture<F>
79where
80    F: Future<Output = Result<Res, E>>,
81    E: Into<crate::Error>,
82{
83    type Output = Result<Res, crate::Error>;
84
85    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86        let this = self.project();
87
88        if let Poll::Ready(result) = this.inner.poll(cx) {
89            return Poll::Ready(result.map_err(Into::into));
90        }
91
92        if let OptionPinProj::Some(sleep) = this.sleep.project() {
93            futures_util::ready!(sleep.poll(cx));
94            return Poll::Ready(Err(TimeoutExpired(()).into()));
95        }
96
97        Poll::Pending
98    }
99}
100
101const SECONDS_IN_HOUR: u64 = 60 * 60;
102const SECONDS_IN_MINUTE: u64 = 60;
103
104/// Tries to parse the `grpc-timeout` header if it is present. If we fail to parse, returns
105/// the value we attempted to parse.
106///
107/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
108fn try_parse_grpc_timeout(
109    headers: &HeaderMap<HeaderValue>,
110) -> Result<Option<Duration>, &HeaderValue> {
111    match headers.get(GRPC_TIMEOUT_HEADER) {
112        Some(val) => {
113            let (timeout_value, timeout_unit) = val
114                .to_str()
115                .map_err(|_| val)
116                .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
117                // `HeaderValue::to_str` only returns `Ok` if the header contains ASCII so this
118                // `split_at` will never panic from trying to split in the middle of a character.
119                // See https://docs.rs/http/0.2.4/http/header/struct.HeaderValue.html#method.to_str
120                //
121                // `len - 1` also wont panic since we just checked `s.is_empty`.
122                .split_at(val.len() - 1);
123
124            // gRPC spec specifies `TimeoutValue` will be at most 8 digits
125            // Caping this at 8 digits also prevents integer overflow from ever occurring
126            if timeout_value.len() > 8 {
127                return Err(val);
128            }
129
130            let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
131
132            let duration = match timeout_unit {
133                // Hours
134                "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
135                // Minutes
136                "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
137                // Seconds
138                "S" => Duration::from_secs(timeout_value),
139                // Milliseconds
140                "m" => Duration::from_millis(timeout_value),
141                // Microseconds
142                "u" => Duration::from_micros(timeout_value),
143                // Nanoseconds
144                "n" => Duration::from_nanos(timeout_value),
145                _ => return Err(val),
146            };
147
148            Ok(Some(duration))
149        }
150        None => Ok(None),
151    }
152}
153
154/// Error returned if a request didn't complete within the configured timeout.
155///
156/// Timeouts can be configured either with [`Endpoint::timeout`], [`Server::timeout`], or by
157/// setting the [`grpc-timeout` metadata value][spec].
158///
159/// [`Endpoint::timeout`]: crate::transport::server::Server::timeout
160/// [`Server::timeout`]: crate::transport::channel::Endpoint::timeout
161/// [spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
162#[derive(Debug)]
163pub struct TimeoutExpired(());
164
165impl fmt::Display for TimeoutExpired {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(f, "Timeout expired")
168    }
169}
170
171// std::error::Error only requires a type to impl Debug and Display
172impl std::error::Error for TimeoutExpired {}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use quickcheck::{Arbitrary, Gen};
178    use quickcheck_macros::quickcheck;
179
180    // Helper function to reduce the boiler plate of our test cases
181    fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
182        let mut hm = HeaderMap::new();
183        if let Some(v) = val {
184            let hv = HeaderValue::from_str(v).unwrap();
185            hm.insert(GRPC_TIMEOUT_HEADER, hv);
186        };
187
188        try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
189    }
190
191    #[test]
192    fn test_hours() {
193        let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
194        assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
195    }
196
197    #[test]
198    fn test_minutes() {
199        let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
200        assert_eq!(Duration::from_secs(60), parsed_duration);
201    }
202
203    #[test]
204    fn test_seconds() {
205        let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
206        assert_eq!(Duration::from_secs(42), parsed_duration);
207    }
208
209    #[test]
210    fn test_milliseconds() {
211        let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
212        assert_eq!(Duration::from_millis(13), parsed_duration);
213    }
214
215    #[test]
216    fn test_microseconds() {
217        let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
218        assert_eq!(Duration::from_micros(2), parsed_duration);
219    }
220
221    #[test]
222    fn test_nanoseconds() {
223        let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
224        assert_eq!(Duration::from_nanos(82), parsed_duration);
225    }
226
227    #[test]
228    fn test_header_not_present() {
229        let parsed_duration = setup_map_try_parse(None).unwrap();
230        assert!(parsed_duration.is_none());
231    }
232
233    #[test]
234    #[should_panic(expected = "82f")]
235    fn test_invalid_unit() {
236        // "f" is not a valid TimeoutUnit
237        setup_map_try_parse(Some("82f")).unwrap().unwrap();
238    }
239
240    #[test]
241    #[should_panic(expected = "123456789H")]
242    fn test_too_many_digits() {
243        // gRPC spec states TimeoutValue will be at most 8 digits
244        setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
245    }
246
247    #[test]
248    #[should_panic(expected = "oneH")]
249    fn test_invalid_digits() {
250        // gRPC spec states TimeoutValue will be at most 8 digits
251        setup_map_try_parse(Some("oneH")).unwrap().unwrap();
252    }
253
254    #[quickcheck]
255    fn fuzz(header_value: HeaderValueGen) -> bool {
256        let header_value = header_value.0;
257
258        // this just shouldn't panic
259        let _ = setup_map_try_parse(Some(&header_value));
260
261        true
262    }
263
264    /// Newtype to implement `Arbitrary` for generating `String`s that are valid `HeaderValue`s.
265    #[derive(Clone, Debug)]
266    struct HeaderValueGen(String);
267
268    impl Arbitrary for HeaderValueGen {
269        fn arbitrary(g: &mut Gen) -> Self {
270            let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
271            Self(gen_string(g, 0, max))
272        }
273    }
274
275    // copied from https://github.com/hyperium/http/blob/master/tests/header_map_fuzz.rs
276    fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
277        let bytes: Vec<_> = (min..max)
278            .map(|_| {
279                // Chars to pick from
280                g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
281                    .copied()
282                    .unwrap()
283            })
284            .collect();
285
286        String::from_utf8(bytes).unwrap()
287    }
288}