tonic/codec/
prost.rs
1use super::{Codec, DecodeBuf, Decoder, Encoder};
2use crate::codec::EncodeBuf;
3use crate::{Code, Status};
4use prost1::Message;
5use std::marker::PhantomData;
6
7#[derive(Debug, Clone)]
9pub struct ProstCodec<T, U> {
10 _pd: PhantomData<(T, U)>,
11}
12
13impl<T, U> Default for ProstCodec<T, U> {
14 fn default() -> Self {
15 Self { _pd: PhantomData }
16 }
17}
18
19impl<T, U> Codec for ProstCodec<T, U>
20where
21 T: Message + Send + 'static,
22 U: Message + Default + Send + 'static,
23{
24 type Encode = T;
25 type Decode = U;
26
27 type Encoder = ProstEncoder<T>;
28 type Decoder = ProstDecoder<U>;
29
30 fn encoder(&mut self) -> Self::Encoder {
31 ProstEncoder(PhantomData)
32 }
33
34 fn decoder(&mut self) -> Self::Decoder {
35 ProstDecoder(PhantomData)
36 }
37}
38
39#[derive(Debug, Clone, Default)]
41pub struct ProstEncoder<T>(PhantomData<T>);
42
43impl<T: Message> Encoder for ProstEncoder<T> {
44 type Item = T;
45 type Error = Status;
46
47 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
48 item.encode(buf)
49 .expect("Message only errors if not enough space");
50
51 Ok(())
52 }
53}
54
55#[derive(Debug, Clone, Default)]
57pub struct ProstDecoder<U>(PhantomData<U>);
58
59impl<U: Message + Default> Decoder for ProstDecoder<U> {
60 type Item = U;
61 type Error = Status;
62
63 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
64 let item = Message::decode(buf)
65 .map(Option::Some)
66 .map_err(from_decode_error)?;
67
68 Ok(item)
69 }
70}
71
72fn from_decode_error(error: prost1::DecodeError) -> crate::Status {
73 Status::new(Code::Internal, error.to_string())
76}
77
78#[cfg(test)]
79mod tests {
80 use crate::codec::compression::SingleMessageCompressionOverride;
81 use crate::codec::{
82 encode_server, DecodeBuf, Decoder, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
83 };
84 use crate::Status;
85 use bytes::{Buf, BufMut, BytesMut};
86 use http_body::Body;
87
88 const LEN: usize = 10000;
89
90 #[tokio::test]
91 async fn decode() {
92 let decoder = MockDecoder::default();
93
94 let msg = vec![0u8; LEN];
95
96 let mut buf = BytesMut::new();
97
98 buf.reserve(msg.len() + HEADER_SIZE);
99 buf.put_u8(0);
100 buf.put_u32(msg.len() as u32);
101
102 buf.put(&msg[..]);
103
104 let body = body::MockBody::new(&buf[..], 10005, 0);
105
106 let mut stream = Streaming::new_request(decoder, body, None);
107
108 let mut i = 0usize;
109 while let Some(output_msg) = stream.message().await.unwrap() {
110 assert_eq!(output_msg.len(), msg.len());
111 i += 1;
112 }
113 assert_eq!(i, 1);
114 }
115
116 #[tokio::test]
117 async fn encode() {
118 let encoder = MockEncoder::default();
119
120 let msg = Vec::from(&[0u8; 1024][..]);
121
122 let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
123 let source = futures_util::stream::iter(messages);
124
125 let body = encode_server(
126 encoder,
127 source,
128 None,
129 SingleMessageCompressionOverride::default(),
130 );
131
132 futures_util::pin_mut!(body);
133
134 while let Some(r) = body.data().await {
135 r.unwrap();
136 }
137 }
138
139 #[derive(Debug, Clone, Default)]
140 struct MockEncoder;
141
142 impl Encoder for MockEncoder {
143 type Item = Vec<u8>;
144 type Error = Status;
145
146 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
147 buf.put(&item[..]);
148 Ok(())
149 }
150 }
151
152 #[derive(Debug, Clone, Default)]
153 struct MockDecoder;
154
155 impl Decoder for MockDecoder {
156 type Item = Vec<u8>;
157 type Error = Status;
158
159 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
160 let out = Vec::from(buf.chunk());
161 buf.advance(LEN);
162 Ok(Some(out))
163 }
164 }
165
166 mod body {
167 use crate::Status;
168 use bytes::Bytes;
169 use http_body::Body;
170 use std::{
171 pin::Pin,
172 task::{Context, Poll},
173 };
174
175 #[derive(Debug)]
176 pub(super) struct MockBody {
177 data: Bytes,
178
179 partial_len: usize,
181
182 count: usize,
184 }
185
186 impl MockBody {
187 pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
188 MockBody {
189 data: Bytes::copy_from_slice(b),
190 partial_len,
191 count,
192 }
193 }
194 }
195
196 impl Body for MockBody {
197 type Data = Bytes;
198 type Error = Status;
199
200 fn poll_data(
201 mut self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
204 let should_send = self.count % 2 == 0;
206 let data_len = self.data.len();
207 let partial_len = self.partial_len;
208 let count = self.count;
209 if data_len > 0 {
210 let result = if should_send {
211 let response =
212 self.data
213 .split_to(if count == 0 { partial_len } else { data_len });
214 Poll::Ready(Some(Ok(response)))
215 } else {
216 cx.waker().wake_by_ref();
217 Poll::Pending
218 };
219 self.count += 1;
221 result
222 } else {
223 Poll::Ready(None)
224 }
225 }
226
227 #[allow(clippy::drop_ref)]
228 fn poll_trailers(
229 self: Pin<&mut Self>,
230 cx: &mut Context<'_>,
231 ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
232 drop(cx);
233 Poll::Ready(Ok(None))
234 }
235 }
236 }
237}