tonic/service/
interceptor.rs

1//! gRPC interceptors which are a kind of middleware.
2//!
3//! See [`Interceptor`] for more details.
4
5use crate::{request::SanitizeHeaders, Status};
6use pin_project::pin_project;
7use std::{
8    fmt,
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16/// A gRPC interceptor.
17///
18/// gRPC interceptors are similar to middleware but have less flexibility. An interceptor allows
19/// you to do two main things, one is to add/remove/check items in the `MetadataMap` of each
20/// request. Two, cancel a request with a `Status`.
21///
22/// Any function that satisfies the bound `FnMut(Request<()>) -> Result<Request<()>, Status>` can be
23/// used as an `Interceptor`.
24///
25/// An interceptor can be used on both the server and client side through the `tonic-build` crate's
26/// generated structs.
27///
28/// See the [interceptor example][example] for more details.
29///
30/// If you need more powerful middleware, [tower] is the recommended approach. You can find
31/// examples of how to use tower with tonic [here][tower-example].
32///
33/// Additionally, interceptors is not the recommended way to add logging to your service. For that
34/// a [tower] middleware is more appropriate since it can also act on the response. For example
35/// tower-http's [`Trace`](https://docs.rs/tower-http/latest/tower_http/trace/index.html)
36/// middleware supports gRPC out of the box.
37///
38/// [tower]: https://crates.io/crates/tower
39/// [example]: https://github.com/hyperium/tonic/tree/master/examples/src/interceptor
40/// [tower-example]: https://github.com/hyperium/tonic/tree/master/examples/src/tower
41pub trait Interceptor {
42    /// Intercept a request before it is sent, optionally cancelling it.
43    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
44}
45
46impl<F> Interceptor for F
47where
48    F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
49{
50    fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
51        self(request)
52    }
53}
54
55/// Create a new interceptor layer.
56///
57/// See [`Interceptor`] for more details.
58pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
59where
60    F: Interceptor,
61{
62    InterceptorLayer { f }
63}
64
65#[deprecated(
66    since = "0.5.1",
67    note = "Please use the `interceptor` function instead"
68)]
69/// Create a new interceptor layer.
70///
71/// See [`Interceptor`] for more details.
72pub fn interceptor_fn<F>(f: F) -> InterceptorLayer<F>
73where
74    F: Interceptor,
75{
76    interceptor(f)
77}
78
79/// A gRPC interceptor that can be used as a [`Layer`],
80/// created by calling [`interceptor`].
81///
82/// See [`Interceptor`] for more details.
83#[derive(Debug, Clone, Copy)]
84pub struct InterceptorLayer<F> {
85    f: F,
86}
87
88impl<S, F> Layer<S> for InterceptorLayer<F>
89where
90    F: Interceptor + Clone,
91{
92    type Service = InterceptedService<S, F>;
93
94    fn layer(&self, service: S) -> Self::Service {
95        InterceptedService::new(service, self.f.clone())
96    }
97}
98
99#[deprecated(
100    since = "0.5.1",
101    note = "Please use the `InterceptorLayer` type instead"
102)]
103/// A gRPC interceptor that can be used as a [`Layer`],
104/// created by calling [`interceptor`].
105///
106/// See [`Interceptor`] for more details.
107pub type InterceptorFn<F> = InterceptorLayer<F>;
108
109/// A service wrapped in an interceptor middleware.
110///
111/// See [`Interceptor`] for more details.
112#[derive(Clone, Copy)]
113pub struct InterceptedService<S, F> {
114    inner: S,
115    f: F,
116}
117
118impl<S, F> InterceptedService<S, F> {
119    /// Create a new `InterceptedService` that wraps `S` and intercepts each request with the
120    /// function `F`.
121    pub fn new(service: S, f: F) -> Self
122    where
123        F: Interceptor,
124    {
125        Self { inner: service, f }
126    }
127}
128
129impl<S, F> fmt::Debug for InterceptedService<S, F>
130where
131    S: fmt::Debug,
132{
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("InterceptedService")
135            .field("inner", &self.inner)
136            .field("f", &format_args!("{}", std::any::type_name::<F>()))
137            .finish()
138    }
139}
140
141impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
142where
143    F: Interceptor,
144    S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
145    S::Error: Into<crate::Error>,
146{
147    type Response = http::Response<ResBody>;
148    type Error = crate::Error;
149    type Future = ResponseFuture<S::Future>;
150
151    #[inline]
152    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
153        self.inner.poll_ready(cx).map_err(Into::into)
154    }
155
156    fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
157        let uri = req.uri().clone();
158        let req = crate::Request::from_http(req);
159        let (metadata, extensions, msg) = req.into_parts();
160
161        match self
162            .f
163            .call(crate::Request::from_parts(metadata, extensions, ()))
164        {
165            Ok(req) => {
166                let (metadata, extensions, _) = req.into_parts();
167                let req = crate::Request::from_parts(metadata, extensions, msg);
168                let req = req.into_http(uri, SanitizeHeaders::No);
169                ResponseFuture::future(self.inner.call(req))
170            }
171            Err(status) => ResponseFuture::error(status),
172        }
173    }
174}
175
176// required to use `InterceptedService` with `Router`
177#[cfg(feature = "transport")]
178impl<S, F> crate::transport::NamedService for InterceptedService<S, F>
179where
180    S: crate::transport::NamedService,
181{
182    const NAME: &'static str = S::NAME;
183}
184
185/// Response future for [`InterceptedService`].
186#[pin_project]
187#[derive(Debug)]
188pub struct ResponseFuture<F> {
189    #[pin]
190    kind: Kind<F>,
191}
192
193impl<F> ResponseFuture<F> {
194    fn future(future: F) -> Self {
195        Self {
196            kind: Kind::Future(future),
197        }
198    }
199
200    fn error(status: Status) -> Self {
201        Self {
202            kind: Kind::Error(Some(status)),
203        }
204    }
205}
206
207#[pin_project(project = KindProj)]
208#[derive(Debug)]
209enum Kind<F> {
210    Future(#[pin] F),
211    Error(Option<Status>),
212}
213
214impl<F, E, B> Future for ResponseFuture<F>
215where
216    F: Future<Output = Result<http::Response<B>, E>>,
217    E: Into<crate::Error>,
218{
219    type Output = Result<http::Response<B>, crate::Error>;
220
221    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222        match self.project().kind.project() {
223            KindProj::Future(future) => future.poll(cx).map_err(Into::into),
224            KindProj::Error(status) => {
225                let error = status.take().unwrap().into();
226                Poll::Ready(Err(error))
227            }
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    #[allow(unused_imports)]
235    use super::*;
236    use tower::ServiceExt;
237
238    #[tokio::test]
239    async fn doesnt_remove_headers() {
240        let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
241            assert_eq!(
242                request
243                    .headers()
244                    .get("user-agent")
245                    .expect("missing in leaf service"),
246                "test-tonic"
247            );
248
249            Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
250        });
251
252        let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
253            assert_eq!(
254                request
255                    .metadata()
256                    .get("user-agent")
257                    .expect("missing in interceptor"),
258                "test-tonic"
259            );
260            Ok(request)
261        });
262
263        let request = http::Request::builder()
264            .header("user-agent", "test-tonic")
265            .body(hyper::Body::empty())
266            .unwrap();
267
268        svc.oneshot(request).await.unwrap();
269    }
270}