1#[cfg(feature = "compression")]
2use super::compression::{decompress, CompressionEncoding};
3use super::{DecodeBuf, Decoder, HEADER_SIZE};
4use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
5use bytes::{Buf, BufMut, BytesMut};
6use futures_core::Stream;
7use futures_util::{future, ready};
8use http::StatusCode;
9use http_body::Body;
10use std::{
11 fmt,
12 pin::Pin,
13 task::{Context, Poll},
14};
15use tracing::{debug, trace};
16
17const BUFFER_SIZE: usize = 8 * 1024;
18
19pub struct Streaming<T> {
24 decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>,
25 body: BoxBody,
26 state: State,
27 direction: Direction,
28 buf: BytesMut,
29 trailers: Option<MetadataMap>,
30 #[cfg(feature = "compression")]
31 decompress_buf: BytesMut,
32 #[cfg(feature = "compression")]
33 encoding: Option<CompressionEncoding>,
34}
35
36impl<T> Unpin for Streaming<T> {}
37
38#[derive(Debug)]
39enum State {
40 ReadHeader,
41 ReadBody { compression: bool, len: usize },
42}
43
44#[derive(Debug)]
45enum Direction {
46 Request,
47 Response(StatusCode),
48 EmptyResponse,
49}
50
51impl<T> Streaming<T> {
52 pub(crate) fn new_response<B, D>(
53 decoder: D,
54 body: B,
55 status_code: StatusCode,
56 #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
57 ) -> Self
58 where
59 B: Body + Send + 'static,
60 B::Error: Into<crate::Error>,
61 D: Decoder<Item = T, Error = Status> + Send + 'static,
62 {
63 Self::new(
64 decoder,
65 body,
66 Direction::Response(status_code),
67 #[cfg(feature = "compression")]
68 encoding,
69 )
70 }
71
72 pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self
73 where
74 B: Body + Send + 'static,
75 B::Error: Into<crate::Error>,
76 D: Decoder<Item = T, Error = Status> + Send + 'static,
77 {
78 Self::new(
79 decoder,
80 body,
81 Direction::EmptyResponse,
82 #[cfg(feature = "compression")]
83 None,
84 )
85 }
86
87 #[doc(hidden)]
88 pub fn new_request<B, D>(
89 decoder: D,
90 body: B,
91 #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
92 ) -> Self
93 where
94 B: Body + Send + 'static,
95 B::Error: Into<crate::Error>,
96 D: Decoder<Item = T, Error = Status> + Send + 'static,
97 {
98 Self::new(
99 decoder,
100 body,
101 Direction::Request,
102 #[cfg(feature = "compression")]
103 encoding,
104 )
105 }
106
107 fn new<B, D>(
108 decoder: D,
109 body: B,
110 direction: Direction,
111 #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
112 ) -> Self
113 where
114 B: Body + Send + 'static,
115 B::Error: Into<crate::Error>,
116 D: Decoder<Item = T, Error = Status> + Send + 'static,
117 {
118 Self {
119 decoder: Box::new(decoder),
120 body: body
121 .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
122 .map_err(|err| Status::map_error(err.into()))
123 .boxed_unsync(),
124 state: State::ReadHeader,
125 direction,
126 buf: BytesMut::with_capacity(BUFFER_SIZE),
127 trailers: None,
128 #[cfg(feature = "compression")]
129 decompress_buf: BytesMut::new(),
130 #[cfg(feature = "compression")]
131 encoding,
132 }
133 }
134}
135
136impl<T> Streaming<T> {
137 pub async fn message(&mut self) -> Result<Option<T>, Status> {
152 match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
153 Some(Ok(m)) => Ok(Some(m)),
154 Some(Err(e)) => Err(e),
155 None => Ok(None),
156 }
157 }
158
159 pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
175 if let Some(trailers) = self.trailers.take() {
178 return Ok(Some(trailers));
179 }
180
181 while self.message().await?.is_some() {}
183
184 if let Some(trailers) = self.trailers.take() {
187 return Ok(Some(trailers));
188 }
189
190 let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx))
193 .await
194 .map_err(|e| Status::from_error(Box::new(e)));
195
196 map.map(|x| x.map(MetadataMap::from_headers))
197 }
198
199 fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
200 if let State::ReadHeader = self.state {
201 if self.buf.remaining() < HEADER_SIZE {
202 return Ok(None);
203 }
204
205 let is_compressed = match self.buf.get_u8() {
206 0 => false,
207 1 => {
208 #[cfg(feature = "compression")]
209 {
210 if self.encoding.is_some() {
211 true
212 } else {
213 return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
218 }
219 }
220 #[cfg(not(feature = "compression"))]
221 {
222 return Err(Status::new(
223 Code::Unimplemented,
224 "Message compressed, compression support not enabled.".to_string(),
225 ));
226 }
227 }
228 f => {
229 trace!("unexpected compression flag");
230 let message = if let Direction::Response(status) = self.direction {
231 format!(
232 "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}",
233 f, status
234 )
235 } else {
236 format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f)
237 };
238 return Err(Status::new(Code::Internal, message));
239 }
240 };
241 let len = self.buf.get_u32() as usize;
242 self.buf.reserve(len);
243
244 self.state = State::ReadBody {
245 compression: is_compressed,
246 len,
247 }
248 }
249
250 if let State::ReadBody { len, compression } = &self.state {
251 if self.buf.remaining() < *len || self.buf.len() < *len {
254 return Ok(None);
255 }
256
257 let decoding_result = if *compression {
258 #[cfg(feature = "compression")]
259 {
260 self.decompress_buf.clear();
261
262 if let Err(err) = decompress(
263 self.encoding.unwrap_or_else(|| {
264 unreachable!("message was compressed but `Streaming.encoding` was `None`. This is a bug in Tonic. Please file an issue")
266 }),
267 &mut self.buf,
268 &mut self.decompress_buf,
269 *len,
270 ) {
271 let message = if let Direction::Response(status) = self.direction {
272 format!(
273 "Error decompressing: {}, while receiving response with status: {}",
274 err, status
275 )
276 } else {
277 format!("Error decompressing: {}, while sending request", err)
278 };
279 return Err(Status::new(Code::Internal, message));
280 }
281 let decompressed_len = self.decompress_buf.len();
282 self.decoder.decode(&mut DecodeBuf::new(
283 &mut self.decompress_buf,
284 decompressed_len,
285 ))
286 }
287
288 #[cfg(not(feature = "compression"))]
289 unreachable!("should not take this branch if compression is disabled")
290 } else {
291 self.decoder
292 .decode(&mut DecodeBuf::new(&mut self.buf, *len))
293 };
294
295 return match decoding_result {
296 Ok(Some(msg)) => {
297 self.state = State::ReadHeader;
298 Ok(Some(msg))
299 }
300 Ok(None) => Ok(None),
301 Err(e) => Err(e),
302 };
303 }
304
305 Ok(None)
306 }
307}
308
309impl<T> Stream for Streaming<T> {
310 type Item = Result<T, Status>;
311
312 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
313 loop {
314 if let Some(item) = self.decode_chunk()? {
318 return Poll::Ready(Some(Ok(item)));
319 }
320
321 let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
322 Some(Ok(d)) => Some(d),
323 Some(Err(e)) => {
324 let err: crate::Error = e.into();
325 debug!("decoder inner stream error: {:?}", err);
326 let status = Status::from_error(err);
327 return Poll::Ready(Some(Err(status)));
328 }
329 None => None,
330 };
331
332 if let Some(data) = chunk {
333 self.buf.put(data);
334 } else {
335 if self.buf.has_remaining() {
337 trace!("unexpected EOF decoding stream");
338 return Poll::Ready(Some(Err(Status::new(
339 Code::Internal,
340 "Unexpected EOF decoding stream.".to_string(),
341 ))));
342 } else {
343 break;
344 }
345 }
346 }
347
348 if let Direction::Response(status) = self.direction {
349 match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
350 Ok(trailer) => {
351 if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) {
352 if let Some(e) = e {
353 return Some(Err(e)).into();
354 } else {
355 return Poll::Ready(None);
356 }
357 } else {
358 self.trailers = trailer.map(MetadataMap::from_headers);
359 }
360 }
361 Err(e) => {
362 let err: crate::Error = e.into();
363 debug!("decoder inner trailers error: {:?}", err);
364 let status = Status::from_error(err);
365 return Some(Err(status)).into();
366 }
367 }
368 }
369
370 Poll::Ready(None)
371 }
372}
373
374impl<T> fmt::Debug for Streaming<T> {
375 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376 f.debug_struct("Streaming").finish()
377 }
378}
379
380#[cfg(test)]
381static_assertions::assert_impl_all!(Streaming<()>: Send);