tonic/server/
grpc.rs

1#[cfg(feature = "compression")]
2use crate::codec::compression::{
3    CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
4};
5use crate::{
6    body::BoxBody,
7    codec::{encode_server, Codec, Streaming},
8    server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
9    Code, Request, Status,
10};
11use futures_core::TryStream;
12use futures_util::{future, stream, TryStreamExt};
13use http_body::Body;
14use std::fmt;
15
16macro_rules! t {
17    ($result:expr) => {
18        match $result {
19            Ok(value) => value,
20            Err(status) => return status.to_http(),
21        }
22    };
23}
24
25/// A gRPC Server handler.
26///
27/// This will wrap some inner [`Codec`] and provide utilities to handle
28/// inbound unary, client side streaming, server side streaming, and
29/// bi-directional streaming.
30///
31/// Each request handler method accepts some service that implements the
32/// corresponding service trait and a http request that contains some body that
33/// implements some [`Body`].
34pub struct Grpc<T> {
35    codec: T,
36    /// Which compression encodings does the server accept for requests?
37    #[cfg(feature = "compression")]
38    accept_compression_encodings: EnabledCompressionEncodings,
39    /// Which compression encodings might the server use for responses.
40    #[cfg(feature = "compression")]
41    send_compression_encodings: EnabledCompressionEncodings,
42}
43
44impl<T> Grpc<T>
45where
46    T: Codec,
47{
48    /// Creates a new gRPC server with the provided [`Codec`].
49    pub fn new(codec: T) -> Self {
50        Self {
51            codec,
52            #[cfg(feature = "compression")]
53            accept_compression_encodings: EnabledCompressionEncodings::default(),
54            #[cfg(feature = "compression")]
55            send_compression_encodings: EnabledCompressionEncodings::default(),
56        }
57    }
58
59    /// Enable accepting `gzip` compressed requests.
60    ///
61    /// If a request with an unsupported encoding is received the server will respond with
62    /// [`Code::UnUnimplemented`](crate::Code).
63    ///
64    /// # Example
65    ///
66    /// The most common way of using this is through a server generated by tonic-build:
67    ///
68    /// ```rust
69    /// # struct Svc;
70    /// # struct ExampleServer<T>(T);
71    /// # impl<T> ExampleServer<T> {
72    /// #     fn new(svc: T) -> Self { Self(svc) }
73    /// #     fn accept_gzip(self) -> Self { self }
74    /// # }
75    /// # #[tonic::async_trait]
76    /// # trait Example {}
77    ///
78    /// #[tonic::async_trait]
79    /// impl Example for Svc {
80    ///     // ...
81    /// }
82    ///
83    /// let service = ExampleServer::new(Svc).accept_gzip();
84    /// ```
85    #[cfg(feature = "compression")]
86    #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
87    pub fn accept_gzip(mut self) -> Self {
88        self.accept_compression_encodings.enable_gzip();
89        self
90    }
91
92    #[doc(hidden)]
93    #[cfg(not(feature = "compression"))]
94    pub fn accept_gzip(self) -> Self {
95        panic!("`accept_gzip` called on a server but the `compression` feature is not enabled on tonic");
96    }
97
98    /// Enable sending `gzip` compressed responses.
99    ///
100    /// Requires the client to also support receiving compressed responses.
101    ///
102    /// # Example
103    ///
104    /// The most common way of using this is through a server generated by tonic-build:
105    ///
106    /// ```rust
107    /// # struct Svc;
108    /// # struct ExampleServer<T>(T);
109    /// # impl<T> ExampleServer<T> {
110    /// #     fn new(svc: T) -> Self { Self(svc) }
111    /// #     fn send_gzip(self) -> Self { self }
112    /// # }
113    /// # #[tonic::async_trait]
114    /// # trait Example {}
115    ///
116    /// #[tonic::async_trait]
117    /// impl Example for Svc {
118    ///     // ...
119    /// }
120    ///
121    /// let service = ExampleServer::new(Svc).send_gzip();
122    /// ```
123    #[cfg(feature = "compression")]
124    #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
125    pub fn send_gzip(mut self) -> Self {
126        self.send_compression_encodings.enable_gzip();
127        self
128    }
129
130    #[doc(hidden)]
131    #[cfg(not(feature = "compression"))]
132    pub fn send_gzip(self) -> Self {
133        panic!(
134            "`send_gzip` called on a server but the `compression` feature is not enabled on tonic"
135        );
136    }
137
138    #[cfg(feature = "compression")]
139    #[doc(hidden)]
140    pub fn apply_compression_config(
141        self,
142        accept_encodings: EnabledCompressionEncodings,
143        send_encodings: EnabledCompressionEncodings,
144    ) -> Self {
145        let mut this = self;
146
147        let EnabledCompressionEncodings { gzip: accept_gzip } = accept_encodings;
148        if accept_gzip {
149            this = this.accept_gzip();
150        }
151
152        let EnabledCompressionEncodings { gzip: send_gzip } = send_encodings;
153        if send_gzip {
154            this = this.send_gzip();
155        }
156
157        this
158    }
159
160    #[cfg(not(feature = "compression"))]
161    #[doc(hidden)]
162    #[allow(unused_variables)]
163    pub fn apply_compression_config(self, accept_encodings: (), send_encodings: ()) -> Self {
164        self
165    }
166
167    /// Handle a single unary gRPC request.
168    pub async fn unary<S, B>(
169        &mut self,
170        mut service: S,
171        req: http::Request<B>,
172    ) -> http::Response<BoxBody>
173    where
174        S: UnaryService<T::Decode, Response = T::Encode>,
175        B: Body + Send + 'static,
176        B::Error: Into<crate::Error> + Send,
177    {
178        #[cfg(feature = "compression")]
179        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
180            req.headers(),
181            self.send_compression_encodings,
182        );
183
184        let request = match self.map_request_unary(req).await {
185            Ok(r) => r,
186            Err(status) => {
187                return self
188                    .map_response::<stream::Once<future::Ready<Result<T::Encode, Status>>>>(
189                        Err(status),
190                        #[cfg(feature = "compression")]
191                        accept_encoding,
192                        #[cfg(feature = "compression")]
193                        SingleMessageCompressionOverride::default(),
194                    );
195            }
196        };
197
198        let response = service
199            .call(request)
200            .await
201            .map(|r| r.map(|m| stream::once(future::ok(m))));
202
203        #[cfg(feature = "compression")]
204        let compression_override = compression_override_from_response(&response);
205
206        self.map_response(
207            response,
208            #[cfg(feature = "compression")]
209            accept_encoding,
210            #[cfg(feature = "compression")]
211            compression_override,
212        )
213    }
214
215    /// Handle a server side streaming request.
216    pub async fn server_streaming<S, B>(
217        &mut self,
218        mut service: S,
219        req: http::Request<B>,
220    ) -> http::Response<BoxBody>
221    where
222        S: ServerStreamingService<T::Decode, Response = T::Encode>,
223        S::ResponseStream: Send + 'static,
224        B: Body + Send + 'static,
225        B::Error: Into<crate::Error> + Send,
226    {
227        #[cfg(feature = "compression")]
228        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
229            req.headers(),
230            self.send_compression_encodings,
231        );
232
233        let request = match self.map_request_unary(req).await {
234            Ok(r) => r,
235            Err(status) => {
236                return self.map_response::<S::ResponseStream>(
237                    Err(status),
238                    #[cfg(feature = "compression")]
239                    accept_encoding,
240                    #[cfg(feature = "compression")]
241                    SingleMessageCompressionOverride::default(),
242                );
243            }
244        };
245
246        let response = service.call(request).await;
247
248        self.map_response(
249            response,
250            #[cfg(feature = "compression")]
251            accept_encoding,
252            // disabling compression of individual stream items must be done on
253            // the items themselves
254            #[cfg(feature = "compression")]
255            SingleMessageCompressionOverride::default(),
256        )
257    }
258
259    /// Handle a client side streaming gRPC request.
260    pub async fn client_streaming<S, B>(
261        &mut self,
262        mut service: S,
263        req: http::Request<B>,
264    ) -> http::Response<BoxBody>
265    where
266        S: ClientStreamingService<T::Decode, Response = T::Encode>,
267        B: Body + Send + 'static,
268        B::Error: Into<crate::Error> + Send + 'static,
269    {
270        #[cfg(feature = "compression")]
271        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
272            req.headers(),
273            self.send_compression_encodings,
274        );
275
276        let request = t!(self.map_request_streaming(req));
277
278        let response = service
279            .call(request)
280            .await
281            .map(|r| r.map(|m| stream::once(future::ok(m))));
282
283        #[cfg(feature = "compression")]
284        let compression_override = compression_override_from_response(&response);
285
286        self.map_response(
287            response,
288            #[cfg(feature = "compression")]
289            accept_encoding,
290            #[cfg(feature = "compression")]
291            compression_override,
292        )
293    }
294
295    /// Handle a bi-directional streaming gRPC request.
296    pub async fn streaming<S, B>(
297        &mut self,
298        mut service: S,
299        req: http::Request<B>,
300    ) -> http::Response<BoxBody>
301    where
302        S: StreamingService<T::Decode, Response = T::Encode> + Send,
303        S::ResponseStream: Send + 'static,
304        B: Body + Send + 'static,
305        B::Error: Into<crate::Error> + Send,
306    {
307        #[cfg(feature = "compression")]
308        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
309            req.headers(),
310            self.send_compression_encodings,
311        );
312
313        let request = t!(self.map_request_streaming(req));
314
315        let response = service.call(request).await;
316
317        self.map_response(
318            response,
319            #[cfg(feature = "compression")]
320            accept_encoding,
321            #[cfg(feature = "compression")]
322            SingleMessageCompressionOverride::default(),
323        )
324    }
325
326    async fn map_request_unary<B>(
327        &mut self,
328        request: http::Request<B>,
329    ) -> Result<Request<T::Decode>, Status>
330    where
331        B: Body + Send + 'static,
332        B::Error: Into<crate::Error> + Send,
333    {
334        #[cfg(feature = "compression")]
335        let request_compression_encoding = self.request_encoding_if_supported(&request)?;
336
337        let (parts, body) = request.into_parts();
338
339        #[cfg(feature = "compression")]
340        let stream =
341            Streaming::new_request(self.codec.decoder(), body, request_compression_encoding);
342
343        #[cfg(not(feature = "compression"))]
344        let stream = Streaming::new_request(self.codec.decoder(), body);
345
346        futures_util::pin_mut!(stream);
347
348        let message = stream
349            .try_next()
350            .await?
351            .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
352
353        let mut req = Request::from_http_parts(parts, message);
354
355        if let Some(trailers) = stream.trailers().await? {
356            req.metadata_mut().merge(trailers);
357        }
358
359        Ok(req)
360    }
361
362    fn map_request_streaming<B>(
363        &mut self,
364        request: http::Request<B>,
365    ) -> Result<Request<Streaming<T::Decode>>, Status>
366    where
367        B: Body + Send + 'static,
368        B::Error: Into<crate::Error> + Send,
369    {
370        #[cfg(feature = "compression")]
371        let encoding = self.request_encoding_if_supported(&request)?;
372
373        #[cfg(feature = "compression")]
374        let request =
375            request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding));
376
377        #[cfg(not(feature = "compression"))]
378        let request = request.map(|body| Streaming::new_request(self.codec.decoder(), body));
379
380        Ok(Request::from_http(request))
381    }
382
383    fn map_response<B>(
384        &mut self,
385        response: Result<crate::Response<B>, Status>,
386        #[cfg(feature = "compression")] accept_encoding: Option<CompressionEncoding>,
387        #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
388    ) -> http::Response<BoxBody>
389    where
390        B: TryStream<Ok = T::Encode, Error = Status> + Send + 'static,
391    {
392        let response = match response {
393            Ok(r) => r,
394            Err(status) => return status.to_http(),
395        };
396
397        let (mut parts, body) = response.into_http().into_parts();
398
399        // Set the content type
400        parts.headers.insert(
401            http::header::CONTENT_TYPE,
402            http::header::HeaderValue::from_static("application/grpc"),
403        );
404
405        #[cfg(feature = "compression")]
406        if let Some(encoding) = accept_encoding {
407            // Set the content encoding
408            parts.headers.insert(
409                crate::codec::compression::ENCODING_HEADER,
410                encoding.into_header_value(),
411            );
412        }
413
414        let body = encode_server(
415            self.codec.encoder(),
416            body.into_stream(),
417            #[cfg(feature = "compression")]
418            accept_encoding,
419            #[cfg(feature = "compression")]
420            compression_override,
421        );
422
423        http::Response::from_parts(parts, BoxBody::new(body))
424    }
425
426    #[cfg(feature = "compression")]
427    fn request_encoding_if_supported<B>(
428        &self,
429        request: &http::Request<B>,
430    ) -> Result<Option<CompressionEncoding>, Status> {
431        CompressionEncoding::from_encoding_header(
432            request.headers(),
433            self.accept_compression_encodings,
434        )
435    }
436}
437
438impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
439    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440        let mut f = f.debug_struct("Grpc");
441
442        f.field("codec", &self.codec);
443
444        #[cfg(feature = "compression")]
445        f.field(
446            "accept_compression_encodings",
447            &self.accept_compression_encodings,
448        );
449
450        #[cfg(feature = "compression")]
451        f.field(
452            "send_compression_encodings",
453            &self.send_compression_encodings,
454        );
455
456        f.finish()
457    }
458}
459
460#[cfg(feature = "compression")]
461fn compression_override_from_response<B, E>(
462    res: &Result<crate::Response<B>, E>,
463) -> SingleMessageCompressionOverride {
464    res.as_ref()
465        .ok()
466        .and_then(|response| {
467            response
468                .extensions()
469                .get::<SingleMessageCompressionOverride>()
470                .copied()
471        })
472        .unwrap_or_default()
473}