tonic/codec/
prost.rs

1use super::{Codec, DecodeBuf, Decoder, Encoder};
2use crate::codec::EncodeBuf;
3use crate::{Code, Status};
4use prost1::Message;
5use std::marker::PhantomData;
6
7/// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8#[derive(Debug, Clone)]
9pub struct ProstCodec<T, U> {
10    _pd: PhantomData<(T, U)>,
11}
12
13impl<T, U> Default for ProstCodec<T, U> {
14    fn default() -> Self {
15        Self { _pd: PhantomData }
16    }
17}
18
19impl<T, U> Codec for ProstCodec<T, U>
20where
21    T: Message + Send + 'static,
22    U: Message + Default + Send + 'static,
23{
24    type Encode = T;
25    type Decode = U;
26
27    type Encoder = ProstEncoder<T>;
28    type Decoder = ProstDecoder<U>;
29
30    fn encoder(&mut self) -> Self::Encoder {
31        ProstEncoder(PhantomData)
32    }
33
34    fn decoder(&mut self) -> Self::Decoder {
35        ProstDecoder(PhantomData)
36    }
37}
38
39/// A [`Encoder`] that knows how to encode `T`.
40#[derive(Debug, Clone, Default)]
41pub struct ProstEncoder<T>(PhantomData<T>);
42
43impl<T: Message> Encoder for ProstEncoder<T> {
44    type Item = T;
45    type Error = Status;
46
47    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
48        item.encode(buf)
49            .expect("Message only errors if not enough space");
50
51        Ok(())
52    }
53}
54
55/// A [`Decoder`] that knows how to decode `U`.
56#[derive(Debug, Clone, Default)]
57pub struct ProstDecoder<U>(PhantomData<U>);
58
59impl<U: Message + Default> Decoder for ProstDecoder<U> {
60    type Item = U;
61    type Error = Status;
62
63    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
64        let item = Message::decode(buf)
65            .map(Option::Some)
66            .map_err(from_decode_error)?;
67
68        Ok(item)
69    }
70}
71
72fn from_decode_error(error: prost1::DecodeError) -> crate::Status {
73    // Map Protobuf parse errors to an INTERNAL status code, as per
74    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
75    Status::new(Code::Internal, error.to_string())
76}
77
78#[cfg(test)]
79mod tests {
80    use crate::codec::compression::SingleMessageCompressionOverride;
81    use crate::codec::{
82        encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
83    };
84    use crate::Status;
85    use bytes::{Buf, BufMut, BytesMut};
86    use http_body::Body;
87
88    const LEN: usize = 10000;
89
90    #[tokio::test]
91    async fn decode() {
92        let decoder = MockDecoder::default();
93
94        let msg = vec![0u8; LEN];
95
96        let mut buf = BytesMut::new();
97
98        buf.reserve(msg.len() + HEADER_SIZE);
99        buf.put_u8(0);
100        buf.put_u32(msg.len() as u32);
101
102        buf.put(&msg[..]);
103
104        let body = body::MockBody::new(&buf[..], 10005, 0);
105
106        let mut stream = Streaming::new_request(decoder, body, None);
107
108        let mut i = 0usize;
109        while let Some(output_msg) = stream.message().await.unwrap() {
110            assert_eq!(output_msg.len(), msg.len());
111            i += 1;
112        }
113        assert_eq!(i, 1);
114    }
115
116    #[tokio::test]
117    async fn encode() {
118        let encoder = MockEncoder::default();
119
120        let msg = Vec::from(&[0u8; 1024][..]);
121
122        let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
123        let source = futures_util::stream::iter(messages);
124
125        let body = encode_server(
126            encoder,
127            source,
128            None,
129            SingleMessageCompressionOverride::default(),
130        );
131
132        futures_util::pin_mut!(body);
133
134        while let Some(r) = body.data().await {
135            r.unwrap();
136        }
137    }
138
139    #[derive(Debug, Clone, Default)]
140    struct MockEncoder;
141
142    impl Encoder for MockEncoder {
143        type Item = Vec<u8>;
144        type Error = Status;
145
146        fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
147            buf.put(&item[..]);
148            Ok(())
149        }
150    }
151
152    #[derive(Debug, Clone, Default)]
153    struct MockDecoder;
154
155    impl Decoder for MockDecoder {
156        type Item = Vec<u8>;
157        type Error = Status;
158
159        fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
160            let out = Vec::from(buf.chunk());
161            buf.advance(LEN);
162            Ok(Some(out))
163        }
164    }
165
166    mod body {
167        use crate::Status;
168        use bytes::Bytes;
169        use http_body::Body;
170        use std::{
171            pin::Pin,
172            task::{Context, Poll},
173        };
174
175        #[derive(Debug)]
176        pub(super) struct MockBody {
177            data: Bytes,
178
179            // the size of the partial message to send
180            partial_len: usize,
181
182            // the number of times we've sent
183            count: usize,
184        }
185
186        impl MockBody {
187            pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
188                MockBody {
189                    data: Bytes::copy_from_slice(b),
190                    partial_len,
191                    count,
192                }
193            }
194        }
195
196        impl Body for MockBody {
197            type Data = Bytes;
198            type Error = Status;
199
200            fn poll_data(
201                mut self: Pin<&mut Self>,
202                cx: &mut Context<'_>,
203            ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
204                // every other call to poll_data returns data
205                let should_send = self.count % 2 == 0;
206                let data_len = self.data.len();
207                let partial_len = self.partial_len;
208                let count = self.count;
209                if data_len > 0 {
210                    let result = if should_send {
211                        let response =
212                            self.data
213                                .split_to(if count == 0 { partial_len } else { data_len });
214                        Poll::Ready(Some(Ok(response)))
215                    } else {
216                        cx.waker().wake_by_ref();
217                        Poll::Pending
218                    };
219                    // make some fake progress
220                    self.count += 1;
221                    result
222                } else {
223                    Poll::Ready(None)
224                }
225            }
226
227            #[allow(clippy::drop_ref)]
228            fn poll_trailers(
229                self: Pin<&mut Self>,
230                cx: &mut Context<'_>,
231            ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
232                drop(cx);
233                Poll::Ready(Ok(None))
234            }
235        }
236    }
237}