tonic/client/
grpc.rs

1#[cfg(feature = "compression")]
2use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
3use crate::{
4    body::BoxBody,
5    client::GrpcService,
6    codec::{encode_client, Codec, Streaming},
7    request::SanitizeHeaders,
8    Code, Request, Response, Status,
9};
10use futures_core::Stream;
11use futures_util::{future, stream, TryStreamExt};
12use http::{
13    header::{HeaderValue, CONTENT_TYPE, TE},
14    uri::{Parts, PathAndQuery, Uri},
15};
16use http_body::Body;
17use std::fmt;
18
19/// A gRPC client dispatcher.
20///
21/// This will wrap some inner [`GrpcService`] and will encode/decode
22/// messages via the provided codec.
23///
24/// Each request method takes a [`Request`], a [`PathAndQuery`], and a
25/// [`Codec`]. The request contains the message to send via the
26/// [`Codec::encoder`]. The path determines the fully qualified path
27/// that will be append to the outgoing uri. The path must follow
28/// the conventions explained in the [gRPC protocol definition] under `Path →`. An
29/// example of this path could look like `/greeter.Greeter/SayHello`.
30///
31/// [gRPC protocol definition]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests
32pub struct Grpc<T> {
33    inner: T,
34    #[cfg(feature = "compression")]
35    /// Which compression encodings does the client accept?
36    accept_compression_encodings: EnabledCompressionEncodings,
37    #[cfg(feature = "compression")]
38    /// The compression encoding that will be applied to requests.
39    send_compression_encodings: Option<CompressionEncoding>,
40}
41
42impl<T> Grpc<T> {
43    /// Creates a new gRPC client with the provided [`GrpcService`].
44    pub fn new(inner: T) -> Self {
45        Self {
46            inner,
47            #[cfg(feature = "compression")]
48            send_compression_encodings: None,
49            #[cfg(feature = "compression")]
50            accept_compression_encodings: EnabledCompressionEncodings::default(),
51        }
52    }
53
54    /// Compress requests with `gzip`.
55    ///
56    /// Requires the server to accept `gzip` otherwise it might return an error.
57    ///
58    /// # Example
59    ///
60    /// The most common way of using this is through a client generated by tonic-build:
61    ///
62    /// ```rust
63    /// use tonic::transport::Channel;
64    /// # struct TestClient<T>(T);
65    /// # impl<T> TestClient<T> {
66    /// #     fn new(channel: T) -> Self { Self(channel) }
67    /// #     fn send_gzip(self) -> Self { self }
68    /// # }
69    ///
70    /// # async {
71    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
72    ///     .connect()
73    ///     .await
74    ///     .unwrap();
75    ///
76    /// let client = TestClient::new(channel).send_gzip();
77    /// # };
78    /// ```
79    #[cfg(feature = "compression")]
80    #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
81    pub fn send_gzip(mut self) -> Self {
82        self.send_compression_encodings = Some(CompressionEncoding::Gzip);
83        self
84    }
85
86    #[doc(hidden)]
87    #[cfg(not(feature = "compression"))]
88    pub fn send_gzip(self) -> Self {
89        panic!(
90            "`send_gzip` called on a client but the `compression` feature is not enabled on tonic"
91        );
92    }
93
94    /// Enable accepting `gzip` compressed responses.
95    ///
96    /// Requires the server to also support sending compressed responses.
97    ///
98    /// # Example
99    ///
100    /// The most common way of using this is through a client generated by tonic-build:
101    ///
102    /// ```rust
103    /// use tonic::transport::Channel;
104    /// # struct TestClient<T>(T);
105    /// # impl<T> TestClient<T> {
106    /// #     fn new(channel: T) -> Self { Self(channel) }
107    /// #     fn accept_gzip(self) -> Self { self }
108    /// # }
109    ///
110    /// # async {
111    /// let channel = Channel::builder("127.0.0.1:3000".parse().unwrap())
112    ///     .connect()
113    ///     .await
114    ///     .unwrap();
115    ///
116    /// let client = TestClient::new(channel).accept_gzip();
117    /// # };
118    /// ```
119    #[cfg(feature = "compression")]
120    #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
121    pub fn accept_gzip(mut self) -> Self {
122        self.accept_compression_encodings.enable_gzip();
123        self
124    }
125
126    #[doc(hidden)]
127    #[cfg(not(feature = "compression"))]
128    pub fn accept_gzip(self) -> Self {
129        panic!("`accept_gzip` called on a client but the `compression` feature is not enabled on tonic");
130    }
131
132    /// Check if the inner [`GrpcService`] is able to accept a  new request.
133    ///
134    /// This will call [`GrpcService::poll_ready`] until it returns ready or
135    /// an error. If this returns ready the inner [`GrpcService`] is ready to
136    /// accept one more request.
137    pub async fn ready(&mut self) -> Result<(), T::Error>
138    where
139        T: GrpcService<BoxBody>,
140    {
141        future::poll_fn(|cx| self.inner.poll_ready(cx)).await
142    }
143
144    /// Send a single unary gRPC request.
145    pub async fn unary<M1, M2, C>(
146        &mut self,
147        request: Request<M1>,
148        path: PathAndQuery,
149        codec: C,
150    ) -> Result<Response<M2>, Status>
151    where
152        T: GrpcService<BoxBody>,
153        T::ResponseBody: Body + Send + 'static,
154        <T::ResponseBody as Body>::Error: Into<crate::Error>,
155        C: Codec<Encode = M1, Decode = M2>,
156        M1: Send + Sync + 'static,
157        M2: Send + Sync + 'static,
158    {
159        let request = request.map(|m| stream::once(future::ready(m)));
160        self.client_streaming(request, path, codec).await
161    }
162
163    /// Send a client side streaming gRPC request.
164    pub async fn client_streaming<S, M1, M2, C>(
165        &mut self,
166        request: Request<S>,
167        path: PathAndQuery,
168        codec: C,
169    ) -> Result<Response<M2>, Status>
170    where
171        T: GrpcService<BoxBody>,
172        T::ResponseBody: Body + Send + 'static,
173        <T::ResponseBody as Body>::Error: Into<crate::Error>,
174        S: Stream<Item = M1> + Send + 'static,
175        C: Codec<Encode = M1, Decode = M2>,
176        M1: Send + Sync + 'static,
177        M2: Send + Sync + 'static,
178    {
179        let (mut parts, body, extensions) =
180            self.streaming(request, path, codec).await?.into_parts();
181
182        futures_util::pin_mut!(body);
183
184        let message = body
185            .try_next()
186            .await
187            .map_err(|mut status| {
188                status.metadata_mut().merge(parts.clone());
189                status
190            })?
191            .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?;
192
193        if let Some(trailers) = body.trailers().await? {
194            parts.merge(trailers);
195        }
196
197        Ok(Response::from_parts(parts, message, extensions))
198    }
199
200    /// Send a server side streaming gRPC request.
201    pub async fn server_streaming<M1, M2, C>(
202        &mut self,
203        request: Request<M1>,
204        path: PathAndQuery,
205        codec: C,
206    ) -> Result<Response<Streaming<M2>>, Status>
207    where
208        T: GrpcService<BoxBody>,
209        T::ResponseBody: Body + Send + 'static,
210        <T::ResponseBody as Body>::Error: Into<crate::Error>,
211        C: Codec<Encode = M1, Decode = M2>,
212        M1: Send + Sync + 'static,
213        M2: Send + Sync + 'static,
214    {
215        let request = request.map(|m| stream::once(future::ready(m)));
216        self.streaming(request, path, codec).await
217    }
218
219    /// Send a bi-directional streaming gRPC request.
220    pub async fn streaming<S, M1, M2, C>(
221        &mut self,
222        request: Request<S>,
223        path: PathAndQuery,
224        mut codec: C,
225    ) -> Result<Response<Streaming<M2>>, Status>
226    where
227        T: GrpcService<BoxBody>,
228        T::ResponseBody: Body + Send + 'static,
229        <T::ResponseBody as Body>::Error: Into<crate::Error>,
230        S: Stream<Item = M1> + Send + 'static,
231        C: Codec<Encode = M1, Decode = M2>,
232        M1: Send + Sync + 'static,
233        M2: Send + Sync + 'static,
234    {
235        let mut parts = Parts::default();
236        parts.path_and_query = Some(path);
237
238        let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
239
240        let request = request
241            .map(|s| {
242                encode_client(
243                    codec.encoder(),
244                    s,
245                    #[cfg(feature = "compression")]
246                    self.send_compression_encodings,
247                )
248            })
249            .map(BoxBody::new);
250
251        let mut request = request.into_http(uri, SanitizeHeaders::Yes);
252
253        // Add the gRPC related HTTP headers
254        request
255            .headers_mut()
256            .insert(TE, HeaderValue::from_static("trailers"));
257
258        // Set the content type
259        request
260            .headers_mut()
261            .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc"));
262
263        #[cfg(feature = "compression")]
264        {
265            if let Some(encoding) = self.send_compression_encodings {
266                request.headers_mut().insert(
267                    crate::codec::compression::ENCODING_HEADER,
268                    encoding.into_header_value(),
269                );
270            }
271
272            if let Some(header_value) = self
273                .accept_compression_encodings
274                .into_accept_encoding_header_value()
275            {
276                request.headers_mut().insert(
277                    crate::codec::compression::ACCEPT_ENCODING_HEADER,
278                    header_value,
279                );
280            }
281        }
282
283        let response = self
284            .inner
285            .call(request)
286            .await
287            .map_err(|err| Status::from_error(err.into()))?;
288
289        #[cfg(feature = "compression")]
290        let encoding = CompressionEncoding::from_encoding_header(
291            response.headers(),
292            self.accept_compression_encodings,
293        )?;
294
295        let status_code = response.status();
296        let trailers_only_status = Status::from_header_map(response.headers());
297
298        // We do not need to check for trailers if the `grpc-status` header is present
299        // with a valid code.
300        let expect_additional_trailers = if let Some(status) = trailers_only_status {
301            if status.code() != Code::Ok {
302                return Err(status);
303            }
304
305            false
306        } else {
307            true
308        };
309
310        let response = response.map(|body| {
311            if expect_additional_trailers {
312                Streaming::new_response(
313                    codec.decoder(),
314                    body,
315                    status_code,
316                    #[cfg(feature = "compression")]
317                    encoding,
318                )
319            } else {
320                Streaming::new_empty(codec.decoder(), body)
321            }
322        });
323
324        Ok(Response::from_http(response))
325    }
326}
327
328impl<T: Clone> Clone for Grpc<T> {
329    fn clone(&self) -> Self {
330        Self {
331            inner: self.inner.clone(),
332            #[cfg(feature = "compression")]
333            send_compression_encodings: self.send_compression_encodings,
334            #[cfg(feature = "compression")]
335            accept_compression_encodings: self.accept_compression_encodings,
336        }
337    }
338}
339
340impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342        let mut f = f.debug_struct("Grpc");
343
344        f.field("inner", &self.inner);
345
346        #[cfg(feature = "compression")]
347        f.field("compression_encoding", &self.send_compression_encodings);
348
349        #[cfg(feature = "compression")]
350        f.field(
351            "accept_compression_encodings",
352            &self.accept_compression_encodings,
353        );
354
355        f.finish()
356    }
357}