tonic/codec/
encode.rs

1#[cfg(feature = "compression")]
2use super::compression::{compress, CompressionEncoding, SingleMessageCompressionOverride};
3use super::{EncodeBuf, Encoder, HEADER_SIZE};
4use crate::{Code, Status};
5use bytes::{BufMut, Bytes, BytesMut};
6use futures_core::{Stream, TryStream};
7use futures_util::{ready, StreamExt, TryStreamExt};
8use http::HeaderMap;
9use http_body::Body;
10use pin_project::pin_project;
11use std::{
12    pin::Pin,
13    task::{Context, Poll},
14};
15
16pub(super) const BUFFER_SIZE: usize = 8 * 1024;
17
18pub(crate) fn encode_server<T, U>(
19    encoder: T,
20    source: U,
21    #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
22    #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
23) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
24where
25    T: Encoder<Error = Status>,
26    U: Stream<Item = Result<T::Item, Status>>,
27{
28    let stream = encode(
29        encoder,
30        source,
31        #[cfg(feature = "compression")]
32        compression_encoding,
33        #[cfg(feature = "compression")]
34        compression_override,
35    )
36    .into_stream();
37
38    EncodeBody::new_server(stream)
39}
40
41pub(crate) fn encode_client<T, U>(
42    encoder: T,
43    source: U,
44    #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
45) -> EncodeBody<impl Stream<Item = Result<Bytes, Status>>>
46where
47    T: Encoder<Error = Status>,
48    U: Stream<Item = T::Item>,
49{
50    let stream = encode(
51        encoder,
52        source.map(Ok),
53        #[cfg(feature = "compression")]
54        compression_encoding,
55        #[cfg(feature = "compression")]
56        SingleMessageCompressionOverride::default(),
57    )
58    .into_stream();
59    EncodeBody::new_client(stream)
60}
61
62fn encode<T, U>(
63    mut encoder: T,
64    source: U,
65    #[cfg(feature = "compression")] compression_encoding: Option<CompressionEncoding>,
66    #[cfg(feature = "compression")] compression_override: SingleMessageCompressionOverride,
67) -> impl TryStream<Ok = Bytes, Error = Status>
68where
69    T: Encoder<Error = Status>,
70    U: Stream<Item = Result<T::Item, Status>>,
71{
72    async_stream::stream! {
73        let mut buf = BytesMut::with_capacity(BUFFER_SIZE);
74
75        #[cfg(feature = "compression")]
76        let (compression_enabled_for_stream, mut uncompression_buf) = match compression_encoding {
77            Some(CompressionEncoding::Gzip) => (true, BytesMut::with_capacity(BUFFER_SIZE)),
78            None => (false, BytesMut::new()),
79        };
80
81        #[cfg(feature = "compression")]
82        let compress_item = compression_enabled_for_stream && compression_override == SingleMessageCompressionOverride::Inherit;
83
84        #[cfg(not(feature = "compression"))]
85        let compress_item = false;
86
87        futures_util::pin_mut!(source);
88
89        loop {
90            match source.next().await {
91                Some(Ok(item)) => {
92                    buf.reserve(HEADER_SIZE);
93                    unsafe {
94                        buf.advance_mut(HEADER_SIZE);
95                    }
96
97                    if compress_item {
98                        #[cfg(feature = "compression")]
99                        {
100                            uncompression_buf.clear();
101
102                            encoder.encode(item, &mut EncodeBuf::new(&mut uncompression_buf))
103                                .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
104
105                            let uncompressed_len = uncompression_buf.len();
106
107                            compress(
108                                compression_encoding.unwrap(),
109                                &mut uncompression_buf,
110                                &mut buf,
111                                uncompressed_len,
112                            ).map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
113                        }
114
115                        #[cfg(not(feature = "compression"))]
116                        unreachable!("compression disabled, should not take this branch");
117                    } else {
118                        encoder.encode(item, &mut EncodeBuf::new(&mut buf))
119                            .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
120                    }
121
122                    // now that we know length, we can write the header
123                    let len = buf.len() - HEADER_SIZE;
124                    assert!(len <= std::u32::MAX as usize);
125                    {
126                        let mut buf = &mut buf[..HEADER_SIZE];
127                        buf.put_u8(compress_item as u8);
128                        buf.put_u32(len as u32);
129                    }
130
131                    yield Ok(buf.split_to(len + HEADER_SIZE).freeze());
132                },
133                Some(Err(status)) => yield Err(status),
134                None => break,
135            }
136        }
137    }
138}
139
140#[derive(Debug)]
141enum Role {
142    Client,
143    Server,
144}
145
146#[pin_project]
147#[derive(Debug)]
148pub(crate) struct EncodeBody<S> {
149    #[pin]
150    inner: S,
151    error: Option<Status>,
152    role: Role,
153    is_end_stream: bool,
154}
155
156impl<S> EncodeBody<S>
157where
158    S: Stream<Item = Result<Bytes, Status>>,
159{
160    pub(crate) fn new_client(inner: S) -> Self {
161        Self {
162            inner,
163            error: None,
164            role: Role::Client,
165            is_end_stream: false,
166        }
167    }
168
169    pub(crate) fn new_server(inner: S) -> Self {
170        Self {
171            inner,
172            error: None,
173            role: Role::Server,
174            is_end_stream: false,
175        }
176    }
177}
178
179impl<S> Body for EncodeBody<S>
180where
181    S: Stream<Item = Result<Bytes, Status>>,
182{
183    type Data = Bytes;
184    type Error = Status;
185
186    fn is_end_stream(&self) -> bool {
187        self.is_end_stream
188    }
189
190    fn poll_data(
191        self: Pin<&mut Self>,
192        cx: &mut Context<'_>,
193    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
194        let mut self_proj = self.project();
195        match ready!(self_proj.inner.try_poll_next_unpin(cx)) {
196            Some(Ok(d)) => Some(Ok(d)).into(),
197            Some(Err(status)) => match self_proj.role {
198                Role::Client => Some(Err(status)).into(),
199                Role::Server => {
200                    *self_proj.error = Some(status);
201                    None.into()
202                }
203            },
204            None => None.into(),
205        }
206    }
207
208    fn poll_trailers(
209        self: Pin<&mut Self>,
210        _cx: &mut Context<'_>,
211    ) -> Poll<Result<Option<HeaderMap>, Status>> {
212        match self.role {
213            Role::Client => Poll::Ready(Ok(None)),
214            Role::Server => {
215                let self_proj = self.project();
216
217                if *self_proj.is_end_stream {
218                    return Poll::Ready(Ok(None));
219                }
220
221                let status = if let Some(status) = self_proj.error.take() {
222                    *self_proj.is_end_stream = true;
223                    status
224                } else {
225                    Status::new(Code::Ok, "")
226                };
227
228                Poll::Ready(Ok(Some(status.to_header_map()?)))
229            }
230        }
231    }
232}