1#[cfg(feature = "compression")]
2use crate::codec::compression::{
3 CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
4};
5use crate::{
6 body::BoxBody,
7 codec::{encode_server, Codec, Streaming},
8 server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
9 Code, Request, Status,
10};
11use futures_core::TryStream;
12use futures_util::{future, stream, TryStreamExt};
13use http_body::Body;
14use std::fmt;
15
16macro_rules! t {
17 ($result:expr) => {
18 match $result {
19 Ok(value) => value,
20 Err(status) => return status.to_http(),
21 }
22 };
23}
24
25pub struct Grpc<T> {
35 codec: T,
36 #[cfg(feature = "compression")]
38 accept_compression_encodings: EnabledCompressionEncodings,
39 #[cfg(feature = "compression")]
41 send_compression_encodings: EnabledCompressionEncodings,
42}
43
44impl<T> Grpc<T>
45where
46 T: Codec,
47{
48 pub fn new(codec: T) -> Self {
50 Self {
51 codec,
52 #[cfg(feature = "compression")]
53 accept_compression_encodings: EnabledCompressionEncodings::default(),
54 #[cfg(feature = "compression")]
55 send_compression_encodings: EnabledCompressionEncodings::default(),
56 }
57 }
58
59 #[cfg(feature = "compression")]
86 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
87 pub fn accept_gzip(mut self) -> Self {
88 self.accept_compression_encodings.enable_gzip();
89 self
90 }
91
92 #[doc(hidden)]
93 #[cfg(not(feature = "compression"))]
94 pub fn accept_gzip(self) -> Self {
95 panic!("`accept_gzip` called on a server but the `compression` feature is not enabled on tonic");
96 }
97
98 #[cfg(feature = "compression")]
124 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
125 pub fn send_gzip(mut self) -> Self {
126 self.send_compression_encodings.enable_gzip();
127 self
128 }
129
130 #[doc(hidden)]
131 #[cfg(not(feature = "compression"))]
132 pub fn send_gzip(self) -> Self {
133 panic!(
134 "`send_gzip` called on a server but the `compression` feature is not enabled on tonic"
135 );
136 }
137
138 #[cfg(feature = "compression")]
139 #[doc(hidden)]
140 pub fn apply_compression_config(
141 self,
142 accept_encodings: EnabledCompressionEncodings,
143 send_encodings: EnabledCompressionEncodings,
144 ) -> Self {
145 let mut this = self;
146
147 let EnabledCompressionEncodings { gzip: accept_gzip } = accept_encodings;
148 if accept_gzip {
149 this = this.accept_gzip();
150 }
151
152 let EnabledCompressionEncodings { gzip: send_gzip } = send_encodings;
153 if send_gzip {
154 this = this.send_gzip();
155 }
156
157 this
158 }
159
160 #[cfg(not(feature = "compression"))]
161 #[doc(hidden)]
162 #[allow(unused_variables)]
163 pub fn apply_compression_config(self, accept_encodings: (), send_encodings: ()) -> Self {
164 self
165 }
166
167 pub async fn unary<S, B>(
169 &mut self,
170 mut service: S,
171 req: http::Request<B>,
172 ) -> http::Response<BoxBody>
173 where
174 S: UnaryService<T::Decode, Response = T::Encode>,
175 B: Body + Send + 'static,
176 B::Error: Into<crate::Error> + Send,
177 {
178 #[cfg(feature = "compression")]
179 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
180 req.headers(),
181 self.send_compression_encodings,
182 );
183
184 let request = match self.map_request_unary(req).await {
185 Ok(r) => r,
186 Err(status) => {
187 return self
188 .map_response::<stream::Once<future::Ready<Result<T::Encode, Status>>>>(
189 Err(status),
190 #[cfg(feature = "compression")]
191 accept_encoding,
192 #[cfg(feature = "compression")]
193 SingleMessageCompressionOverride::default(),
194 );
195 }
196 };
197
198 let response = service
199 .call(request)
200 .await
201 .map(|r| r.map(|m| stream::once(future::ok(m))));
202
203 #[cfg(feature = "compression")]
204 let compression_override = compression_override_from_response(&response);
205
206 self.map_response(
207 response,
208 #[cfg(feature = "compression")]
209 accept_encoding,
210 #[cfg(feature = "compression")]
211 compression_override,
212 )
213 }
214
215 pub async fn server_streaming<S, B>(
217 &mut self,
218 mut service: S,
219 req: http::Request<B>,
220 ) -> http::Response<BoxBody>
221 where
222 S: ServerStreamingService<T::Decode, Response = T::Encode>,
223 S::ResponseStream: Send + 'static,
224 B: Body + Send + 'static,
225 B::Error: Into<crate::Error> + Send,
226 {
227 #[cfg(feature = "compression")]
228 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
229 req.headers(),
230 self.send_compression_encodings,
231 );
232
233 let request = match self.map_request_unary(req).await {
234 Ok(r) => r,
235 Err(status) => {
236 return self.map_response::<S::ResponseStream>(
237 Err(status),
238 #[cfg(feature = "compression")]
239 accept_encoding,
240 #[cfg(feature = "compression")]
241 SingleMessageCompressionOverride::default(),
242 );
243 }
244 };
245
246 let response = service.call(request).await;
247
248 self.map_response(
249 response,
250 #[cfg(feature = "compression")]
251 accept_encoding,
252 #[cfg(feature = "compression")]
255 SingleMessageCompressionOverride::default(),
256 )
257 }
258
259 pub async fn client_streaming<S, B>(
261 &mut self,
262 mut service: S,
263 req: http::Request<B>,
264 ) -> http::Response<BoxBody>
265 where
266 S: ClientStreamingService<T::Decode, Response = T::Encode>,
267 B: Body + Send + 'static,
268 B::Error: Into<crate::Error> + Send + 'static,
269 {
270 #[cfg(feature = "compression")]
271 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
272 req.headers(),
273 self.send_compression_encodings,
274 );
275
276 let request = t!(self.map_request_streaming(req));
277
278 let response = service
279 .call(request)
280 .await
281 .map(|r| r.map(|m| stream::once(future::ok(m))));
282
283 #[cfg(feature = "compression")]
284 let compression_override = compression_override_from_response(&response);
285
286 self.map_response(
287 response,
288 #[cfg(feature = "compression")]
289 accept_encoding,
290 #[cfg(feature = "compression")]
291 compression_override,
292 )
293 }
294
295 pub async fn streaming<S, B>(
297 &mut self,
298 mut service: S,
299 req: http::Request<B>,
300 ) -> http::Response<BoxBody>
301 where
302 S: StreamingService<T::Decode, Response = T::Encode> + Send,
303 S::ResponseStream: Send + 'static,
304 B: Body + Send + 'static,
305 B::Error: Into<crate::Error> + Send,
306 {
307 #[cfg(feature = "compression")]
308 let accept_encoding = CompressionEncoding::from_accept_encoding_header(
309 req.headers(),
310 self.send_compression_encodings,
311 );
312
313 let request = t!(self.map_request_streaming(req));
314
315 let response = service.call(request).await;
316
317 self.map_response(
318 response,
319 #[cfg(feature = "compression")]
320 accept_encoding,
321 #[cfg(feature = "compression")]
322 SingleMessageCompressionOverride::default(),
323 )
324 }
325
326 async fn map_request_unary<B>(
327 &mut self,
328 request: http::Request<B>,
329 ) -> Result<Request<T::Decode>, Status>
330 where
331 B: Body + Send + 'static,
332 B::Error: Into<crate::Error> + Send,
333 {
334 #[cfg(feature = "compression")]
335 let request_compression_encoding = self.request_encoding_if_supported(&request)?;
336
337 let (parts, body) = request.into_parts();
338
339 #[cfg(feature = "compression")]
340 let stream =
341 Streaming::new_request(self.codec.decoder(), body, request_compression_encoding);
342
343 #[cfg(not(feature = "compression"))]
344 let stream = Streaming::new_request(self.codec.decoder(), body);
345
346 futures_util::pin_mut!(stream);
347
348 let message = stream
349 .try_next()
350 .await?
351 .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
352
353 let mut req = Request::from_http_parts(parts, message);
354
355 if let Some(trailers) = stream.trailers().await? {
356 req.metadata_mut().merge(trailers);
357 }
358
359 Ok(req)
360 }
361
362 fn map_request_streaming<B>(
363 &mut self,
364 request: http::Request<B>,
365 ) -> Result<Request<Streaming<T::Decode>>, Status>
366 where
367 B: Body + Send + 'static,
368 B::Error: Into<crate::Error> + Send,
369 {
370 #[cfg(feature = "compression")]
371 let encoding = self.request_encoding_if_supported(&request)?;
372
373 #[cfg(feature = "compression")]
374 let request =
375 request.map(|body| Streaming::new_request(self.codec.decoder(), body, encoding));
376
377 #[cfg(not(feature = "compression"))]
378 let request = request.map(|body| Streaming::new_request(self.codec.decoder(), body));
379
380 Ok(Request::from_http(request))
381 }
382
383 fn map_response<B>(
384 &mut self,
385 response: Result<crate::Response<B>, Status>,
386 #[cfg(feature = "compression")] accept_encoding: Option<CompressionEncoding>,
387 #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
388 ) -> http::Response<BoxBody>
389 where
390 B: TryStream<Ok = T::Encode, Error = Status> + Send + 'static,
391 {
392 let response = match response {
393 Ok(r) => r,
394 Err(status) => return status.to_http(),
395 };
396
397 let (mut parts, body) = response.into_http().into_parts();
398
399 parts.headers.insert(
401 http::header::CONTENT_TYPE,
402 http::header::HeaderValue::from_static("application/grpc"),
403 );
404
405 #[cfg(feature = "compression")]
406 if let Some(encoding) = accept_encoding {
407 parts.headers.insert(
409 crate::codec::compression::ENCODING_HEADER,
410 encoding.into_header_value(),
411 );
412 }
413
414 let body = encode_server(
415 self.codec.encoder(),
416 body.into_stream(),
417 #[cfg(feature = "compression")]
418 accept_encoding,
419 #[cfg(feature = "compression")]
420 compression_override,
421 );
422
423 http::Response::from_parts(parts, BoxBody::new(body))
424 }
425
426 #[cfg(feature = "compression")]
427 fn request_encoding_if_supported<B>(
428 &self,
429 request: &http::Request<B>,
430 ) -> Result<Option<CompressionEncoding>, Status> {
431 CompressionEncoding::from_encoding_header(
432 request.headers(),
433 self.accept_compression_encodings,
434 )
435 }
436}
437
438impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 let mut f = f.debug_struct("Grpc");
441
442 f.field("codec", &self.codec);
443
444 #[cfg(feature = "compression")]
445 f.field(
446 "accept_compression_encodings",
447 &self.accept_compression_encodings,
448 );
449
450 #[cfg(feature = "compression")]
451 f.field(
452 "send_compression_encodings",
453 &self.send_compression_encodings,
454 );
455
456 f.finish()
457 }
458}
459
460#[cfg(feature = "compression")]
461fn compression_override_from_response<B, E>(
462 res: &Result<crate::Response<B>, E>,
463) -> SingleMessageCompressionOverride {
464 res.as_ref()
465 .ok()
466 .and_then(|response| {
467 response
468 .extensions()
469 .get::<SingleMessageCompressionOverride>()
470 .copied()
471 })
472 .unwrap_or_default()
473}