tonic/codec/
decode.rs

1#[cfg(feature = "compression")]
2use super::compression::{decompress, CompressionEncoding};
3use super::{DecodeBuf, Decoder, HEADER_SIZE};
4use crate::{body::BoxBody, metadata::MetadataMap, Code, Status};
5use bytes::{Buf, BufMut, BytesMut};
6use futures_core::Stream;
7use futures_util::{future, ready};
8use http::StatusCode;
9use http_body::Body;
10use std::{
11    fmt,
12    pin::Pin,
13    task::{Context, Poll},
14};
15use tracing::{debug, trace};
16
17const BUFFER_SIZE: usize = 8 * 1024;
18
19/// Streaming requests and responses.
20///
21/// This will wrap some inner [`Body`] and [`Decoder`] and provide an interface
22/// to fetch the message stream and trailing metadata
23pub struct Streaming<T> {
24    decoder: Box<dyn Decoder<Item = T, Error = Status> + Send + 'static>,
25    body: BoxBody,
26    state: State,
27    direction: Direction,
28    buf: BytesMut,
29    trailers: Option<MetadataMap>,
30    #[cfg(feature = "compression")]
31    decompress_buf: BytesMut,
32    #[cfg(feature = "compression")]
33    encoding: Option<CompressionEncoding>,
34}
35
36impl<T> Unpin for Streaming<T> {}
37
38#[derive(Debug)]
39enum State {
40    ReadHeader,
41    ReadBody { compression: bool, len: usize },
42}
43
44#[derive(Debug)]
45enum Direction {
46    Request,
47    Response(StatusCode),
48    EmptyResponse,
49}
50
51impl<T> Streaming<T> {
52    pub(crate) fn new_response<B, D>(
53        decoder: D,
54        body: B,
55        status_code: StatusCode,
56        #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
57    ) -> Self
58    where
59        B: Body + Send + 'static,
60        B::Error: Into<crate::Error>,
61        D: Decoder<Item = T, Error = Status> + Send + 'static,
62    {
63        Self::new(
64            decoder,
65            body,
66            Direction::Response(status_code),
67            #[cfg(feature = "compression")]
68            encoding,
69        )
70    }
71
72    pub(crate) fn new_empty<B, D>(decoder: D, body: B) -> Self
73    where
74        B: Body + Send + 'static,
75        B::Error: Into<crate::Error>,
76        D: Decoder<Item = T, Error = Status> + Send + 'static,
77    {
78        Self::new(
79            decoder,
80            body,
81            Direction::EmptyResponse,
82            #[cfg(feature = "compression")]
83            None,
84        )
85    }
86
87    #[doc(hidden)]
88    pub fn new_request<B, D>(
89        decoder: D,
90        body: B,
91        #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
92    ) -> Self
93    where
94        B: Body + Send + 'static,
95        B::Error: Into<crate::Error>,
96        D: Decoder<Item = T, Error = Status> + Send + 'static,
97    {
98        Self::new(
99            decoder,
100            body,
101            Direction::Request,
102            #[cfg(feature = "compression")]
103            encoding,
104        )
105    }
106
107    fn new<B, D>(
108        decoder: D,
109        body: B,
110        direction: Direction,
111        #[cfg(feature = "compression")] encoding: Option<CompressionEncoding>,
112    ) -> Self
113    where
114        B: Body + Send + 'static,
115        B::Error: Into<crate::Error>,
116        D: Decoder<Item = T, Error = Status> + Send + 'static,
117    {
118        Self {
119            decoder: Box::new(decoder),
120            body: body
121                .map_data(|mut buf| buf.copy_to_bytes(buf.remaining()))
122                .map_err(|err| Status::map_error(err.into()))
123                .boxed_unsync(),
124            state: State::ReadHeader,
125            direction,
126            buf: BytesMut::with_capacity(BUFFER_SIZE),
127            trailers: None,
128            #[cfg(feature = "compression")]
129            decompress_buf: BytesMut::new(),
130            #[cfg(feature = "compression")]
131            encoding,
132        }
133    }
134}
135
136impl<T> Streaming<T> {
137    /// Fetch the next message from this stream.
138    /// ```rust
139    /// # use tonic::{Streaming, Status, codec::Decoder};
140    /// # use std::fmt::Debug;
141    /// # async fn next_message_ex<T, D>(mut request: Streaming<T>) -> Result<(), Status>
142    /// # where T: Debug,
143    /// # D: Decoder<Item = T, Error = Status> + Send  + 'static,
144    /// # {
145    /// if let Some(next_message) = request.message().await? {
146    ///     println!("{:?}", next_message);
147    /// }
148    /// # Ok(())
149    /// # }
150    /// ```
151    pub async fn message(&mut self) -> Result<Option<T>, Status> {
152        match future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await {
153            Some(Ok(m)) => Ok(Some(m)),
154            Some(Err(e)) => Err(e),
155            None => Ok(None),
156        }
157    }
158
159    /// Fetch the trailing metadata.
160    ///
161    /// This will drain the stream of all its messages to receive the trailing
162    /// metadata. If [`Streaming::message`] returns `None` then this function
163    /// will not need to poll for trailers since the body was totally consumed.
164    ///
165    /// ```rust
166    /// # use tonic::{Streaming, Status};
167    /// # async fn trailers_ex<T>(mut request: Streaming<T>) -> Result<(), Status> {
168    /// if let Some(metadata) = request.trailers().await? {
169    ///     println!("{:?}", metadata);
170    /// }
171    /// # Ok(())
172    /// # }
173    /// ```
174    pub async fn trailers(&mut self) -> Result<Option<MetadataMap>, Status> {
175        // Shortcut to see if we already pulled the trailers in the stream step
176        // we need to do that so that the stream can error on trailing grpc-status
177        if let Some(trailers) = self.trailers.take() {
178            return Ok(Some(trailers));
179        }
180
181        // To fetch the trailers we must clear the body and drop it.
182        while self.message().await?.is_some() {}
183
184        // Since we call poll_trailers internally on poll_next we need to
185        // check if it got cached again.
186        if let Some(trailers) = self.trailers.take() {
187            return Ok(Some(trailers));
188        }
189
190        // Trailers were not caught during poll_next and thus lets poll for
191        // them manually.
192        let map = future::poll_fn(|cx| Pin::new(&mut self.body).poll_trailers(cx))
193            .await
194            .map_err(|e| Status::from_error(Box::new(e)));
195
196        map.map(|x| x.map(MetadataMap::from_headers))
197    }
198
199    fn decode_chunk(&mut self) -> Result<Option<T>, Status> {
200        if let State::ReadHeader = self.state {
201            if self.buf.remaining() < HEADER_SIZE {
202                return Ok(None);
203            }
204
205            let is_compressed = match self.buf.get_u8() {
206                0 => false,
207                1 => {
208                    #[cfg(feature = "compression")]
209                    {
210                        if self.encoding.is_some() {
211                            true
212                        } else {
213                            // https://grpc.github.io/grpc/core/md_doc_compression.html
214                            // An ill-constructed message with its Compressed-Flag bit set but lacking a grpc-encoding
215                            // entry different from identity in its metadata MUST fail with INTERNAL status,
216                            // its associated description indicating the invalid Compressed-Flag condition.
217                            return Err(Status::new(Code::Internal, "protocol error: received message with compressed-flag but no grpc-encoding was specified"));
218                        }
219                    }
220                    #[cfg(not(feature = "compression"))]
221                    {
222                        return Err(Status::new(
223                            Code::Unimplemented,
224                            "Message compressed, compression support not enabled.".to_string(),
225                        ));
226                    }
227                }
228                f => {
229                    trace!("unexpected compression flag");
230                    let message = if let Direction::Response(status) = self.direction {
231                        format!(
232                            "protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1) while receiving response with status: {}",
233                            f, status
234                        )
235                    } else {
236                        format!("protocol error: received message with invalid compression flag: {} (valid flags are 0 and 1), while sending request", f)
237                    };
238                    return Err(Status::new(Code::Internal, message));
239                }
240            };
241            let len = self.buf.get_u32() as usize;
242            self.buf.reserve(len);
243
244            self.state = State::ReadBody {
245                compression: is_compressed,
246                len,
247            }
248        }
249
250        if let State::ReadBody { len, compression } = &self.state {
251            // if we haven't read enough of the message then return and keep
252            // reading
253            if self.buf.remaining() < *len || self.buf.len() < *len {
254                return Ok(None);
255            }
256
257            let decoding_result = if *compression {
258                #[cfg(feature = "compression")]
259                {
260                    self.decompress_buf.clear();
261
262                    if let Err(err) = decompress(
263                        self.encoding.unwrap_or_else(|| {
264                            // SAFETY: The check while in State::ReadHeader would already have returned Code::Internal
265                            unreachable!("message was compressed but `Streaming.encoding` was `None`. This is a bug in Tonic. Please file an issue")
266                        }),
267                        &mut self.buf,
268                        &mut self.decompress_buf,
269                        *len,
270                    ) {
271                        let message = if let Direction::Response(status) = self.direction {
272                            format!(
273                                "Error decompressing: {}, while receiving response with status: {}",
274                                err, status
275                            )
276                        } else {
277                            format!("Error decompressing: {}, while sending request", err)
278                        };
279                        return Err(Status::new(Code::Internal, message));
280                    }
281                    let decompressed_len = self.decompress_buf.len();
282                    self.decoder.decode(&mut DecodeBuf::new(
283                        &mut self.decompress_buf,
284                        decompressed_len,
285                    ))
286                }
287
288                #[cfg(not(feature = "compression"))]
289                unreachable!("should not take this branch if compression is disabled")
290            } else {
291                self.decoder
292                    .decode(&mut DecodeBuf::new(&mut self.buf, *len))
293            };
294
295            return match decoding_result {
296                Ok(Some(msg)) => {
297                    self.state = State::ReadHeader;
298                    Ok(Some(msg))
299                }
300                Ok(None) => Ok(None),
301                Err(e) => Err(e),
302            };
303        }
304
305        Ok(None)
306    }
307}
308
309impl<T> Stream for Streaming<T> {
310    type Item = Result<T, Status>;
311
312    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
313        loop {
314            // FIXME: implement the ability to poll trailers when we _know_ that
315            // the consumer of this stream will only poll for the first message.
316            // This means we skip the poll_trailers step.
317            if let Some(item) = self.decode_chunk()? {
318                return Poll::Ready(Some(Ok(item)));
319            }
320
321            let chunk = match ready!(Pin::new(&mut self.body).poll_data(cx)) {
322                Some(Ok(d)) => Some(d),
323                Some(Err(e)) => {
324                    let err: crate::Error = e.into();
325                    debug!("decoder inner stream error: {:?}", err);
326                    let status = Status::from_error(err);
327                    return Poll::Ready(Some(Err(status)));
328                }
329                None => None,
330            };
331
332            if let Some(data) = chunk {
333                self.buf.put(data);
334            } else {
335                // FIXME: improve buf usage.
336                if self.buf.has_remaining() {
337                    trace!("unexpected EOF decoding stream");
338                    return Poll::Ready(Some(Err(Status::new(
339                        Code::Internal,
340                        "Unexpected EOF decoding stream.".to_string(),
341                    ))));
342                } else {
343                    break;
344                }
345            }
346        }
347
348        if let Direction::Response(status) = self.direction {
349            match ready!(Pin::new(&mut self.body).poll_trailers(cx)) {
350                Ok(trailer) => {
351                    if let Err(e) = crate::status::infer_grpc_status(trailer.as_ref(), status) {
352                        if let Some(e) = e {
353                            return Some(Err(e)).into();
354                        } else {
355                            return Poll::Ready(None);
356                        }
357                    } else {
358                        self.trailers = trailer.map(MetadataMap::from_headers);
359                    }
360                }
361                Err(e) => {
362                    let err: crate::Error = e.into();
363                    debug!("decoder inner trailers error: {:?}", err);
364                    let status = Status::from_error(err);
365                    return Some(Err(status)).into();
366                }
367            }
368        }
369
370        Poll::Ready(None)
371    }
372}
373
374impl<T> fmt::Debug for Streaming<T> {
375    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
376        f.debug_struct("Streaming").finish()
377    }
378}
379
380#[cfg(test)]
381static_assertions::assert_impl_all!(Streaming<()>: Send);