tonic/service/
interceptor.rs1use crate::{request::SanitizeHeaders, Status};
6use pin_project::pin_project;
7use std::{
8 fmt,
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tower_layer::Layer;
14use tower_service::Service;
15
16pub trait Interceptor {
42 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status>;
44}
45
46impl<F> Interceptor for F
47where
48 F: FnMut(crate::Request<()>) -> Result<crate::Request<()>, Status>,
49{
50 fn call(&mut self, request: crate::Request<()>) -> Result<crate::Request<()>, Status> {
51 self(request)
52 }
53}
54
55pub fn interceptor<F>(f: F) -> InterceptorLayer<F>
59where
60 F: Interceptor,
61{
62 InterceptorLayer { f }
63}
64
65#[deprecated(
66 since = "0.5.1",
67 note = "Please use the `interceptor` function instead"
68)]
69pub fn interceptor_fn<F>(f: F) -> InterceptorLayer<F>
73where
74 F: Interceptor,
75{
76 interceptor(f)
77}
78
79#[derive(Debug, Clone, Copy)]
84pub struct InterceptorLayer<F> {
85 f: F,
86}
87
88impl<S, F> Layer<S> for InterceptorLayer<F>
89where
90 F: Interceptor + Clone,
91{
92 type Service = InterceptedService<S, F>;
93
94 fn layer(&self, service: S) -> Self::Service {
95 InterceptedService::new(service, self.f.clone())
96 }
97}
98
99#[deprecated(
100 since = "0.5.1",
101 note = "Please use the `InterceptorLayer` type instead"
102)]
103pub type InterceptorFn<F> = InterceptorLayer<F>;
108
109#[derive(Clone, Copy)]
113pub struct InterceptedService<S, F> {
114 inner: S,
115 f: F,
116}
117
118impl<S, F> InterceptedService<S, F> {
119 pub fn new(service: S, f: F) -> Self
122 where
123 F: Interceptor,
124 {
125 Self { inner: service, f }
126 }
127}
128
129impl<S, F> fmt::Debug for InterceptedService<S, F>
130where
131 S: fmt::Debug,
132{
133 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134 f.debug_struct("InterceptedService")
135 .field("inner", &self.inner)
136 .field("f", &format_args!("{}", std::any::type_name::<F>()))
137 .finish()
138 }
139}
140
141impl<S, F, ReqBody, ResBody> Service<http::Request<ReqBody>> for InterceptedService<S, F>
142where
143 F: Interceptor,
144 S: Service<http::Request<ReqBody>, Response = http::Response<ResBody>>,
145 S::Error: Into<crate::Error>,
146{
147 type Response = http::Response<ResBody>;
148 type Error = crate::Error;
149 type Future = ResponseFuture<S::Future>;
150
151 #[inline]
152 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
153 self.inner.poll_ready(cx).map_err(Into::into)
154 }
155
156 fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
157 let uri = req.uri().clone();
158 let req = crate::Request::from_http(req);
159 let (metadata, extensions, msg) = req.into_parts();
160
161 match self
162 .f
163 .call(crate::Request::from_parts(metadata, extensions, ()))
164 {
165 Ok(req) => {
166 let (metadata, extensions, _) = req.into_parts();
167 let req = crate::Request::from_parts(metadata, extensions, msg);
168 let req = req.into_http(uri, SanitizeHeaders::No);
169 ResponseFuture::future(self.inner.call(req))
170 }
171 Err(status) => ResponseFuture::error(status),
172 }
173 }
174}
175
176#[cfg(feature = "transport")]
178impl<S, F> crate::transport::NamedService for InterceptedService<S, F>
179where
180 S: crate::transport::NamedService,
181{
182 const NAME: &'static str = S::NAME;
183}
184
185#[pin_project]
187#[derive(Debug)]
188pub struct ResponseFuture<F> {
189 #[pin]
190 kind: Kind<F>,
191}
192
193impl<F> ResponseFuture<F> {
194 fn future(future: F) -> Self {
195 Self {
196 kind: Kind::Future(future),
197 }
198 }
199
200 fn error(status: Status) -> Self {
201 Self {
202 kind: Kind::Error(Some(status)),
203 }
204 }
205}
206
207#[pin_project(project = KindProj)]
208#[derive(Debug)]
209enum Kind<F> {
210 Future(#[pin] F),
211 Error(Option<Status>),
212}
213
214impl<F, E, B> Future for ResponseFuture<F>
215where
216 F: Future<Output = Result<http::Response<B>, E>>,
217 E: Into<crate::Error>,
218{
219 type Output = Result<http::Response<B>, crate::Error>;
220
221 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
222 match self.project().kind.project() {
223 KindProj::Future(future) => future.poll(cx).map_err(Into::into),
224 KindProj::Error(status) => {
225 let error = status.take().unwrap().into();
226 Poll::Ready(Err(error))
227 }
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 #[allow(unused_imports)]
235 use super::*;
236 use tower::ServiceExt;
237
238 #[tokio::test]
239 async fn doesnt_remove_headers() {
240 let svc = tower::service_fn(|request: http::Request<hyper::Body>| async move {
241 assert_eq!(
242 request
243 .headers()
244 .get("user-agent")
245 .expect("missing in leaf service"),
246 "test-tonic"
247 );
248
249 Ok::<_, hyper::Error>(hyper::Response::new(hyper::Body::empty()))
250 });
251
252 let svc = InterceptedService::new(svc, |request: crate::Request<()>| {
253 assert_eq!(
254 request
255 .metadata()
256 .get("user-agent")
257 .expect("missing in interceptor"),
258 "test-tonic"
259 );
260 Ok(request)
261 });
262
263 let request = http::Request::builder()
264 .header("user-agent", "test-tonic")
265 .body(hyper::Body::empty())
266 .unwrap();
267
268 svc.oneshot(request).await.unwrap();
269 }
270}