tonic/transport/service/
grpc_timeout.rs1use crate::metadata::GRPC_TIMEOUT_HEADER;
2use crate::util::{OptionPin, OptionPinProj};
3use http::{HeaderMap, HeaderValue, Request};
4use pin_project::pin_project;
5use std::{
6 fmt,
7 future::Future,
8 pin::Pin,
9 task::{Context, Poll},
10 time::Duration,
11};
12use tokio::time::Sleep;
13use tower_service::Service;
14
15#[derive(Debug, Clone)]
16pub(crate) struct GrpcTimeout<S> {
17 inner: S,
18 server_timeout: Option<Duration>,
19}
20
21impl<S> GrpcTimeout<S> {
22 pub(crate) fn new(inner: S, server_timeout: Option<Duration>) -> Self {
23 Self {
24 inner,
25 server_timeout,
26 }
27 }
28}
29
30impl<S, ReqBody> Service<Request<ReqBody>> for GrpcTimeout<S>
31where
32 S: Service<Request<ReqBody>>,
33 S::Error: Into<crate::Error>,
34{
35 type Response = S::Response;
36 type Error = crate::Error;
37 type Future = ResponseFuture<S::Future>;
38
39 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
40 self.inner.poll_ready(cx).map_err(Into::into)
41 }
42
43 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
44 let client_timeout = try_parse_grpc_timeout(req.headers()).unwrap_or_else(|e| {
45 tracing::trace!("Error parsing `grpc-timeout` header {:?}", e);
46 None
47 });
48
49 let timeout_duration = match (client_timeout, self.server_timeout) {
51 (None, None) => None,
52 (Some(dur), None) => Some(dur),
53 (None, Some(dur)) => Some(dur),
54 (Some(header), Some(server)) => {
55 let shorter_duration = std::cmp::min(header, server);
56 Some(shorter_duration)
57 }
58 };
59
60 ResponseFuture {
61 inner: self.inner.call(req),
62 sleep: timeout_duration
63 .map(tokio::time::sleep)
64 .map(OptionPin::Some)
65 .unwrap_or(OptionPin::None),
66 }
67 }
68}
69
70#[pin_project]
71pub(crate) struct ResponseFuture<F> {
72 #[pin]
73 inner: F,
74 #[pin]
75 sleep: OptionPin<Sleep>,
76}
77
78impl<F, Res, E> Future for ResponseFuture<F>
79where
80 F: Future<Output = Result<Res, E>>,
81 E: Into<crate::Error>,
82{
83 type Output = Result<Res, crate::Error>;
84
85 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
86 let this = self.project();
87
88 if let Poll::Ready(result) = this.inner.poll(cx) {
89 return Poll::Ready(result.map_err(Into::into));
90 }
91
92 if let OptionPinProj::Some(sleep) = this.sleep.project() {
93 futures_util::ready!(sleep.poll(cx));
94 return Poll::Ready(Err(TimeoutExpired(()).into()));
95 }
96
97 Poll::Pending
98 }
99}
100
101const SECONDS_IN_HOUR: u64 = 60 * 60;
102const SECONDS_IN_MINUTE: u64 = 60;
103
104fn try_parse_grpc_timeout(
109 headers: &HeaderMap<HeaderValue>,
110) -> Result<Option<Duration>, &HeaderValue> {
111 match headers.get(GRPC_TIMEOUT_HEADER) {
112 Some(val) => {
113 let (timeout_value, timeout_unit) = val
114 .to_str()
115 .map_err(|_| val)
116 .and_then(|s| if s.is_empty() { Err(val) } else { Ok(s) })?
117 .split_at(val.len() - 1);
123
124 if timeout_value.len() > 8 {
127 return Err(val);
128 }
129
130 let timeout_value: u64 = timeout_value.parse().map_err(|_| val)?;
131
132 let duration = match timeout_unit {
133 "H" => Duration::from_secs(timeout_value * SECONDS_IN_HOUR),
135 "M" => Duration::from_secs(timeout_value * SECONDS_IN_MINUTE),
137 "S" => Duration::from_secs(timeout_value),
139 "m" => Duration::from_millis(timeout_value),
141 "u" => Duration::from_micros(timeout_value),
143 "n" => Duration::from_nanos(timeout_value),
145 _ => return Err(val),
146 };
147
148 Ok(Some(duration))
149 }
150 None => Ok(None),
151 }
152}
153
154#[derive(Debug)]
163pub struct TimeoutExpired(());
164
165impl fmt::Display for TimeoutExpired {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 write!(f, "Timeout expired")
168 }
169}
170
171impl std::error::Error for TimeoutExpired {}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use quickcheck::{Arbitrary, Gen};
178 use quickcheck_macros::quickcheck;
179
180 fn setup_map_try_parse(val: Option<&str>) -> Result<Option<Duration>, HeaderValue> {
182 let mut hm = HeaderMap::new();
183 if let Some(v) = val {
184 let hv = HeaderValue::from_str(v).unwrap();
185 hm.insert(GRPC_TIMEOUT_HEADER, hv);
186 };
187
188 try_parse_grpc_timeout(&hm).map_err(|e| e.clone())
189 }
190
191 #[test]
192 fn test_hours() {
193 let parsed_duration = setup_map_try_parse(Some("3H")).unwrap().unwrap();
194 assert_eq!(Duration::from_secs(3 * 60 * 60), parsed_duration);
195 }
196
197 #[test]
198 fn test_minutes() {
199 let parsed_duration = setup_map_try_parse(Some("1M")).unwrap().unwrap();
200 assert_eq!(Duration::from_secs(60), parsed_duration);
201 }
202
203 #[test]
204 fn test_seconds() {
205 let parsed_duration = setup_map_try_parse(Some("42S")).unwrap().unwrap();
206 assert_eq!(Duration::from_secs(42), parsed_duration);
207 }
208
209 #[test]
210 fn test_milliseconds() {
211 let parsed_duration = setup_map_try_parse(Some("13m")).unwrap().unwrap();
212 assert_eq!(Duration::from_millis(13), parsed_duration);
213 }
214
215 #[test]
216 fn test_microseconds() {
217 let parsed_duration = setup_map_try_parse(Some("2u")).unwrap().unwrap();
218 assert_eq!(Duration::from_micros(2), parsed_duration);
219 }
220
221 #[test]
222 fn test_nanoseconds() {
223 let parsed_duration = setup_map_try_parse(Some("82n")).unwrap().unwrap();
224 assert_eq!(Duration::from_nanos(82), parsed_duration);
225 }
226
227 #[test]
228 fn test_header_not_present() {
229 let parsed_duration = setup_map_try_parse(None).unwrap();
230 assert!(parsed_duration.is_none());
231 }
232
233 #[test]
234 #[should_panic(expected = "82f")]
235 fn test_invalid_unit() {
236 setup_map_try_parse(Some("82f")).unwrap().unwrap();
238 }
239
240 #[test]
241 #[should_panic(expected = "123456789H")]
242 fn test_too_many_digits() {
243 setup_map_try_parse(Some("123456789H")).unwrap().unwrap();
245 }
246
247 #[test]
248 #[should_panic(expected = "oneH")]
249 fn test_invalid_digits() {
250 setup_map_try_parse(Some("oneH")).unwrap().unwrap();
252 }
253
254 #[quickcheck]
255 fn fuzz(header_value: HeaderValueGen) -> bool {
256 let header_value = header_value.0;
257
258 let _ = setup_map_try_parse(Some(&header_value));
260
261 true
262 }
263
264 #[derive(Clone, Debug)]
266 struct HeaderValueGen(String);
267
268 impl Arbitrary for HeaderValueGen {
269 fn arbitrary(g: &mut Gen) -> Self {
270 let max = g.choose(&(1..70).collect::<Vec<_>>()).copied().unwrap();
271 Self(gen_string(g, 0, max))
272 }
273 }
274
275 fn gen_string(g: &mut Gen, min: usize, max: usize) -> String {
277 let bytes: Vec<_> = (min..max)
278 .map(|_| {
279 g.choose(b"ABCDEFGHIJKLMNOPQRSTUVabcdefghilpqrstuvwxyz----")
281 .copied()
282 .unwrap()
283 })
284 .collect();
285
286 String::from_utf8(bytes).unwrap()
287 }
288}