1#[cfg(feature = "compression")]
2use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
3use crate::{
4 body::BoxBody,
5 client::GrpcService,
6 codec::{encode_client, Codec, Streaming},
7 request::SanitizeHeaders,
8 Code, Request, Response, Status,
9};
10use futures_core::Stream;
11use futures_util::{future, stream, TryStreamExt};
12use http::{
13 header::{HeaderValue, CONTENT_TYPE, TE},
14 uri::{Parts, PathAndQuery, Uri},
15};
16use http_body::Body;
17use std::fmt;
18
19pub struct Grpc<T> {
33 inner: T,
34 #[cfg(feature = "compression")]
35 accept_compression_encodings: EnabledCompressionEncodings,
37 #[cfg(feature = "compression")]
38 send_compression_encodings: Option<CompressionEncoding>,
40}
41
42impl<T> Grpc<T> {
43 pub fn new(inner: T) -> Self {
45 Self {
46 inner,
47 #[cfg(feature = "compression")]
48 send_compression_encodings: None,
49 #[cfg(feature = "compression")]
50 accept_compression_encodings: EnabledCompressionEncodings::default(),
51 }
52 }
53
54 #[cfg(feature = "compression")]
80 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
81 pub fn send_gzip(mut self) -> Self {
82 self.send_compression_encodings = Some(CompressionEncoding::Gzip);
83 self
84 }
85
86 #[doc(hidden)]
87 #[cfg(not(feature = "compression"))]
88 pub fn send_gzip(self) -> Self {
89 panic!(
90 "`send_gzip` called on a client but the `compression` feature is not enabled on tonic"
91 );
92 }
93
94 #[cfg(feature = "compression")]
120 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
121 pub fn accept_gzip(mut self) -> Self {
122 self.accept_compression_encodings.enable_gzip();
123 self
124 }
125
126 #[doc(hidden)]
127 #[cfg(not(feature = "compression"))]
128 pub fn accept_gzip(self) -> Self {
129 panic!("`accept_gzip` called on a client but the `compression` feature is not enabled on tonic");
130 }
131
132 pub async fn ready(&mut self) -> Result<(), T::Error>
138 where
139 T: GrpcService<BoxBody>,
140 {
141 future::poll_fn(|cx| self.inner.poll_ready(cx)).await
142 }
143
144 pub async fn unary<M1, M2, C>(
146 &mut self,
147 request: Request<M1>,
148 path: PathAndQuery,
149 codec: C,
150 ) -> Result<Response<M2>, Status>
151 where
152 T: GrpcService<BoxBody>,
153 T::ResponseBody: Body + Send + 'static,
154 <T::ResponseBody as Body>::Error: Into<crate::Error>,
155 C: Codec<Encode = M1, Decode = M2>,
156 M1: Send + Sync + 'static,
157 M2: Send + Sync + 'static,
158 {
159 let request = request.map(|m| stream::once(future::ready(m)));
160 self.client_streaming(request, path, codec).await
161 }
162
163 pub async fn client_streaming<S, M1, M2, C>(
165 &mut self,
166 request: Request<S>,
167 path: PathAndQuery,
168 codec: C,
169 ) -> Result<Response<M2>, Status>
170 where
171 T: GrpcService<BoxBody>,
172 T::ResponseBody: Body + Send + 'static,
173 <T::ResponseBody as Body>::Error: Into<crate::Error>,
174 S: Stream<Item = M1> + Send + 'static,
175 C: Codec<Encode = M1, Decode = M2>,
176 M1: Send + Sync + 'static,
177 M2: Send + Sync + 'static,
178 {
179 let (mut parts, body, extensions) =
180 self.streaming(request, path, codec).await?.into_parts();
181
182 futures_util::pin_mut!(body);
183
184 let message = body
185 .try_next()
186 .await
187 .map_err(|mut status| {
188 status.metadata_mut().merge(parts.clone());
189 status
190 })?
191 .ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?;
192
193 if let Some(trailers) = body.trailers().await? {
194 parts.merge(trailers);
195 }
196
197 Ok(Response::from_parts(parts, message, extensions))
198 }
199
200 pub async fn server_streaming<M1, M2, C>(
202 &mut self,
203 request: Request<M1>,
204 path: PathAndQuery,
205 codec: C,
206 ) -> Result<Response<Streaming<M2>>, Status>
207 where
208 T: GrpcService<BoxBody>,
209 T::ResponseBody: Body + Send + 'static,
210 <T::ResponseBody as Body>::Error: Into<crate::Error>,
211 C: Codec<Encode = M1, Decode = M2>,
212 M1: Send + Sync + 'static,
213 M2: Send + Sync + 'static,
214 {
215 let request = request.map(|m| stream::once(future::ready(m)));
216 self.streaming(request, path, codec).await
217 }
218
219 pub async fn streaming<S, M1, M2, C>(
221 &mut self,
222 request: Request<S>,
223 path: PathAndQuery,
224 mut codec: C,
225 ) -> Result<Response<Streaming<M2>>, Status>
226 where
227 T: GrpcService<BoxBody>,
228 T::ResponseBody: Body + Send + 'static,
229 <T::ResponseBody as Body>::Error: Into<crate::Error>,
230 S: Stream<Item = M1> + Send + 'static,
231 C: Codec<Encode = M1, Decode = M2>,
232 M1: Send + Sync + 'static,
233 M2: Send + Sync + 'static,
234 {
235 let mut parts = Parts::default();
236 parts.path_and_query = Some(path);
237
238 let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
239
240 let request = request
241 .map(|s| {
242 encode_client(
243 codec.encoder(),
244 s,
245 #[cfg(feature = "compression")]
246 self.send_compression_encodings,
247 )
248 })
249 .map(BoxBody::new);
250
251 let mut request = request.into_http(uri, SanitizeHeaders::Yes);
252
253 request
255 .headers_mut()
256 .insert(TE, HeaderValue::from_static("trailers"));
257
258 request
260 .headers_mut()
261 .insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc"));
262
263 #[cfg(feature = "compression")]
264 {
265 if let Some(encoding) = self.send_compression_encodings {
266 request.headers_mut().insert(
267 crate::codec::compression::ENCODING_HEADER,
268 encoding.into_header_value(),
269 );
270 }
271
272 if let Some(header_value) = self
273 .accept_compression_encodings
274 .into_accept_encoding_header_value()
275 {
276 request.headers_mut().insert(
277 crate::codec::compression::ACCEPT_ENCODING_HEADER,
278 header_value,
279 );
280 }
281 }
282
283 let response = self
284 .inner
285 .call(request)
286 .await
287 .map_err(|err| Status::from_error(err.into()))?;
288
289 #[cfg(feature = "compression")]
290 let encoding = CompressionEncoding::from_encoding_header(
291 response.headers(),
292 self.accept_compression_encodings,
293 )?;
294
295 let status_code = response.status();
296 let trailers_only_status = Status::from_header_map(response.headers());
297
298 let expect_additional_trailers = if let Some(status) = trailers_only_status {
301 if status.code() != Code::Ok {
302 return Err(status);
303 }
304
305 false
306 } else {
307 true
308 };
309
310 let response = response.map(|body| {
311 if expect_additional_trailers {
312 Streaming::new_response(
313 codec.decoder(),
314 body,
315 status_code,
316 #[cfg(feature = "compression")]
317 encoding,
318 )
319 } else {
320 Streaming::new_empty(codec.decoder(), body)
321 }
322 });
323
324 Ok(Response::from_http(response))
325 }
326}
327
328impl<T: Clone> Clone for Grpc<T> {
329 fn clone(&self) -> Self {
330 Self {
331 inner: self.inner.clone(),
332 #[cfg(feature = "compression")]
333 send_compression_encodings: self.send_compression_encodings,
334 #[cfg(feature = "compression")]
335 accept_compression_encodings: self.accept_compression_encodings,
336 }
337 }
338}
339
340impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
341 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
342 let mut f = f.debug_struct("Grpc");
343
344 f.field("inner", &self.inner);
345
346 #[cfg(feature = "compression")]
347 f.field("compression_encoding", &self.send_compression_encodings);
348
349 #[cfg(feature = "compression")]
350 f.field(
351 "accept_compression_encodings",
352 &self.accept_compression_encodings,
353 );
354
355 f.finish()
356 }
357}