prost/
encoding.rs

1//! Utility functions and types for encoding and decoding Protobuf types.
2//!
3//! Meant to be used only from `Message` implementations.
4
5#![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7use alloc::collections::BTreeMap;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::cmp::min;
12use core::convert::TryFrom;
13use core::mem;
14use core::str;
15use core::u32;
16use core::usize;
17
18use ::bytes::{Buf, BufMut, Bytes};
19
20use crate::DecodeError;
21use crate::Message;
22
23/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
24/// The buffer must have enough remaining space (maximum 10 bytes).
25#[inline]
26pub fn encode_varint<B>(mut value: u64, buf: &mut B)
27where
28    B: BufMut,
29{
30    loop {
31        if value < 0x80 {
32            buf.put_u8(value as u8);
33            break;
34        } else {
35            buf.put_u8(((value & 0x7F) | 0x80) as u8);
36            value >>= 7;
37        }
38    }
39}
40
41/// Decodes a LEB128-encoded variable length integer from the buffer.
42#[inline]
43pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError>
44where
45    B: Buf,
46{
47    let bytes = buf.chunk();
48    let len = bytes.len();
49    if len == 0 {
50        return Err(DecodeError::new("invalid varint"));
51    }
52
53    let byte = bytes[0];
54    if byte < 0x80 {
55        buf.advance(1);
56        Ok(u64::from(byte))
57    } else if len > 10 || bytes[len - 1] < 0x80 {
58        let (value, advance) = decode_varint_slice(bytes)?;
59        buf.advance(advance);
60        Ok(value)
61    } else {
62        decode_varint_slow(buf)
63    }
64}
65
66/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
67/// number of bytes read.
68///
69/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
70/// [`ConsumeVarint`][2].
71///
72/// ## Safety
73///
74/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
75/// element in bytes is < `0x80`.
76///
77/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
78/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
79#[inline]
80fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
81    // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
82
83    // Use assertions to ensure memory safety, but it should always be optimized after inline.
84    assert!(!bytes.is_empty());
85    assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
86
87    let mut b: u8;
88    let mut part0: u32;
89    b = unsafe { *bytes.get_unchecked(0) };
90    part0 = u32::from(b);
91    if b < 0x80 {
92        return Ok((u64::from(part0), 1));
93    };
94    part0 -= 0x80;
95    b = unsafe { *bytes.get_unchecked(1) };
96    part0 += u32::from(b) << 7;
97    if b < 0x80 {
98        return Ok((u64::from(part0), 2));
99    };
100    part0 -= 0x80 << 7;
101    b = unsafe { *bytes.get_unchecked(2) };
102    part0 += u32::from(b) << 14;
103    if b < 0x80 {
104        return Ok((u64::from(part0), 3));
105    };
106    part0 -= 0x80 << 14;
107    b = unsafe { *bytes.get_unchecked(3) };
108    part0 += u32::from(b) << 21;
109    if b < 0x80 {
110        return Ok((u64::from(part0), 4));
111    };
112    part0 -= 0x80 << 21;
113    let value = u64::from(part0);
114
115    let mut part1: u32;
116    b = unsafe { *bytes.get_unchecked(4) };
117    part1 = u32::from(b);
118    if b < 0x80 {
119        return Ok((value + (u64::from(part1) << 28), 5));
120    };
121    part1 -= 0x80;
122    b = unsafe { *bytes.get_unchecked(5) };
123    part1 += u32::from(b) << 7;
124    if b < 0x80 {
125        return Ok((value + (u64::from(part1) << 28), 6));
126    };
127    part1 -= 0x80 << 7;
128    b = unsafe { *bytes.get_unchecked(6) };
129    part1 += u32::from(b) << 14;
130    if b < 0x80 {
131        return Ok((value + (u64::from(part1) << 28), 7));
132    };
133    part1 -= 0x80 << 14;
134    b = unsafe { *bytes.get_unchecked(7) };
135    part1 += u32::from(b) << 21;
136    if b < 0x80 {
137        return Ok((value + (u64::from(part1) << 28), 8));
138    };
139    part1 -= 0x80 << 21;
140    let value = value + ((u64::from(part1)) << 28);
141
142    let mut part2: u32;
143    b = unsafe { *bytes.get_unchecked(8) };
144    part2 = u32::from(b);
145    if b < 0x80 {
146        return Ok((value + (u64::from(part2) << 56), 9));
147    };
148    part2 -= 0x80;
149    b = unsafe { *bytes.get_unchecked(9) };
150    part2 += u32::from(b) << 7;
151    // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
152    // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
153    if b < 0x02 {
154        return Ok((value + (u64::from(part2) << 56), 10));
155    };
156
157    // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
158    // Assume the data is corrupt.
159    Err(DecodeError::new("invalid varint"))
160}
161
162/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
163/// necessary.
164///
165/// Contains a varint overflow check from [`ConsumeVarint`][1].
166///
167/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
168#[inline(never)]
169#[cold]
170fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
171where
172    B: Buf,
173{
174    let mut value = 0;
175    for count in 0..min(10, buf.remaining()) {
176        let byte = buf.get_u8();
177        value |= u64::from(byte & 0x7F) << (count * 7);
178        if byte <= 0x7F {
179            // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
180            // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
181            if count == 9 && byte >= 0x02 {
182                return Err(DecodeError::new("invalid varint"));
183            } else {
184                return Ok(value);
185            }
186        }
187    }
188
189    Err(DecodeError::new("invalid varint"))
190}
191
192/// Additional information passed to every decode/merge function.
193///
194/// The context should be passed by value and can be freely cloned. When passing
195/// to a function which is decoding a nested object, then use `enter_recursion`.
196#[derive(Clone, Debug)]
197pub struct DecodeContext {
198    /// How many times we can recurse in the current decode stack before we hit
199    /// the recursion limit.
200    ///
201    /// The recursion limit is defined by `RECURSION_LIMIT` and cannot be
202    /// customized. The recursion limit can be ignored by building the Prost
203    /// crate with the `no-recursion-limit` feature.
204    #[cfg(not(feature = "no-recursion-limit"))]
205    recurse_count: u32,
206}
207
208impl Default for DecodeContext {
209    #[cfg(not(feature = "no-recursion-limit"))]
210    #[inline]
211    fn default() -> DecodeContext {
212        DecodeContext {
213            recurse_count: crate::RECURSION_LIMIT,
214        }
215    }
216
217    #[cfg(feature = "no-recursion-limit")]
218    #[inline]
219    fn default() -> DecodeContext {
220        DecodeContext {}
221    }
222}
223
224impl DecodeContext {
225    /// Call this function before recursively decoding.
226    ///
227    /// There is no `exit` function since this function creates a new `DecodeContext`
228    /// to be used at the next level of recursion. Continue to use the old context
229    // at the previous level of recursion.
230    #[cfg(not(feature = "no-recursion-limit"))]
231    #[inline]
232    pub(crate) fn enter_recursion(&self) -> DecodeContext {
233        DecodeContext {
234            recurse_count: self.recurse_count - 1,
235        }
236    }
237
238    #[cfg(feature = "no-recursion-limit")]
239    #[inline]
240    pub(crate) fn enter_recursion(&self) -> DecodeContext {
241        DecodeContext {}
242    }
243
244    /// Checks whether the recursion limit has been reached in the stack of
245    /// decodes described by the `DecodeContext` at `self.ctx`.
246    ///
247    /// Returns `Ok<()>` if it is ok to continue recursing.
248    /// Returns `Err<DecodeError>` if the recursion limit has been reached.
249    #[cfg(not(feature = "no-recursion-limit"))]
250    #[inline]
251    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
252        if self.recurse_count == 0 {
253            Err(DecodeError::new("recursion limit reached"))
254        } else {
255            Ok(())
256        }
257    }
258
259    #[cfg(feature = "no-recursion-limit")]
260    #[inline]
261    #[allow(clippy::unnecessary_wraps)] // needed in other features
262    pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
263        Ok(())
264    }
265}
266
267/// Returns the encoded length of the value in LEB128 variable length format.
268/// The returned value will be between 1 and 10, inclusive.
269#[inline]
270pub fn encoded_len_varint(value: u64) -> usize {
271    // Based on [VarintSize64][1].
272    // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
273    ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
274}
275
276#[derive(Clone, Copy, Debug, PartialEq)]
277#[repr(u8)]
278pub enum WireType {
279    Varint = 0,
280    SixtyFourBit = 1,
281    LengthDelimited = 2,
282    StartGroup = 3,
283    EndGroup = 4,
284    ThirtyTwoBit = 5,
285}
286
287pub const MIN_TAG: u32 = 1;
288pub const MAX_TAG: u32 = (1 << 29) - 1;
289
290impl TryFrom<u64> for WireType {
291    type Error = DecodeError;
292
293    #[inline]
294    fn try_from(value: u64) -> Result<Self, Self::Error> {
295        match value {
296            0 => Ok(WireType::Varint),
297            1 => Ok(WireType::SixtyFourBit),
298            2 => Ok(WireType::LengthDelimited),
299            3 => Ok(WireType::StartGroup),
300            4 => Ok(WireType::EndGroup),
301            5 => Ok(WireType::ThirtyTwoBit),
302            _ => Err(DecodeError::new(format!(
303                "invalid wire type value: {}",
304                value
305            ))),
306        }
307    }
308}
309
310/// Encodes a Protobuf field key, which consists of a wire type designator and
311/// the field tag.
312#[inline]
313pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B)
314where
315    B: BufMut,
316{
317    debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
318    let key = (tag << 3) | wire_type as u32;
319    encode_varint(u64::from(key), buf);
320}
321
322/// Decodes a Protobuf field key, which consists of a wire type designator and
323/// the field tag.
324#[inline(always)]
325pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError>
326where
327    B: Buf,
328{
329    let key = decode_varint(buf)?;
330    if key > u64::from(u32::MAX) {
331        return Err(DecodeError::new(format!("invalid key value: {}", key)));
332    }
333    let wire_type = WireType::try_from(key & 0x07)?;
334    let tag = key as u32 >> 3;
335
336    if tag < MIN_TAG {
337        return Err(DecodeError::new("invalid tag value: 0"));
338    }
339
340    Ok((tag, wire_type))
341}
342
343/// Returns the width of an encoded Protobuf field key with the given tag.
344/// The returned width will be between 1 and 5 bytes (inclusive).
345#[inline]
346pub fn key_len(tag: u32) -> usize {
347    encoded_len_varint(u64::from(tag << 3))
348}
349
350/// Checks that the expected wire type matches the actual wire type,
351/// or returns an error result.
352#[inline]
353pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
354    if expected != actual {
355        return Err(DecodeError::new(format!(
356            "invalid wire type: {:?} (expected {:?})",
357            actual, expected
358        )));
359    }
360    Ok(())
361}
362
363/// Helper function which abstracts reading a length delimiter prefix followed
364/// by decoding values until the length of bytes is exhausted.
365pub fn merge_loop<T, M, B>(
366    value: &mut T,
367    buf: &mut B,
368    ctx: DecodeContext,
369    mut merge: M,
370) -> Result<(), DecodeError>
371where
372    M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
373    B: Buf,
374{
375    let len = decode_varint(buf)?;
376    let remaining = buf.remaining();
377    if len > remaining as u64 {
378        return Err(DecodeError::new("buffer underflow"));
379    }
380
381    let limit = remaining - len as usize;
382    while buf.remaining() > limit {
383        merge(value, buf, ctx.clone())?;
384    }
385
386    if buf.remaining() != limit {
387        return Err(DecodeError::new("delimited length exceeded"));
388    }
389    Ok(())
390}
391
392pub fn skip_field<B>(
393    wire_type: WireType,
394    tag: u32,
395    buf: &mut B,
396    ctx: DecodeContext,
397) -> Result<(), DecodeError>
398where
399    B: Buf,
400{
401    ctx.limit_reached()?;
402    let len = match wire_type {
403        WireType::Varint => decode_varint(buf).map(|_| 0)?,
404        WireType::ThirtyTwoBit => 4,
405        WireType::SixtyFourBit => 8,
406        WireType::LengthDelimited => decode_varint(buf)?,
407        WireType::StartGroup => loop {
408            let (inner_tag, inner_wire_type) = decode_key(buf)?;
409            match inner_wire_type {
410                WireType::EndGroup => {
411                    if inner_tag != tag {
412                        return Err(DecodeError::new("unexpected end group tag"));
413                    }
414                    break 0;
415                }
416                _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
417            }
418        },
419        WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
420    };
421
422    if len > buf.remaining() as u64 {
423        return Err(DecodeError::new("buffer underflow"));
424    }
425
426    buf.advance(len as usize);
427    Ok(())
428}
429
430/// Helper macro which emits an `encode_repeated` function for the type.
431macro_rules! encode_repeated {
432    ($ty:ty) => {
433        pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B)
434        where
435            B: BufMut,
436        {
437            for value in values {
438                encode(tag, value, buf);
439            }
440        }
441    };
442}
443
444/// Helper macro which emits a `merge_repeated` function for the numeric type.
445macro_rules! merge_repeated_numeric {
446    ($ty:ty,
447     $wire_type:expr,
448     $merge:ident,
449     $merge_repeated:ident) => {
450        pub fn $merge_repeated<B>(
451            wire_type: WireType,
452            values: &mut Vec<$ty>,
453            buf: &mut B,
454            ctx: DecodeContext,
455        ) -> Result<(), DecodeError>
456        where
457            B: Buf,
458        {
459            if wire_type == WireType::LengthDelimited {
460                // Packed.
461                merge_loop(values, buf, ctx, |values, buf, ctx| {
462                    let mut value = Default::default();
463                    $merge($wire_type, &mut value, buf, ctx)?;
464                    values.push(value);
465                    Ok(())
466                })
467            } else {
468                // Unpacked.
469                check_wire_type($wire_type, wire_type)?;
470                let mut value = Default::default();
471                $merge(wire_type, &mut value, buf, ctx)?;
472                values.push(value);
473                Ok(())
474            }
475        }
476    };
477}
478
479/// Macro which emits a module containing a set of encoding functions for a
480/// variable width numeric type.
481macro_rules! varint {
482    ($ty:ty,
483     $proto_ty:ident) => (
484        varint!($ty,
485                $proto_ty,
486                to_uint64(value) { *value as u64 },
487                from_uint64(value) { value as $ty });
488    );
489
490    ($ty:ty,
491     $proto_ty:ident,
492     to_uint64($to_uint64_value:ident) $to_uint64:expr,
493     from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
494
495         pub mod $proto_ty {
496            use crate::encoding::*;
497
498            pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut {
499                encode_key(tag, WireType::Varint, buf);
500                encode_varint($to_uint64, buf);
501            }
502
503            pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf {
504                check_wire_type(WireType::Varint, wire_type)?;
505                let $from_uint64_value = decode_varint(buf)?;
506                *value = $from_uint64;
507                Ok(())
508            }
509
510            encode_repeated!($ty);
511
512            pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut {
513                if values.is_empty() { return; }
514
515                encode_key(tag, WireType::LengthDelimited, buf);
516                let len: usize = values.iter().map(|$to_uint64_value| {
517                    encoded_len_varint($to_uint64)
518                }).sum();
519                encode_varint(len as u64, buf);
520
521                for $to_uint64_value in values {
522                    encode_varint($to_uint64, buf);
523                }
524            }
525
526            merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
527
528            #[inline]
529            pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
530                key_len(tag) + encoded_len_varint($to_uint64)
531            }
532
533            #[inline]
534            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
535                key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
536                    encoded_len_varint($to_uint64)
537                }).sum::<usize>()
538            }
539
540            #[inline]
541            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
542                if values.is_empty() {
543                    0
544                } else {
545                    let len = values.iter()
546                                    .map(|$to_uint64_value| encoded_len_varint($to_uint64))
547                                    .sum::<usize>();
548                    key_len(tag) + encoded_len_varint(len as u64) + len
549                }
550            }
551
552            #[cfg(test)]
553            mod test {
554                use proptest::prelude::*;
555
556                use crate::encoding::$proto_ty::*;
557                use crate::encoding::test::{
558                    check_collection_type,
559                    check_type,
560                };
561
562                proptest! {
563                    #[test]
564                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
565                        check_type(value, tag, WireType::Varint,
566                                   encode, merge, encoded_len)?;
567                    }
568                    #[test]
569                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
570                        check_collection_type(value, tag, WireType::Varint,
571                                              encode_repeated, merge_repeated,
572                                              encoded_len_repeated)?;
573                    }
574                    #[test]
575                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
576                        check_type(value, tag, WireType::LengthDelimited,
577                                   encode_packed, merge_repeated,
578                                   encoded_len_packed)?;
579                    }
580                }
581            }
582         }
583
584    );
585}
586varint!(bool, bool,
587        to_uint64(value) if *value { 1u64 } else { 0u64 },
588        from_uint64(value) value != 0);
589varint!(i32, int32);
590varint!(i64, int64);
591varint!(u32, uint32);
592varint!(u64, uint64);
593varint!(i32, sint32,
594to_uint64(value) {
595    ((value << 1) ^ (value >> 31)) as u32 as u64
596},
597from_uint64(value) {
598    let value = value as u32;
599    ((value >> 1) as i32) ^ (-((value & 1) as i32))
600});
601varint!(i64, sint64,
602to_uint64(value) {
603    ((value << 1) ^ (value >> 63)) as u64
604},
605from_uint64(value) {
606    ((value >> 1) as i64) ^ (-((value & 1) as i64))
607});
608
609/// Macro which emits a module containing a set of encoding functions for a
610/// fixed width numeric type.
611macro_rules! fixed_width {
612    ($ty:ty,
613     $width:expr,
614     $wire_type:expr,
615     $proto_ty:ident,
616     $put:ident,
617     $get:ident) => {
618        pub mod $proto_ty {
619            use crate::encoding::*;
620
621            pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B)
622            where
623                B: BufMut,
624            {
625                encode_key(tag, $wire_type, buf);
626                buf.$put(*value);
627            }
628
629            pub fn merge<B>(
630                wire_type: WireType,
631                value: &mut $ty,
632                buf: &mut B,
633                _ctx: DecodeContext,
634            ) -> Result<(), DecodeError>
635            where
636                B: Buf,
637            {
638                check_wire_type($wire_type, wire_type)?;
639                if buf.remaining() < $width {
640                    return Err(DecodeError::new("buffer underflow"));
641                }
642                *value = buf.$get();
643                Ok(())
644            }
645
646            encode_repeated!($ty);
647
648            pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B)
649            where
650                B: BufMut,
651            {
652                if values.is_empty() {
653                    return;
654                }
655
656                encode_key(tag, WireType::LengthDelimited, buf);
657                let len = values.len() as u64 * $width;
658                encode_varint(len as u64, buf);
659
660                for value in values {
661                    buf.$put(*value);
662                }
663            }
664
665            merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
666
667            #[inline]
668            pub fn encoded_len(tag: u32, _: &$ty) -> usize {
669                key_len(tag) + $width
670            }
671
672            #[inline]
673            pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
674                (key_len(tag) + $width) * values.len()
675            }
676
677            #[inline]
678            pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
679                if values.is_empty() {
680                    0
681                } else {
682                    let len = $width * values.len();
683                    key_len(tag) + encoded_len_varint(len as u64) + len
684                }
685            }
686
687            #[cfg(test)]
688            mod test {
689                use proptest::prelude::*;
690
691                use super::super::test::{check_collection_type, check_type};
692                use super::*;
693
694                proptest! {
695                    #[test]
696                    fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
697                        check_type(value, tag, $wire_type,
698                                   encode, merge, encoded_len)?;
699                    }
700                    #[test]
701                    fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
702                        check_collection_type(value, tag, $wire_type,
703                                              encode_repeated, merge_repeated,
704                                              encoded_len_repeated)?;
705                    }
706                    #[test]
707                    fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
708                        check_type(value, tag, WireType::LengthDelimited,
709                                   encode_packed, merge_repeated,
710                                   encoded_len_packed)?;
711                    }
712                }
713            }
714        }
715    };
716}
717fixed_width!(
718    f32,
719    4,
720    WireType::ThirtyTwoBit,
721    float,
722    put_f32_le,
723    get_f32_le
724);
725fixed_width!(
726    f64,
727    8,
728    WireType::SixtyFourBit,
729    double,
730    put_f64_le,
731    get_f64_le
732);
733fixed_width!(
734    u32,
735    4,
736    WireType::ThirtyTwoBit,
737    fixed32,
738    put_u32_le,
739    get_u32_le
740);
741fixed_width!(
742    u64,
743    8,
744    WireType::SixtyFourBit,
745    fixed64,
746    put_u64_le,
747    get_u64_le
748);
749fixed_width!(
750    i32,
751    4,
752    WireType::ThirtyTwoBit,
753    sfixed32,
754    put_i32_le,
755    get_i32_le
756);
757fixed_width!(
758    i64,
759    8,
760    WireType::SixtyFourBit,
761    sfixed64,
762    put_i64_le,
763    get_i64_le
764);
765
766/// Macro which emits encoding functions for a length-delimited type.
767macro_rules! length_delimited {
768    ($ty:ty) => {
769        encode_repeated!($ty);
770
771        pub fn merge_repeated<B>(
772            wire_type: WireType,
773            values: &mut Vec<$ty>,
774            buf: &mut B,
775            ctx: DecodeContext,
776        ) -> Result<(), DecodeError>
777        where
778            B: Buf,
779        {
780            check_wire_type(WireType::LengthDelimited, wire_type)?;
781            let mut value = Default::default();
782            merge(wire_type, &mut value, buf, ctx)?;
783            values.push(value);
784            Ok(())
785        }
786
787        #[inline]
788        pub fn encoded_len(tag: u32, value: &$ty) -> usize {
789            key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
790        }
791
792        #[inline]
793        pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
794            key_len(tag) * values.len()
795                + values
796                    .iter()
797                    .map(|value| encoded_len_varint(value.len() as u64) + value.len())
798                    .sum::<usize>()
799        }
800    };
801}
802
803pub mod string {
804    use super::*;
805
806    pub fn encode<B>(tag: u32, value: &String, buf: &mut B)
807    where
808        B: BufMut,
809    {
810        encode_key(tag, WireType::LengthDelimited, buf);
811        encode_varint(value.len() as u64, buf);
812        buf.put_slice(value.as_bytes());
813    }
814    pub fn merge<B>(
815        wire_type: WireType,
816        value: &mut String,
817        buf: &mut B,
818        ctx: DecodeContext,
819    ) -> Result<(), DecodeError>
820    where
821        B: Buf,
822    {
823        // ## Unsafety
824        //
825        // `string::merge` reuses `bytes::merge`, with an additional check of utf-8
826        // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the
827        // string is cleared, so as to avoid leaking a string field with invalid data.
828        //
829        // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe
830        // alternative of temporarily swapping an empty `String` into the field, because it results
831        // in up to 10% better performance on the protobuf message decoding benchmarks.
832        //
833        // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into
834        // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or
835        // in the buf implementation, a drop guard is used.
836        unsafe {
837            struct DropGuard<'a>(&'a mut Vec<u8>);
838            impl<'a> Drop for DropGuard<'a> {
839                #[inline]
840                fn drop(&mut self) {
841                    self.0.clear();
842                }
843            }
844
845            let drop_guard = DropGuard(value.as_mut_vec());
846            bytes::merge(wire_type, drop_guard.0, buf, ctx)?;
847            match str::from_utf8(drop_guard.0) {
848                Ok(_) => {
849                    // Success; do not clear the bytes.
850                    mem::forget(drop_guard);
851                    Ok(())
852                }
853                Err(_) => Err(DecodeError::new(
854                    "invalid string value: data is not UTF-8 encoded",
855                )),
856            }
857        }
858    }
859
860    length_delimited!(String);
861
862    #[cfg(test)]
863    mod test {
864        use proptest::prelude::*;
865
866        use super::super::test::{check_collection_type, check_type};
867        use super::*;
868
869        proptest! {
870            #[test]
871            fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
872                super::test::check_type(value, tag, WireType::LengthDelimited,
873                                        encode, merge, encoded_len)?;
874            }
875            #[test]
876            fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
877                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
878                                                   encode_repeated, merge_repeated,
879                                                   encoded_len_repeated)?;
880            }
881        }
882    }
883}
884
885pub trait BytesAdapter: sealed::BytesAdapter {}
886
887mod sealed {
888    use super::{Buf, BufMut};
889
890    pub trait BytesAdapter: Default + Sized + 'static {
891        fn len(&self) -> usize;
892
893        /// Replace contents of this buffer with the contents of another buffer.
894        fn replace_with<B>(&mut self, buf: B)
895        where
896            B: Buf;
897
898        /// Appends this buffer to the (contents of) other buffer.
899        fn append_to<B>(&self, buf: &mut B)
900        where
901            B: BufMut;
902
903        fn is_empty(&self) -> bool {
904            self.len() == 0
905        }
906    }
907}
908
909impl BytesAdapter for Bytes {}
910
911impl sealed::BytesAdapter for Bytes {
912    fn len(&self) -> usize {
913        Buf::remaining(self)
914    }
915
916    fn replace_with<B>(&mut self, mut buf: B)
917    where
918        B: Buf,
919    {
920        *self = buf.copy_to_bytes(buf.remaining());
921    }
922
923    fn append_to<B>(&self, buf: &mut B)
924    where
925        B: BufMut,
926    {
927        buf.put(self.clone())
928    }
929}
930
931impl BytesAdapter for Vec<u8> {}
932
933impl sealed::BytesAdapter for Vec<u8> {
934    fn len(&self) -> usize {
935        Vec::len(self)
936    }
937
938    fn replace_with<B>(&mut self, buf: B)
939    where
940        B: Buf,
941    {
942        self.clear();
943        self.reserve(buf.remaining());
944        self.put(buf);
945    }
946
947    fn append_to<B>(&self, buf: &mut B)
948    where
949        B: BufMut,
950    {
951        buf.put(self.as_slice())
952    }
953}
954
955pub mod bytes {
956    use super::*;
957
958    pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B)
959    where
960        A: BytesAdapter,
961        B: BufMut,
962    {
963        encode_key(tag, WireType::LengthDelimited, buf);
964        encode_varint(value.len() as u64, buf);
965        value.append_to(buf);
966    }
967
968    pub fn merge<A, B>(
969        wire_type: WireType,
970        value: &mut A,
971        buf: &mut B,
972        _ctx: DecodeContext,
973    ) -> Result<(), DecodeError>
974    where
975        A: BytesAdapter,
976        B: Buf,
977    {
978        check_wire_type(WireType::LengthDelimited, wire_type)?;
979        let len = decode_varint(buf)?;
980        if len > buf.remaining() as u64 {
981            return Err(DecodeError::new("buffer underflow"));
982        }
983        let len = len as usize;
984
985        // Clear the existing value. This follows from the following rule in the encoding guide[1]:
986        //
987        // > Normally, an encoded message would never have more than one instance of a non-repeated
988        // > field. However, parsers are expected to handle the case in which they do. For numeric
989        // > types and strings, if the same field appears multiple times, the parser accepts the
990        // > last value it sees.
991        //
992        // [1]: https://developers.google.com/protocol-buffers/docs/encoding#optional
993        value.replace_with(buf.copy_to_bytes(len));
994        Ok(())
995    }
996
997    length_delimited!(impl BytesAdapter);
998
999    #[cfg(test)]
1000    mod test {
1001        use proptest::prelude::*;
1002
1003        use super::super::test::{check_collection_type, check_type};
1004        use super::*;
1005
1006        proptest! {
1007            #[test]
1008            fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1009                super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
1010                                                            encode, merge, encoded_len)?;
1011            }
1012
1013            #[test]
1014            fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1015                let value = Bytes::from(value);
1016                super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
1017                                                        encode, merge, encoded_len)?;
1018            }
1019
1020            #[test]
1021            fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1022                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1023                                                   encode_repeated, merge_repeated,
1024                                                   encoded_len_repeated)?;
1025            }
1026
1027            #[test]
1028            fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1029                let value = value.into_iter().map(Bytes::from).collect();
1030                super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1031                                                   encode_repeated, merge_repeated,
1032                                                   encoded_len_repeated)?;
1033            }
1034        }
1035    }
1036}
1037
1038pub mod message {
1039    use super::*;
1040
1041    pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1042    where
1043        M: Message,
1044        B: BufMut,
1045    {
1046        encode_key(tag, WireType::LengthDelimited, buf);
1047        encode_varint(msg.encoded_len() as u64, buf);
1048        msg.encode_raw(buf);
1049    }
1050
1051    pub fn merge<M, B>(
1052        wire_type: WireType,
1053        msg: &mut M,
1054        buf: &mut B,
1055        ctx: DecodeContext,
1056    ) -> Result<(), DecodeError>
1057    where
1058        M: Message,
1059        B: Buf,
1060    {
1061        check_wire_type(WireType::LengthDelimited, wire_type)?;
1062        ctx.limit_reached()?;
1063        merge_loop(
1064            msg,
1065            buf,
1066            ctx.enter_recursion(),
1067            |msg: &mut M, buf: &mut B, ctx| {
1068                let (tag, wire_type) = decode_key(buf)?;
1069                msg.merge_field(tag, wire_type, buf, ctx)
1070            },
1071        )
1072    }
1073
1074    pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1075    where
1076        M: Message,
1077        B: BufMut,
1078    {
1079        for msg in messages {
1080            encode(tag, msg, buf);
1081        }
1082    }
1083
1084    pub fn merge_repeated<M, B>(
1085        wire_type: WireType,
1086        messages: &mut Vec<M>,
1087        buf: &mut B,
1088        ctx: DecodeContext,
1089    ) -> Result<(), DecodeError>
1090    where
1091        M: Message + Default,
1092        B: Buf,
1093    {
1094        check_wire_type(WireType::LengthDelimited, wire_type)?;
1095        let mut msg = M::default();
1096        merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
1097        messages.push(msg);
1098        Ok(())
1099    }
1100
1101    #[inline]
1102    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1103    where
1104        M: Message,
1105    {
1106        let len = msg.encoded_len();
1107        key_len(tag) + encoded_len_varint(len as u64) + len
1108    }
1109
1110    #[inline]
1111    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1112    where
1113        M: Message,
1114    {
1115        key_len(tag) * messages.len()
1116            + messages
1117                .iter()
1118                .map(Message::encoded_len)
1119                .map(|len| len + encoded_len_varint(len as u64))
1120                .sum::<usize>()
1121    }
1122}
1123
1124pub mod group {
1125    use super::*;
1126
1127    pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1128    where
1129        M: Message,
1130        B: BufMut,
1131    {
1132        encode_key(tag, WireType::StartGroup, buf);
1133        msg.encode_raw(buf);
1134        encode_key(tag, WireType::EndGroup, buf);
1135    }
1136
1137    pub fn merge<M, B>(
1138        tag: u32,
1139        wire_type: WireType,
1140        msg: &mut M,
1141        buf: &mut B,
1142        ctx: DecodeContext,
1143    ) -> Result<(), DecodeError>
1144    where
1145        M: Message,
1146        B: Buf,
1147    {
1148        check_wire_type(WireType::StartGroup, wire_type)?;
1149
1150        ctx.limit_reached()?;
1151        loop {
1152            let (field_tag, field_wire_type) = decode_key(buf)?;
1153            if field_wire_type == WireType::EndGroup {
1154                if field_tag != tag {
1155                    return Err(DecodeError::new("unexpected end group tag"));
1156                }
1157                return Ok(());
1158            }
1159
1160            M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
1161        }
1162    }
1163
1164    pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1165    where
1166        M: Message,
1167        B: BufMut,
1168    {
1169        for msg in messages {
1170            encode(tag, msg, buf);
1171        }
1172    }
1173
1174    pub fn merge_repeated<M, B>(
1175        tag: u32,
1176        wire_type: WireType,
1177        messages: &mut Vec<M>,
1178        buf: &mut B,
1179        ctx: DecodeContext,
1180    ) -> Result<(), DecodeError>
1181    where
1182        M: Message + Default,
1183        B: Buf,
1184    {
1185        check_wire_type(WireType::StartGroup, wire_type)?;
1186        let mut msg = M::default();
1187        merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
1188        messages.push(msg);
1189        Ok(())
1190    }
1191
1192    #[inline]
1193    pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1194    where
1195        M: Message,
1196    {
1197        2 * key_len(tag) + msg.encoded_len()
1198    }
1199
1200    #[inline]
1201    pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1202    where
1203        M: Message,
1204    {
1205        2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
1206    }
1207}
1208
1209/// Rust doesn't have a `Map` trait, so macros are currently the best way to be
1210/// generic over `HashMap` and `BTreeMap`.
1211macro_rules! map {
1212    ($map_ty:ident) => {
1213        use crate::encoding::*;
1214        use core::hash::Hash;
1215
1216        /// Generic protobuf map encode function.
1217        pub fn encode<K, V, B, KE, KL, VE, VL>(
1218            key_encode: KE,
1219            key_encoded_len: KL,
1220            val_encode: VE,
1221            val_encoded_len: VL,
1222            tag: u32,
1223            values: &$map_ty<K, V>,
1224            buf: &mut B,
1225        ) where
1226            K: Default + Eq + Hash + Ord,
1227            V: Default + PartialEq,
1228            B: BufMut,
1229            KE: Fn(u32, &K, &mut B),
1230            KL: Fn(u32, &K) -> usize,
1231            VE: Fn(u32, &V, &mut B),
1232            VL: Fn(u32, &V) -> usize,
1233        {
1234            encode_with_default(
1235                key_encode,
1236                key_encoded_len,
1237                val_encode,
1238                val_encoded_len,
1239                &V::default(),
1240                tag,
1241                values,
1242                buf,
1243            )
1244        }
1245
1246        /// Generic protobuf map merge function.
1247        pub fn merge<K, V, B, KM, VM>(
1248            key_merge: KM,
1249            val_merge: VM,
1250            values: &mut $map_ty<K, V>,
1251            buf: &mut B,
1252            ctx: DecodeContext,
1253        ) -> Result<(), DecodeError>
1254        where
1255            K: Default + Eq + Hash + Ord,
1256            V: Default,
1257            B: Buf,
1258            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1259            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1260        {
1261            merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1262        }
1263
1264        /// Generic protobuf map encode function.
1265        pub fn encoded_len<K, V, KL, VL>(
1266            key_encoded_len: KL,
1267            val_encoded_len: VL,
1268            tag: u32,
1269            values: &$map_ty<K, V>,
1270        ) -> usize
1271        where
1272            K: Default + Eq + Hash + Ord,
1273            V: Default + PartialEq,
1274            KL: Fn(u32, &K) -> usize,
1275            VL: Fn(u32, &V) -> usize,
1276        {
1277            encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1278        }
1279
1280        /// Generic protobuf map encode function with an overriden value default.
1281        ///
1282        /// This is necessary because enumeration values can have a default value other
1283        /// than 0 in proto2.
1284        pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1285            key_encode: KE,
1286            key_encoded_len: KL,
1287            val_encode: VE,
1288            val_encoded_len: VL,
1289            val_default: &V,
1290            tag: u32,
1291            values: &$map_ty<K, V>,
1292            buf: &mut B,
1293        ) where
1294            K: Default + Eq + Hash + Ord,
1295            V: PartialEq,
1296            B: BufMut,
1297            KE: Fn(u32, &K, &mut B),
1298            KL: Fn(u32, &K) -> usize,
1299            VE: Fn(u32, &V, &mut B),
1300            VL: Fn(u32, &V) -> usize,
1301        {
1302            for (key, val) in values.iter() {
1303                let skip_key = key == &K::default();
1304                let skip_val = val == val_default;
1305
1306                let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1307                    + (if skip_val { 0 } else { val_encoded_len(2, val) });
1308
1309                encode_key(tag, WireType::LengthDelimited, buf);
1310                encode_varint(len as u64, buf);
1311                if !skip_key {
1312                    key_encode(1, key, buf);
1313                }
1314                if !skip_val {
1315                    val_encode(2, val, buf);
1316                }
1317            }
1318        }
1319
1320        /// Generic protobuf map merge function with an overriden value default.
1321        ///
1322        /// This is necessary because enumeration values can have a default value other
1323        /// than 0 in proto2.
1324        pub fn merge_with_default<K, V, B, KM, VM>(
1325            key_merge: KM,
1326            val_merge: VM,
1327            val_default: V,
1328            values: &mut $map_ty<K, V>,
1329            buf: &mut B,
1330            ctx: DecodeContext,
1331        ) -> Result<(), DecodeError>
1332        where
1333            K: Default + Eq + Hash + Ord,
1334            B: Buf,
1335            KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1336            VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1337        {
1338            let mut key = Default::default();
1339            let mut val = val_default;
1340            ctx.limit_reached()?;
1341            merge_loop(
1342                &mut (&mut key, &mut val),
1343                buf,
1344                ctx.enter_recursion(),
1345                |&mut (ref mut key, ref mut val), buf, ctx| {
1346                    let (tag, wire_type) = decode_key(buf)?;
1347                    match tag {
1348                        1 => key_merge(wire_type, key, buf, ctx),
1349                        2 => val_merge(wire_type, val, buf, ctx),
1350                        _ => skip_field(wire_type, tag, buf, ctx),
1351                    }
1352                },
1353            )?;
1354            values.insert(key, val);
1355
1356            Ok(())
1357        }
1358
1359        /// Generic protobuf map encode function with an overriden value default.
1360        ///
1361        /// This is necessary because enumeration values can have a default value other
1362        /// than 0 in proto2.
1363        pub fn encoded_len_with_default<K, V, KL, VL>(
1364            key_encoded_len: KL,
1365            val_encoded_len: VL,
1366            val_default: &V,
1367            tag: u32,
1368            values: &$map_ty<K, V>,
1369        ) -> usize
1370        where
1371            K: Default + Eq + Hash + Ord,
1372            V: PartialEq,
1373            KL: Fn(u32, &K) -> usize,
1374            VL: Fn(u32, &V) -> usize,
1375        {
1376            key_len(tag) * values.len()
1377                + values
1378                    .iter()
1379                    .map(|(key, val)| {
1380                        let len = (if key == &K::default() {
1381                            0
1382                        } else {
1383                            key_encoded_len(1, key)
1384                        }) + (if val == val_default {
1385                            0
1386                        } else {
1387                            val_encoded_len(2, val)
1388                        });
1389                        encoded_len_varint(len as u64) + len
1390                    })
1391                    .sum::<usize>()
1392        }
1393    };
1394}
1395
1396#[cfg(feature = "std")]
1397pub mod hash_map {
1398    use std::collections::HashMap;
1399    map!(HashMap);
1400}
1401
1402pub mod btree_map {
1403    map!(BTreeMap);
1404}
1405
1406#[cfg(test)]
1407mod test {
1408    use alloc::string::ToString;
1409    use core::borrow::Borrow;
1410    use core::fmt::Debug;
1411    use core::u64;
1412
1413    use ::bytes::{Bytes, BytesMut};
1414    use proptest::{prelude::*, test_runner::TestCaseResult};
1415
1416    use crate::encoding::*;
1417
1418    pub fn check_type<T, B>(
1419        value: T,
1420        tag: u32,
1421        wire_type: WireType,
1422        encode: fn(u32, &B, &mut BytesMut),
1423        merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1424        encoded_len: fn(u32, &B) -> usize,
1425    ) -> TestCaseResult
1426    where
1427        T: Debug + Default + PartialEq + Borrow<B>,
1428        B: ?Sized,
1429    {
1430        prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG);
1431
1432        let expected_len = encoded_len(tag, value.borrow());
1433
1434        let mut buf = BytesMut::with_capacity(expected_len);
1435        encode(tag, value.borrow(), &mut buf);
1436
1437        let mut buf = buf.freeze();
1438
1439        prop_assert_eq!(
1440            buf.remaining(),
1441            expected_len,
1442            "encoded_len wrong; expected: {}, actual: {}",
1443            expected_len,
1444            buf.remaining()
1445        );
1446
1447        if !buf.has_remaining() {
1448            // Short circuit for empty packed values.
1449            return Ok(());
1450        }
1451
1452        let (decoded_tag, decoded_wire_type) =
1453            decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1454        prop_assert_eq!(
1455            tag,
1456            decoded_tag,
1457            "decoded tag does not match; expected: {}, actual: {}",
1458            tag,
1459            decoded_tag
1460        );
1461
1462        prop_assert_eq!(
1463            wire_type,
1464            decoded_wire_type,
1465            "decoded wire type does not match; expected: {:?}, actual: {:?}",
1466            wire_type,
1467            decoded_wire_type,
1468        );
1469
1470        match wire_type {
1471            WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1472                "64bit wire type illegal remaining: {}, tag: {}",
1473                buf.remaining(),
1474                tag
1475            ))),
1476            WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1477                "32bit wire type illegal remaining: {}, tag: {}",
1478                buf.remaining(),
1479                tag
1480            ))),
1481            _ => Ok(()),
1482        }?;
1483
1484        let mut roundtrip_value = T::default();
1485        merge(
1486            wire_type,
1487            &mut roundtrip_value,
1488            &mut buf,
1489            DecodeContext::default(),
1490        )
1491        .map_err(|error| TestCaseError::fail(error.to_string()))?;
1492
1493        prop_assert!(
1494            !buf.has_remaining(),
1495            "expected buffer to be empty, remaining: {}",
1496            buf.remaining()
1497        );
1498
1499        prop_assert_eq!(value, roundtrip_value);
1500
1501        Ok(())
1502    }
1503
1504    pub fn check_collection_type<T, B, E, M, L>(
1505        value: T,
1506        tag: u32,
1507        wire_type: WireType,
1508        encode: E,
1509        mut merge: M,
1510        encoded_len: L,
1511    ) -> TestCaseResult
1512    where
1513        T: Debug + Default + PartialEq + Borrow<B>,
1514        B: ?Sized,
1515        E: FnOnce(u32, &B, &mut BytesMut),
1516        M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1517        L: FnOnce(u32, &B) -> usize,
1518    {
1519        prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG);
1520
1521        let expected_len = encoded_len(tag, value.borrow());
1522
1523        let mut buf = BytesMut::with_capacity(expected_len);
1524        encode(tag, value.borrow(), &mut buf);
1525
1526        let mut buf = buf.freeze();
1527
1528        prop_assert_eq!(
1529            buf.remaining(),
1530            expected_len,
1531            "encoded_len wrong; expected: {}, actual: {}",
1532            expected_len,
1533            buf.remaining()
1534        );
1535
1536        let mut roundtrip_value = Default::default();
1537        while buf.has_remaining() {
1538            let (decoded_tag, decoded_wire_type) =
1539                decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1540
1541            prop_assert_eq!(
1542                tag,
1543                decoded_tag,
1544                "decoded tag does not match; expected: {}, actual: {}",
1545                tag,
1546                decoded_tag
1547            );
1548
1549            prop_assert_eq!(
1550                wire_type,
1551                decoded_wire_type,
1552                "decoded wire type does not match; expected: {:?}, actual: {:?}",
1553                wire_type,
1554                decoded_wire_type
1555            );
1556
1557            merge(
1558                wire_type,
1559                &mut roundtrip_value,
1560                &mut buf,
1561                DecodeContext::default(),
1562            )
1563            .map_err(|error| TestCaseError::fail(error.to_string()))?;
1564        }
1565
1566        prop_assert_eq!(value, roundtrip_value);
1567
1568        Ok(())
1569    }
1570
1571    #[test]
1572    fn string_merge_invalid_utf8() {
1573        let mut s = String::new();
1574        let buf = b"\x02\x80\x80";
1575
1576        let r = string::merge(
1577            WireType::LengthDelimited,
1578            &mut s,
1579            &mut &buf[..],
1580            DecodeContext::default(),
1581        );
1582        r.expect_err("must be an error");
1583        assert!(s.is_empty());
1584    }
1585
1586    #[test]
1587    fn varint() {
1588        fn check(value: u64, mut encoded: &[u8]) {
1589            // TODO(rust-lang/rust-clippy#5494)
1590            #![allow(clippy::clone_double_ref)]
1591
1592            // Small buffer.
1593            let mut buf = Vec::with_capacity(1);
1594            encode_varint(value, &mut buf);
1595            assert_eq!(buf, encoded);
1596
1597            // Large buffer.
1598            let mut buf = Vec::with_capacity(100);
1599            encode_varint(value, &mut buf);
1600            assert_eq!(buf, encoded);
1601
1602            assert_eq!(encoded_len_varint(value), encoded.len());
1603
1604            let roundtrip_value = decode_varint(&mut encoded.clone()).expect("decoding failed");
1605            assert_eq!(value, roundtrip_value);
1606
1607            let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed");
1608            assert_eq!(value, roundtrip_value);
1609        }
1610
1611        check(2u64.pow(0) - 1, &[0x00]);
1612        check(2u64.pow(0), &[0x01]);
1613
1614        check(2u64.pow(7) - 1, &[0x7F]);
1615        check(2u64.pow(7), &[0x80, 0x01]);
1616        check(300, &[0xAC, 0x02]);
1617
1618        check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
1619        check(2u64.pow(14), &[0x80, 0x80, 0x01]);
1620
1621        check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
1622        check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
1623
1624        check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
1625        check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
1626
1627        check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1628        check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1629
1630        check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1631        check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1632
1633        check(
1634            2u64.pow(49) - 1,
1635            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1636        );
1637        check(
1638            2u64.pow(49),
1639            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1640        );
1641
1642        check(
1643            2u64.pow(56) - 1,
1644            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1645        );
1646        check(
1647            2u64.pow(56),
1648            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1649        );
1650
1651        check(
1652            2u64.pow(63) - 1,
1653            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1654        );
1655        check(
1656            2u64.pow(63),
1657            &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1658        );
1659
1660        check(
1661            u64::MAX,
1662            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
1663        );
1664    }
1665
1666    #[test]
1667    fn varint_overflow() {
1668        let mut u64_max_plus_one: &[u8] =
1669            &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
1670
1671        decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
1672        decode_varint_slow(&mut u64_max_plus_one)
1673            .expect_err("slow decoding u64::MAX + 1 succeeded");
1674    }
1675
1676    /// This big bowl o' macro soup generates an encoding property test for each combination of map
1677    /// type, scalar map key, and value type.
1678    /// TODO: these tests take a long time to compile, can this be improved?
1679    #[cfg(feature = "std")]
1680    macro_rules! map_tests {
1681        (keys: $keys:tt,
1682         vals: $vals:tt) => {
1683            mod hash_map {
1684                map_tests!(@private HashMap, hash_map, $keys, $vals);
1685            }
1686            mod btree_map {
1687                map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1688            }
1689        };
1690
1691        (@private $map_type:ident,
1692                  $mod_name:ident,
1693                  [$(($key_ty:ty, $key_proto:ident)),*],
1694                  $vals:tt) => {
1695            $(
1696                mod $key_proto {
1697                    use std::collections::$map_type;
1698
1699                    use proptest::prelude::*;
1700
1701                    use crate::encoding::*;
1702                    use crate::encoding::test::check_collection_type;
1703
1704                    map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1705                }
1706            )*
1707        };
1708
1709        (@private $map_type:ident,
1710                  $mod_name:ident,
1711                  ($key_ty:ty, $key_proto:ident),
1712                  [$(($val_ty:ty, $val_proto:ident)),*]) => {
1713            $(
1714                proptest! {
1715                    #[test]
1716                    fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1717                        check_collection_type(values, tag, WireType::LengthDelimited,
1718                                              |tag, values, buf| {
1719                                                  $mod_name::encode($key_proto::encode,
1720                                                                    $key_proto::encoded_len,
1721                                                                    $val_proto::encode,
1722                                                                    $val_proto::encoded_len,
1723                                                                    tag,
1724                                                                    values,
1725                                                                    buf)
1726                                              },
1727                                              |wire_type, values, buf, ctx| {
1728                                                  check_wire_type(WireType::LengthDelimited, wire_type)?;
1729                                                  $mod_name::merge($key_proto::merge,
1730                                                                   $val_proto::merge,
1731                                                                   values,
1732                                                                   buf,
1733                                                                   ctx)
1734                                              },
1735                                              |tag, values| {
1736                                                  $mod_name::encoded_len($key_proto::encoded_len,
1737                                                                         $val_proto::encoded_len,
1738                                                                         tag,
1739                                                                         values)
1740                                              })?;
1741                    }
1742                }
1743             )*
1744        };
1745    }
1746
1747    #[cfg(feature = "std")]
1748    map_tests!(keys: [
1749        (i32, int32),
1750        (i64, int64),
1751        (u32, uint32),
1752        (u64, uint64),
1753        (i32, sint32),
1754        (i64, sint64),
1755        (u32, fixed32),
1756        (u64, fixed64),
1757        (i32, sfixed32),
1758        (i64, sfixed64),
1759        (bool, bool),
1760        (String, string)
1761    ],
1762    vals: [
1763        (f32, float),
1764        (f64, double),
1765        (i32, int32),
1766        (i64, int64),
1767        (u32, uint32),
1768        (u64, uint64),
1769        (i32, sint32),
1770        (i64, sint64),
1771        (u32, fixed32),
1772        (u64, fixed64),
1773        (i32, sfixed32),
1774        (i64, sfixed64),
1775        (bool, bool),
1776        (String, string),
1777        (Vec<u8>, bytes)
1778    ]);
1779}