1#![allow(clippy::implicit_hasher, clippy::ptr_arg)]
6
7use alloc::collections::BTreeMap;
8use alloc::format;
9use alloc::string::String;
10use alloc::vec::Vec;
11use core::cmp::min;
12use core::convert::TryFrom;
13use core::mem;
14use core::str;
15use core::u32;
16use core::usize;
17
18use ::bytes::{Buf, BufMut, Bytes};
19
20use crate::DecodeError;
21use crate::Message;
22
23#[inline]
26pub fn encode_varint<B>(mut value: u64, buf: &mut B)
27where
28 B: BufMut,
29{
30 loop {
31 if value < 0x80 {
32 buf.put_u8(value as u8);
33 break;
34 } else {
35 buf.put_u8(((value & 0x7F) | 0x80) as u8);
36 value >>= 7;
37 }
38 }
39}
40
41#[inline]
43pub fn decode_varint<B>(buf: &mut B) -> Result<u64, DecodeError>
44where
45 B: Buf,
46{
47 let bytes = buf.chunk();
48 let len = bytes.len();
49 if len == 0 {
50 return Err(DecodeError::new("invalid varint"));
51 }
52
53 let byte = bytes[0];
54 if byte < 0x80 {
55 buf.advance(1);
56 Ok(u64::from(byte))
57 } else if len > 10 || bytes[len - 1] < 0x80 {
58 let (value, advance) = decode_varint_slice(bytes)?;
59 buf.advance(advance);
60 Ok(value)
61 } else {
62 decode_varint_slow(buf)
63 }
64}
65
66#[inline]
80fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
81 assert!(!bytes.is_empty());
85 assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
86
87 let mut b: u8;
88 let mut part0: u32;
89 b = unsafe { *bytes.get_unchecked(0) };
90 part0 = u32::from(b);
91 if b < 0x80 {
92 return Ok((u64::from(part0), 1));
93 };
94 part0 -= 0x80;
95 b = unsafe { *bytes.get_unchecked(1) };
96 part0 += u32::from(b) << 7;
97 if b < 0x80 {
98 return Ok((u64::from(part0), 2));
99 };
100 part0 -= 0x80 << 7;
101 b = unsafe { *bytes.get_unchecked(2) };
102 part0 += u32::from(b) << 14;
103 if b < 0x80 {
104 return Ok((u64::from(part0), 3));
105 };
106 part0 -= 0x80 << 14;
107 b = unsafe { *bytes.get_unchecked(3) };
108 part0 += u32::from(b) << 21;
109 if b < 0x80 {
110 return Ok((u64::from(part0), 4));
111 };
112 part0 -= 0x80 << 21;
113 let value = u64::from(part0);
114
115 let mut part1: u32;
116 b = unsafe { *bytes.get_unchecked(4) };
117 part1 = u32::from(b);
118 if b < 0x80 {
119 return Ok((value + (u64::from(part1) << 28), 5));
120 };
121 part1 -= 0x80;
122 b = unsafe { *bytes.get_unchecked(5) };
123 part1 += u32::from(b) << 7;
124 if b < 0x80 {
125 return Ok((value + (u64::from(part1) << 28), 6));
126 };
127 part1 -= 0x80 << 7;
128 b = unsafe { *bytes.get_unchecked(6) };
129 part1 += u32::from(b) << 14;
130 if b < 0x80 {
131 return Ok((value + (u64::from(part1) << 28), 7));
132 };
133 part1 -= 0x80 << 14;
134 b = unsafe { *bytes.get_unchecked(7) };
135 part1 += u32::from(b) << 21;
136 if b < 0x80 {
137 return Ok((value + (u64::from(part1) << 28), 8));
138 };
139 part1 -= 0x80 << 21;
140 let value = value + ((u64::from(part1)) << 28);
141
142 let mut part2: u32;
143 b = unsafe { *bytes.get_unchecked(8) };
144 part2 = u32::from(b);
145 if b < 0x80 {
146 return Ok((value + (u64::from(part2) << 56), 9));
147 };
148 part2 -= 0x80;
149 b = unsafe { *bytes.get_unchecked(9) };
150 part2 += u32::from(b) << 7;
151 if b < 0x02 {
154 return Ok((value + (u64::from(part2) << 56), 10));
155 };
156
157 Err(DecodeError::new("invalid varint"))
160}
161
162#[inline(never)]
169#[cold]
170fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
171where
172 B: Buf,
173{
174 let mut value = 0;
175 for count in 0..min(10, buf.remaining()) {
176 let byte = buf.get_u8();
177 value |= u64::from(byte & 0x7F) << (count * 7);
178 if byte <= 0x7F {
179 if count == 9 && byte >= 0x02 {
182 return Err(DecodeError::new("invalid varint"));
183 } else {
184 return Ok(value);
185 }
186 }
187 }
188
189 Err(DecodeError::new("invalid varint"))
190}
191
192#[derive(Clone, Debug)]
197pub struct DecodeContext {
198 #[cfg(not(feature = "no-recursion-limit"))]
205 recurse_count: u32,
206}
207
208impl Default for DecodeContext {
209 #[cfg(not(feature = "no-recursion-limit"))]
210 #[inline]
211 fn default() -> DecodeContext {
212 DecodeContext {
213 recurse_count: crate::RECURSION_LIMIT,
214 }
215 }
216
217 #[cfg(feature = "no-recursion-limit")]
218 #[inline]
219 fn default() -> DecodeContext {
220 DecodeContext {}
221 }
222}
223
224impl DecodeContext {
225 #[cfg(not(feature = "no-recursion-limit"))]
231 #[inline]
232 pub(crate) fn enter_recursion(&self) -> DecodeContext {
233 DecodeContext {
234 recurse_count: self.recurse_count - 1,
235 }
236 }
237
238 #[cfg(feature = "no-recursion-limit")]
239 #[inline]
240 pub(crate) fn enter_recursion(&self) -> DecodeContext {
241 DecodeContext {}
242 }
243
244 #[cfg(not(feature = "no-recursion-limit"))]
250 #[inline]
251 pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
252 if self.recurse_count == 0 {
253 Err(DecodeError::new("recursion limit reached"))
254 } else {
255 Ok(())
256 }
257 }
258
259 #[cfg(feature = "no-recursion-limit")]
260 #[inline]
261 #[allow(clippy::unnecessary_wraps)] pub(crate) fn limit_reached(&self) -> Result<(), DecodeError> {
263 Ok(())
264 }
265}
266
267#[inline]
270pub fn encoded_len_varint(value: u64) -> usize {
271 ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
274}
275
276#[derive(Clone, Copy, Debug, PartialEq)]
277#[repr(u8)]
278pub enum WireType {
279 Varint = 0,
280 SixtyFourBit = 1,
281 LengthDelimited = 2,
282 StartGroup = 3,
283 EndGroup = 4,
284 ThirtyTwoBit = 5,
285}
286
287pub const MIN_TAG: u32 = 1;
288pub const MAX_TAG: u32 = (1 << 29) - 1;
289
290impl TryFrom<u64> for WireType {
291 type Error = DecodeError;
292
293 #[inline]
294 fn try_from(value: u64) -> Result<Self, Self::Error> {
295 match value {
296 0 => Ok(WireType::Varint),
297 1 => Ok(WireType::SixtyFourBit),
298 2 => Ok(WireType::LengthDelimited),
299 3 => Ok(WireType::StartGroup),
300 4 => Ok(WireType::EndGroup),
301 5 => Ok(WireType::ThirtyTwoBit),
302 _ => Err(DecodeError::new(format!(
303 "invalid wire type value: {}",
304 value
305 ))),
306 }
307 }
308}
309
310#[inline]
313pub fn encode_key<B>(tag: u32, wire_type: WireType, buf: &mut B)
314where
315 B: BufMut,
316{
317 debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
318 let key = (tag << 3) | wire_type as u32;
319 encode_varint(u64::from(key), buf);
320}
321
322#[inline(always)]
325pub fn decode_key<B>(buf: &mut B) -> Result<(u32, WireType), DecodeError>
326where
327 B: Buf,
328{
329 let key = decode_varint(buf)?;
330 if key > u64::from(u32::MAX) {
331 return Err(DecodeError::new(format!("invalid key value: {}", key)));
332 }
333 let wire_type = WireType::try_from(key & 0x07)?;
334 let tag = key as u32 >> 3;
335
336 if tag < MIN_TAG {
337 return Err(DecodeError::new("invalid tag value: 0"));
338 }
339
340 Ok((tag, wire_type))
341}
342
343#[inline]
346pub fn key_len(tag: u32) -> usize {
347 encoded_len_varint(u64::from(tag << 3))
348}
349
350#[inline]
353pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
354 if expected != actual {
355 return Err(DecodeError::new(format!(
356 "invalid wire type: {:?} (expected {:?})",
357 actual, expected
358 )));
359 }
360 Ok(())
361}
362
363pub fn merge_loop<T, M, B>(
366 value: &mut T,
367 buf: &mut B,
368 ctx: DecodeContext,
369 mut merge: M,
370) -> Result<(), DecodeError>
371where
372 M: FnMut(&mut T, &mut B, DecodeContext) -> Result<(), DecodeError>,
373 B: Buf,
374{
375 let len = decode_varint(buf)?;
376 let remaining = buf.remaining();
377 if len > remaining as u64 {
378 return Err(DecodeError::new("buffer underflow"));
379 }
380
381 let limit = remaining - len as usize;
382 while buf.remaining() > limit {
383 merge(value, buf, ctx.clone())?;
384 }
385
386 if buf.remaining() != limit {
387 return Err(DecodeError::new("delimited length exceeded"));
388 }
389 Ok(())
390}
391
392pub fn skip_field<B>(
393 wire_type: WireType,
394 tag: u32,
395 buf: &mut B,
396 ctx: DecodeContext,
397) -> Result<(), DecodeError>
398where
399 B: Buf,
400{
401 ctx.limit_reached()?;
402 let len = match wire_type {
403 WireType::Varint => decode_varint(buf).map(|_| 0)?,
404 WireType::ThirtyTwoBit => 4,
405 WireType::SixtyFourBit => 8,
406 WireType::LengthDelimited => decode_varint(buf)?,
407 WireType::StartGroup => loop {
408 let (inner_tag, inner_wire_type) = decode_key(buf)?;
409 match inner_wire_type {
410 WireType::EndGroup => {
411 if inner_tag != tag {
412 return Err(DecodeError::new("unexpected end group tag"));
413 }
414 break 0;
415 }
416 _ => skip_field(inner_wire_type, inner_tag, buf, ctx.enter_recursion())?,
417 }
418 },
419 WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
420 };
421
422 if len > buf.remaining() as u64 {
423 return Err(DecodeError::new("buffer underflow"));
424 }
425
426 buf.advance(len as usize);
427 Ok(())
428}
429
430macro_rules! encode_repeated {
432 ($ty:ty) => {
433 pub fn encode_repeated<B>(tag: u32, values: &[$ty], buf: &mut B)
434 where
435 B: BufMut,
436 {
437 for value in values {
438 encode(tag, value, buf);
439 }
440 }
441 };
442}
443
444macro_rules! merge_repeated_numeric {
446 ($ty:ty,
447 $wire_type:expr,
448 $merge:ident,
449 $merge_repeated:ident) => {
450 pub fn $merge_repeated<B>(
451 wire_type: WireType,
452 values: &mut Vec<$ty>,
453 buf: &mut B,
454 ctx: DecodeContext,
455 ) -> Result<(), DecodeError>
456 where
457 B: Buf,
458 {
459 if wire_type == WireType::LengthDelimited {
460 merge_loop(values, buf, ctx, |values, buf, ctx| {
462 let mut value = Default::default();
463 $merge($wire_type, &mut value, buf, ctx)?;
464 values.push(value);
465 Ok(())
466 })
467 } else {
468 check_wire_type($wire_type, wire_type)?;
470 let mut value = Default::default();
471 $merge(wire_type, &mut value, buf, ctx)?;
472 values.push(value);
473 Ok(())
474 }
475 }
476 };
477}
478
479macro_rules! varint {
482 ($ty:ty,
483 $proto_ty:ident) => (
484 varint!($ty,
485 $proto_ty,
486 to_uint64(value) { *value as u64 },
487 from_uint64(value) { value as $ty });
488 );
489
490 ($ty:ty,
491 $proto_ty:ident,
492 to_uint64($to_uint64_value:ident) $to_uint64:expr,
493 from_uint64($from_uint64_value:ident) $from_uint64:expr) => (
494
495 pub mod $proto_ty {
496 use crate::encoding::*;
497
498 pub fn encode<B>(tag: u32, $to_uint64_value: &$ty, buf: &mut B) where B: BufMut {
499 encode_key(tag, WireType::Varint, buf);
500 encode_varint($to_uint64, buf);
501 }
502
503 pub fn merge<B>(wire_type: WireType, value: &mut $ty, buf: &mut B, _ctx: DecodeContext) -> Result<(), DecodeError> where B: Buf {
504 check_wire_type(WireType::Varint, wire_type)?;
505 let $from_uint64_value = decode_varint(buf)?;
506 *value = $from_uint64;
507 Ok(())
508 }
509
510 encode_repeated!($ty);
511
512 pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B) where B: BufMut {
513 if values.is_empty() { return; }
514
515 encode_key(tag, WireType::LengthDelimited, buf);
516 let len: usize = values.iter().map(|$to_uint64_value| {
517 encoded_len_varint($to_uint64)
518 }).sum();
519 encode_varint(len as u64, buf);
520
521 for $to_uint64_value in values {
522 encode_varint($to_uint64, buf);
523 }
524 }
525
526 merge_repeated_numeric!($ty, WireType::Varint, merge, merge_repeated);
527
528 #[inline]
529 pub fn encoded_len(tag: u32, $to_uint64_value: &$ty) -> usize {
530 key_len(tag) + encoded_len_varint($to_uint64)
531 }
532
533 #[inline]
534 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
535 key_len(tag) * values.len() + values.iter().map(|$to_uint64_value| {
536 encoded_len_varint($to_uint64)
537 }).sum::<usize>()
538 }
539
540 #[inline]
541 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
542 if values.is_empty() {
543 0
544 } else {
545 let len = values.iter()
546 .map(|$to_uint64_value| encoded_len_varint($to_uint64))
547 .sum::<usize>();
548 key_len(tag) + encoded_len_varint(len as u64) + len
549 }
550 }
551
552 #[cfg(test)]
553 mod test {
554 use proptest::prelude::*;
555
556 use crate::encoding::$proto_ty::*;
557 use crate::encoding::test::{
558 check_collection_type,
559 check_type,
560 };
561
562 proptest! {
563 #[test]
564 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
565 check_type(value, tag, WireType::Varint,
566 encode, merge, encoded_len)?;
567 }
568 #[test]
569 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
570 check_collection_type(value, tag, WireType::Varint,
571 encode_repeated, merge_repeated,
572 encoded_len_repeated)?;
573 }
574 #[test]
575 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
576 check_type(value, tag, WireType::LengthDelimited,
577 encode_packed, merge_repeated,
578 encoded_len_packed)?;
579 }
580 }
581 }
582 }
583
584 );
585}
586varint!(bool, bool,
587 to_uint64(value) if *value { 1u64 } else { 0u64 },
588 from_uint64(value) value != 0);
589varint!(i32, int32);
590varint!(i64, int64);
591varint!(u32, uint32);
592varint!(u64, uint64);
593varint!(i32, sint32,
594to_uint64(value) {
595 ((value << 1) ^ (value >> 31)) as u32 as u64
596},
597from_uint64(value) {
598 let value = value as u32;
599 ((value >> 1) as i32) ^ (-((value & 1) as i32))
600});
601varint!(i64, sint64,
602to_uint64(value) {
603 ((value << 1) ^ (value >> 63)) as u64
604},
605from_uint64(value) {
606 ((value >> 1) as i64) ^ (-((value & 1) as i64))
607});
608
609macro_rules! fixed_width {
612 ($ty:ty,
613 $width:expr,
614 $wire_type:expr,
615 $proto_ty:ident,
616 $put:ident,
617 $get:ident) => {
618 pub mod $proto_ty {
619 use crate::encoding::*;
620
621 pub fn encode<B>(tag: u32, value: &$ty, buf: &mut B)
622 where
623 B: BufMut,
624 {
625 encode_key(tag, $wire_type, buf);
626 buf.$put(*value);
627 }
628
629 pub fn merge<B>(
630 wire_type: WireType,
631 value: &mut $ty,
632 buf: &mut B,
633 _ctx: DecodeContext,
634 ) -> Result<(), DecodeError>
635 where
636 B: Buf,
637 {
638 check_wire_type($wire_type, wire_type)?;
639 if buf.remaining() < $width {
640 return Err(DecodeError::new("buffer underflow"));
641 }
642 *value = buf.$get();
643 Ok(())
644 }
645
646 encode_repeated!($ty);
647
648 pub fn encode_packed<B>(tag: u32, values: &[$ty], buf: &mut B)
649 where
650 B: BufMut,
651 {
652 if values.is_empty() {
653 return;
654 }
655
656 encode_key(tag, WireType::LengthDelimited, buf);
657 let len = values.len() as u64 * $width;
658 encode_varint(len as u64, buf);
659
660 for value in values {
661 buf.$put(*value);
662 }
663 }
664
665 merge_repeated_numeric!($ty, $wire_type, merge, merge_repeated);
666
667 #[inline]
668 pub fn encoded_len(tag: u32, _: &$ty) -> usize {
669 key_len(tag) + $width
670 }
671
672 #[inline]
673 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
674 (key_len(tag) + $width) * values.len()
675 }
676
677 #[inline]
678 pub fn encoded_len_packed(tag: u32, values: &[$ty]) -> usize {
679 if values.is_empty() {
680 0
681 } else {
682 let len = $width * values.len();
683 key_len(tag) + encoded_len_varint(len as u64) + len
684 }
685 }
686
687 #[cfg(test)]
688 mod test {
689 use proptest::prelude::*;
690
691 use super::super::test::{check_collection_type, check_type};
692 use super::*;
693
694 proptest! {
695 #[test]
696 fn check(value: $ty, tag in MIN_TAG..=MAX_TAG) {
697 check_type(value, tag, $wire_type,
698 encode, merge, encoded_len)?;
699 }
700 #[test]
701 fn check_repeated(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
702 check_collection_type(value, tag, $wire_type,
703 encode_repeated, merge_repeated,
704 encoded_len_repeated)?;
705 }
706 #[test]
707 fn check_packed(value: Vec<$ty>, tag in MIN_TAG..=MAX_TAG) {
708 check_type(value, tag, WireType::LengthDelimited,
709 encode_packed, merge_repeated,
710 encoded_len_packed)?;
711 }
712 }
713 }
714 }
715 };
716}
717fixed_width!(
718 f32,
719 4,
720 WireType::ThirtyTwoBit,
721 float,
722 put_f32_le,
723 get_f32_le
724);
725fixed_width!(
726 f64,
727 8,
728 WireType::SixtyFourBit,
729 double,
730 put_f64_le,
731 get_f64_le
732);
733fixed_width!(
734 u32,
735 4,
736 WireType::ThirtyTwoBit,
737 fixed32,
738 put_u32_le,
739 get_u32_le
740);
741fixed_width!(
742 u64,
743 8,
744 WireType::SixtyFourBit,
745 fixed64,
746 put_u64_le,
747 get_u64_le
748);
749fixed_width!(
750 i32,
751 4,
752 WireType::ThirtyTwoBit,
753 sfixed32,
754 put_i32_le,
755 get_i32_le
756);
757fixed_width!(
758 i64,
759 8,
760 WireType::SixtyFourBit,
761 sfixed64,
762 put_i64_le,
763 get_i64_le
764);
765
766macro_rules! length_delimited {
768 ($ty:ty) => {
769 encode_repeated!($ty);
770
771 pub fn merge_repeated<B>(
772 wire_type: WireType,
773 values: &mut Vec<$ty>,
774 buf: &mut B,
775 ctx: DecodeContext,
776 ) -> Result<(), DecodeError>
777 where
778 B: Buf,
779 {
780 check_wire_type(WireType::LengthDelimited, wire_type)?;
781 let mut value = Default::default();
782 merge(wire_type, &mut value, buf, ctx)?;
783 values.push(value);
784 Ok(())
785 }
786
787 #[inline]
788 pub fn encoded_len(tag: u32, value: &$ty) -> usize {
789 key_len(tag) + encoded_len_varint(value.len() as u64) + value.len()
790 }
791
792 #[inline]
793 pub fn encoded_len_repeated(tag: u32, values: &[$ty]) -> usize {
794 key_len(tag) * values.len()
795 + values
796 .iter()
797 .map(|value| encoded_len_varint(value.len() as u64) + value.len())
798 .sum::<usize>()
799 }
800 };
801}
802
803pub mod string {
804 use super::*;
805
806 pub fn encode<B>(tag: u32, value: &String, buf: &mut B)
807 where
808 B: BufMut,
809 {
810 encode_key(tag, WireType::LengthDelimited, buf);
811 encode_varint(value.len() as u64, buf);
812 buf.put_slice(value.as_bytes());
813 }
814 pub fn merge<B>(
815 wire_type: WireType,
816 value: &mut String,
817 buf: &mut B,
818 ctx: DecodeContext,
819 ) -> Result<(), DecodeError>
820 where
821 B: Buf,
822 {
823 unsafe {
837 struct DropGuard<'a>(&'a mut Vec<u8>);
838 impl<'a> Drop for DropGuard<'a> {
839 #[inline]
840 fn drop(&mut self) {
841 self.0.clear();
842 }
843 }
844
845 let drop_guard = DropGuard(value.as_mut_vec());
846 bytes::merge(wire_type, drop_guard.0, buf, ctx)?;
847 match str::from_utf8(drop_guard.0) {
848 Ok(_) => {
849 mem::forget(drop_guard);
851 Ok(())
852 }
853 Err(_) => Err(DecodeError::new(
854 "invalid string value: data is not UTF-8 encoded",
855 )),
856 }
857 }
858 }
859
860 length_delimited!(String);
861
862 #[cfg(test)]
863 mod test {
864 use proptest::prelude::*;
865
866 use super::super::test::{check_collection_type, check_type};
867 use super::*;
868
869 proptest! {
870 #[test]
871 fn check(value: String, tag in MIN_TAG..=MAX_TAG) {
872 super::test::check_type(value, tag, WireType::LengthDelimited,
873 encode, merge, encoded_len)?;
874 }
875 #[test]
876 fn check_repeated(value: Vec<String>, tag in MIN_TAG..=MAX_TAG) {
877 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
878 encode_repeated, merge_repeated,
879 encoded_len_repeated)?;
880 }
881 }
882 }
883}
884
885pub trait BytesAdapter: sealed::BytesAdapter {}
886
887mod sealed {
888 use super::{Buf, BufMut};
889
890 pub trait BytesAdapter: Default + Sized + 'static {
891 fn len(&self) -> usize;
892
893 fn replace_with<B>(&mut self, buf: B)
895 where
896 B: Buf;
897
898 fn append_to<B>(&self, buf: &mut B)
900 where
901 B: BufMut;
902
903 fn is_empty(&self) -> bool {
904 self.len() == 0
905 }
906 }
907}
908
909impl BytesAdapter for Bytes {}
910
911impl sealed::BytesAdapter for Bytes {
912 fn len(&self) -> usize {
913 Buf::remaining(self)
914 }
915
916 fn replace_with<B>(&mut self, mut buf: B)
917 where
918 B: Buf,
919 {
920 *self = buf.copy_to_bytes(buf.remaining());
921 }
922
923 fn append_to<B>(&self, buf: &mut B)
924 where
925 B: BufMut,
926 {
927 buf.put(self.clone())
928 }
929}
930
931impl BytesAdapter for Vec<u8> {}
932
933impl sealed::BytesAdapter for Vec<u8> {
934 fn len(&self) -> usize {
935 Vec::len(self)
936 }
937
938 fn replace_with<B>(&mut self, buf: B)
939 where
940 B: Buf,
941 {
942 self.clear();
943 self.reserve(buf.remaining());
944 self.put(buf);
945 }
946
947 fn append_to<B>(&self, buf: &mut B)
948 where
949 B: BufMut,
950 {
951 buf.put(self.as_slice())
952 }
953}
954
955pub mod bytes {
956 use super::*;
957
958 pub fn encode<A, B>(tag: u32, value: &A, buf: &mut B)
959 where
960 A: BytesAdapter,
961 B: BufMut,
962 {
963 encode_key(tag, WireType::LengthDelimited, buf);
964 encode_varint(value.len() as u64, buf);
965 value.append_to(buf);
966 }
967
968 pub fn merge<A, B>(
969 wire_type: WireType,
970 value: &mut A,
971 buf: &mut B,
972 _ctx: DecodeContext,
973 ) -> Result<(), DecodeError>
974 where
975 A: BytesAdapter,
976 B: Buf,
977 {
978 check_wire_type(WireType::LengthDelimited, wire_type)?;
979 let len = decode_varint(buf)?;
980 if len > buf.remaining() as u64 {
981 return Err(DecodeError::new("buffer underflow"));
982 }
983 let len = len as usize;
984
985 value.replace_with(buf.copy_to_bytes(len));
994 Ok(())
995 }
996
997 length_delimited!(impl BytesAdapter);
998
999 #[cfg(test)]
1000 mod test {
1001 use proptest::prelude::*;
1002
1003 use super::super::test::{check_collection_type, check_type};
1004 use super::*;
1005
1006 proptest! {
1007 #[test]
1008 fn check_vec(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1009 super::test::check_type::<Vec<u8>, Vec<u8>>(value, tag, WireType::LengthDelimited,
1010 encode, merge, encoded_len)?;
1011 }
1012
1013 #[test]
1014 fn check_bytes(value: Vec<u8>, tag in MIN_TAG..=MAX_TAG) {
1015 let value = Bytes::from(value);
1016 super::test::check_type::<Bytes, Bytes>(value, tag, WireType::LengthDelimited,
1017 encode, merge, encoded_len)?;
1018 }
1019
1020 #[test]
1021 fn check_repeated_vec(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1022 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1023 encode_repeated, merge_repeated,
1024 encoded_len_repeated)?;
1025 }
1026
1027 #[test]
1028 fn check_repeated_bytes(value: Vec<Vec<u8>>, tag in MIN_TAG..=MAX_TAG) {
1029 let value = value.into_iter().map(Bytes::from).collect();
1030 super::test::check_collection_type(value, tag, WireType::LengthDelimited,
1031 encode_repeated, merge_repeated,
1032 encoded_len_repeated)?;
1033 }
1034 }
1035 }
1036}
1037
1038pub mod message {
1039 use super::*;
1040
1041 pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1042 where
1043 M: Message,
1044 B: BufMut,
1045 {
1046 encode_key(tag, WireType::LengthDelimited, buf);
1047 encode_varint(msg.encoded_len() as u64, buf);
1048 msg.encode_raw(buf);
1049 }
1050
1051 pub fn merge<M, B>(
1052 wire_type: WireType,
1053 msg: &mut M,
1054 buf: &mut B,
1055 ctx: DecodeContext,
1056 ) -> Result<(), DecodeError>
1057 where
1058 M: Message,
1059 B: Buf,
1060 {
1061 check_wire_type(WireType::LengthDelimited, wire_type)?;
1062 ctx.limit_reached()?;
1063 merge_loop(
1064 msg,
1065 buf,
1066 ctx.enter_recursion(),
1067 |msg: &mut M, buf: &mut B, ctx| {
1068 let (tag, wire_type) = decode_key(buf)?;
1069 msg.merge_field(tag, wire_type, buf, ctx)
1070 },
1071 )
1072 }
1073
1074 pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1075 where
1076 M: Message,
1077 B: BufMut,
1078 {
1079 for msg in messages {
1080 encode(tag, msg, buf);
1081 }
1082 }
1083
1084 pub fn merge_repeated<M, B>(
1085 wire_type: WireType,
1086 messages: &mut Vec<M>,
1087 buf: &mut B,
1088 ctx: DecodeContext,
1089 ) -> Result<(), DecodeError>
1090 where
1091 M: Message + Default,
1092 B: Buf,
1093 {
1094 check_wire_type(WireType::LengthDelimited, wire_type)?;
1095 let mut msg = M::default();
1096 merge(WireType::LengthDelimited, &mut msg, buf, ctx)?;
1097 messages.push(msg);
1098 Ok(())
1099 }
1100
1101 #[inline]
1102 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1103 where
1104 M: Message,
1105 {
1106 let len = msg.encoded_len();
1107 key_len(tag) + encoded_len_varint(len as u64) + len
1108 }
1109
1110 #[inline]
1111 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1112 where
1113 M: Message,
1114 {
1115 key_len(tag) * messages.len()
1116 + messages
1117 .iter()
1118 .map(Message::encoded_len)
1119 .map(|len| len + encoded_len_varint(len as u64))
1120 .sum::<usize>()
1121 }
1122}
1123
1124pub mod group {
1125 use super::*;
1126
1127 pub fn encode<M, B>(tag: u32, msg: &M, buf: &mut B)
1128 where
1129 M: Message,
1130 B: BufMut,
1131 {
1132 encode_key(tag, WireType::StartGroup, buf);
1133 msg.encode_raw(buf);
1134 encode_key(tag, WireType::EndGroup, buf);
1135 }
1136
1137 pub fn merge<M, B>(
1138 tag: u32,
1139 wire_type: WireType,
1140 msg: &mut M,
1141 buf: &mut B,
1142 ctx: DecodeContext,
1143 ) -> Result<(), DecodeError>
1144 where
1145 M: Message,
1146 B: Buf,
1147 {
1148 check_wire_type(WireType::StartGroup, wire_type)?;
1149
1150 ctx.limit_reached()?;
1151 loop {
1152 let (field_tag, field_wire_type) = decode_key(buf)?;
1153 if field_wire_type == WireType::EndGroup {
1154 if field_tag != tag {
1155 return Err(DecodeError::new("unexpected end group tag"));
1156 }
1157 return Ok(());
1158 }
1159
1160 M::merge_field(msg, field_tag, field_wire_type, buf, ctx.enter_recursion())?;
1161 }
1162 }
1163
1164 pub fn encode_repeated<M, B>(tag: u32, messages: &[M], buf: &mut B)
1165 where
1166 M: Message,
1167 B: BufMut,
1168 {
1169 for msg in messages {
1170 encode(tag, msg, buf);
1171 }
1172 }
1173
1174 pub fn merge_repeated<M, B>(
1175 tag: u32,
1176 wire_type: WireType,
1177 messages: &mut Vec<M>,
1178 buf: &mut B,
1179 ctx: DecodeContext,
1180 ) -> Result<(), DecodeError>
1181 where
1182 M: Message + Default,
1183 B: Buf,
1184 {
1185 check_wire_type(WireType::StartGroup, wire_type)?;
1186 let mut msg = M::default();
1187 merge(tag, WireType::StartGroup, &mut msg, buf, ctx)?;
1188 messages.push(msg);
1189 Ok(())
1190 }
1191
1192 #[inline]
1193 pub fn encoded_len<M>(tag: u32, msg: &M) -> usize
1194 where
1195 M: Message,
1196 {
1197 2 * key_len(tag) + msg.encoded_len()
1198 }
1199
1200 #[inline]
1201 pub fn encoded_len_repeated<M>(tag: u32, messages: &[M]) -> usize
1202 where
1203 M: Message,
1204 {
1205 2 * key_len(tag) * messages.len() + messages.iter().map(Message::encoded_len).sum::<usize>()
1206 }
1207}
1208
1209macro_rules! map {
1212 ($map_ty:ident) => {
1213 use crate::encoding::*;
1214 use core::hash::Hash;
1215
1216 pub fn encode<K, V, B, KE, KL, VE, VL>(
1218 key_encode: KE,
1219 key_encoded_len: KL,
1220 val_encode: VE,
1221 val_encoded_len: VL,
1222 tag: u32,
1223 values: &$map_ty<K, V>,
1224 buf: &mut B,
1225 ) where
1226 K: Default + Eq + Hash + Ord,
1227 V: Default + PartialEq,
1228 B: BufMut,
1229 KE: Fn(u32, &K, &mut B),
1230 KL: Fn(u32, &K) -> usize,
1231 VE: Fn(u32, &V, &mut B),
1232 VL: Fn(u32, &V) -> usize,
1233 {
1234 encode_with_default(
1235 key_encode,
1236 key_encoded_len,
1237 val_encode,
1238 val_encoded_len,
1239 &V::default(),
1240 tag,
1241 values,
1242 buf,
1243 )
1244 }
1245
1246 pub fn merge<K, V, B, KM, VM>(
1248 key_merge: KM,
1249 val_merge: VM,
1250 values: &mut $map_ty<K, V>,
1251 buf: &mut B,
1252 ctx: DecodeContext,
1253 ) -> Result<(), DecodeError>
1254 where
1255 K: Default + Eq + Hash + Ord,
1256 V: Default,
1257 B: Buf,
1258 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1259 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1260 {
1261 merge_with_default(key_merge, val_merge, V::default(), values, buf, ctx)
1262 }
1263
1264 pub fn encoded_len<K, V, KL, VL>(
1266 key_encoded_len: KL,
1267 val_encoded_len: VL,
1268 tag: u32,
1269 values: &$map_ty<K, V>,
1270 ) -> usize
1271 where
1272 K: Default + Eq + Hash + Ord,
1273 V: Default + PartialEq,
1274 KL: Fn(u32, &K) -> usize,
1275 VL: Fn(u32, &V) -> usize,
1276 {
1277 encoded_len_with_default(key_encoded_len, val_encoded_len, &V::default(), tag, values)
1278 }
1279
1280 pub fn encode_with_default<K, V, B, KE, KL, VE, VL>(
1285 key_encode: KE,
1286 key_encoded_len: KL,
1287 val_encode: VE,
1288 val_encoded_len: VL,
1289 val_default: &V,
1290 tag: u32,
1291 values: &$map_ty<K, V>,
1292 buf: &mut B,
1293 ) where
1294 K: Default + Eq + Hash + Ord,
1295 V: PartialEq,
1296 B: BufMut,
1297 KE: Fn(u32, &K, &mut B),
1298 KL: Fn(u32, &K) -> usize,
1299 VE: Fn(u32, &V, &mut B),
1300 VL: Fn(u32, &V) -> usize,
1301 {
1302 for (key, val) in values.iter() {
1303 let skip_key = key == &K::default();
1304 let skip_val = val == val_default;
1305
1306 let len = (if skip_key { 0 } else { key_encoded_len(1, key) })
1307 + (if skip_val { 0 } else { val_encoded_len(2, val) });
1308
1309 encode_key(tag, WireType::LengthDelimited, buf);
1310 encode_varint(len as u64, buf);
1311 if !skip_key {
1312 key_encode(1, key, buf);
1313 }
1314 if !skip_val {
1315 val_encode(2, val, buf);
1316 }
1317 }
1318 }
1319
1320 pub fn merge_with_default<K, V, B, KM, VM>(
1325 key_merge: KM,
1326 val_merge: VM,
1327 val_default: V,
1328 values: &mut $map_ty<K, V>,
1329 buf: &mut B,
1330 ctx: DecodeContext,
1331 ) -> Result<(), DecodeError>
1332 where
1333 K: Default + Eq + Hash + Ord,
1334 B: Buf,
1335 KM: Fn(WireType, &mut K, &mut B, DecodeContext) -> Result<(), DecodeError>,
1336 VM: Fn(WireType, &mut V, &mut B, DecodeContext) -> Result<(), DecodeError>,
1337 {
1338 let mut key = Default::default();
1339 let mut val = val_default;
1340 ctx.limit_reached()?;
1341 merge_loop(
1342 &mut (&mut key, &mut val),
1343 buf,
1344 ctx.enter_recursion(),
1345 |&mut (ref mut key, ref mut val), buf, ctx| {
1346 let (tag, wire_type) = decode_key(buf)?;
1347 match tag {
1348 1 => key_merge(wire_type, key, buf, ctx),
1349 2 => val_merge(wire_type, val, buf, ctx),
1350 _ => skip_field(wire_type, tag, buf, ctx),
1351 }
1352 },
1353 )?;
1354 values.insert(key, val);
1355
1356 Ok(())
1357 }
1358
1359 pub fn encoded_len_with_default<K, V, KL, VL>(
1364 key_encoded_len: KL,
1365 val_encoded_len: VL,
1366 val_default: &V,
1367 tag: u32,
1368 values: &$map_ty<K, V>,
1369 ) -> usize
1370 where
1371 K: Default + Eq + Hash + Ord,
1372 V: PartialEq,
1373 KL: Fn(u32, &K) -> usize,
1374 VL: Fn(u32, &V) -> usize,
1375 {
1376 key_len(tag) * values.len()
1377 + values
1378 .iter()
1379 .map(|(key, val)| {
1380 let len = (if key == &K::default() {
1381 0
1382 } else {
1383 key_encoded_len(1, key)
1384 }) + (if val == val_default {
1385 0
1386 } else {
1387 val_encoded_len(2, val)
1388 });
1389 encoded_len_varint(len as u64) + len
1390 })
1391 .sum::<usize>()
1392 }
1393 };
1394}
1395
1396#[cfg(feature = "std")]
1397pub mod hash_map {
1398 use std::collections::HashMap;
1399 map!(HashMap);
1400}
1401
1402pub mod btree_map {
1403 map!(BTreeMap);
1404}
1405
1406#[cfg(test)]
1407mod test {
1408 use alloc::string::ToString;
1409 use core::borrow::Borrow;
1410 use core::fmt::Debug;
1411 use core::u64;
1412
1413 use ::bytes::{Bytes, BytesMut};
1414 use proptest::{prelude::*, test_runner::TestCaseResult};
1415
1416 use crate::encoding::*;
1417
1418 pub fn check_type<T, B>(
1419 value: T,
1420 tag: u32,
1421 wire_type: WireType,
1422 encode: fn(u32, &B, &mut BytesMut),
1423 merge: fn(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1424 encoded_len: fn(u32, &B) -> usize,
1425 ) -> TestCaseResult
1426 where
1427 T: Debug + Default + PartialEq + Borrow<B>,
1428 B: ?Sized,
1429 {
1430 prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG);
1431
1432 let expected_len = encoded_len(tag, value.borrow());
1433
1434 let mut buf = BytesMut::with_capacity(expected_len);
1435 encode(tag, value.borrow(), &mut buf);
1436
1437 let mut buf = buf.freeze();
1438
1439 prop_assert_eq!(
1440 buf.remaining(),
1441 expected_len,
1442 "encoded_len wrong; expected: {}, actual: {}",
1443 expected_len,
1444 buf.remaining()
1445 );
1446
1447 if !buf.has_remaining() {
1448 return Ok(());
1450 }
1451
1452 let (decoded_tag, decoded_wire_type) =
1453 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1454 prop_assert_eq!(
1455 tag,
1456 decoded_tag,
1457 "decoded tag does not match; expected: {}, actual: {}",
1458 tag,
1459 decoded_tag
1460 );
1461
1462 prop_assert_eq!(
1463 wire_type,
1464 decoded_wire_type,
1465 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1466 wire_type,
1467 decoded_wire_type,
1468 );
1469
1470 match wire_type {
1471 WireType::SixtyFourBit if buf.remaining() != 8 => Err(TestCaseError::fail(format!(
1472 "64bit wire type illegal remaining: {}, tag: {}",
1473 buf.remaining(),
1474 tag
1475 ))),
1476 WireType::ThirtyTwoBit if buf.remaining() != 4 => Err(TestCaseError::fail(format!(
1477 "32bit wire type illegal remaining: {}, tag: {}",
1478 buf.remaining(),
1479 tag
1480 ))),
1481 _ => Ok(()),
1482 }?;
1483
1484 let mut roundtrip_value = T::default();
1485 merge(
1486 wire_type,
1487 &mut roundtrip_value,
1488 &mut buf,
1489 DecodeContext::default(),
1490 )
1491 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1492
1493 prop_assert!(
1494 !buf.has_remaining(),
1495 "expected buffer to be empty, remaining: {}",
1496 buf.remaining()
1497 );
1498
1499 prop_assert_eq!(value, roundtrip_value);
1500
1501 Ok(())
1502 }
1503
1504 pub fn check_collection_type<T, B, E, M, L>(
1505 value: T,
1506 tag: u32,
1507 wire_type: WireType,
1508 encode: E,
1509 mut merge: M,
1510 encoded_len: L,
1511 ) -> TestCaseResult
1512 where
1513 T: Debug + Default + PartialEq + Borrow<B>,
1514 B: ?Sized,
1515 E: FnOnce(u32, &B, &mut BytesMut),
1516 M: FnMut(WireType, &mut T, &mut Bytes, DecodeContext) -> Result<(), DecodeError>,
1517 L: FnOnce(u32, &B) -> usize,
1518 {
1519 prop_assume!(MIN_TAG <= tag && tag <= MAX_TAG);
1520
1521 let expected_len = encoded_len(tag, value.borrow());
1522
1523 let mut buf = BytesMut::with_capacity(expected_len);
1524 encode(tag, value.borrow(), &mut buf);
1525
1526 let mut buf = buf.freeze();
1527
1528 prop_assert_eq!(
1529 buf.remaining(),
1530 expected_len,
1531 "encoded_len wrong; expected: {}, actual: {}",
1532 expected_len,
1533 buf.remaining()
1534 );
1535
1536 let mut roundtrip_value = Default::default();
1537 while buf.has_remaining() {
1538 let (decoded_tag, decoded_wire_type) =
1539 decode_key(&mut buf).map_err(|error| TestCaseError::fail(error.to_string()))?;
1540
1541 prop_assert_eq!(
1542 tag,
1543 decoded_tag,
1544 "decoded tag does not match; expected: {}, actual: {}",
1545 tag,
1546 decoded_tag
1547 );
1548
1549 prop_assert_eq!(
1550 wire_type,
1551 decoded_wire_type,
1552 "decoded wire type does not match; expected: {:?}, actual: {:?}",
1553 wire_type,
1554 decoded_wire_type
1555 );
1556
1557 merge(
1558 wire_type,
1559 &mut roundtrip_value,
1560 &mut buf,
1561 DecodeContext::default(),
1562 )
1563 .map_err(|error| TestCaseError::fail(error.to_string()))?;
1564 }
1565
1566 prop_assert_eq!(value, roundtrip_value);
1567
1568 Ok(())
1569 }
1570
1571 #[test]
1572 fn string_merge_invalid_utf8() {
1573 let mut s = String::new();
1574 let buf = b"\x02\x80\x80";
1575
1576 let r = string::merge(
1577 WireType::LengthDelimited,
1578 &mut s,
1579 &mut &buf[..],
1580 DecodeContext::default(),
1581 );
1582 r.expect_err("must be an error");
1583 assert!(s.is_empty());
1584 }
1585
1586 #[test]
1587 fn varint() {
1588 fn check(value: u64, mut encoded: &[u8]) {
1589 #![allow(clippy::clone_double_ref)]
1591
1592 let mut buf = Vec::with_capacity(1);
1594 encode_varint(value, &mut buf);
1595 assert_eq!(buf, encoded);
1596
1597 let mut buf = Vec::with_capacity(100);
1599 encode_varint(value, &mut buf);
1600 assert_eq!(buf, encoded);
1601
1602 assert_eq!(encoded_len_varint(value), encoded.len());
1603
1604 let roundtrip_value = decode_varint(&mut encoded.clone()).expect("decoding failed");
1605 assert_eq!(value, roundtrip_value);
1606
1607 let roundtrip_value = decode_varint_slow(&mut encoded).expect("slow decoding failed");
1608 assert_eq!(value, roundtrip_value);
1609 }
1610
1611 check(2u64.pow(0) - 1, &[0x00]);
1612 check(2u64.pow(0), &[0x01]);
1613
1614 check(2u64.pow(7) - 1, &[0x7F]);
1615 check(2u64.pow(7), &[0x80, 0x01]);
1616 check(300, &[0xAC, 0x02]);
1617
1618 check(2u64.pow(14) - 1, &[0xFF, 0x7F]);
1619 check(2u64.pow(14), &[0x80, 0x80, 0x01]);
1620
1621 check(2u64.pow(21) - 1, &[0xFF, 0xFF, 0x7F]);
1622 check(2u64.pow(21), &[0x80, 0x80, 0x80, 0x01]);
1623
1624 check(2u64.pow(28) - 1, &[0xFF, 0xFF, 0xFF, 0x7F]);
1625 check(2u64.pow(28), &[0x80, 0x80, 0x80, 0x80, 0x01]);
1626
1627 check(2u64.pow(35) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1628 check(2u64.pow(35), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1629
1630 check(2u64.pow(42) - 1, &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
1631 check(2u64.pow(42), &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01]);
1632
1633 check(
1634 2u64.pow(49) - 1,
1635 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1636 );
1637 check(
1638 2u64.pow(49),
1639 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1640 );
1641
1642 check(
1643 2u64.pow(56) - 1,
1644 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1645 );
1646 check(
1647 2u64.pow(56),
1648 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1649 );
1650
1651 check(
1652 2u64.pow(63) - 1,
1653 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
1654 );
1655 check(
1656 2u64.pow(63),
1657 &[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x01],
1658 );
1659
1660 check(
1661 u64::MAX,
1662 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01],
1663 );
1664 }
1665
1666 #[test]
1667 fn varint_overflow() {
1668 let mut u64_max_plus_one: &[u8] =
1669 &[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x02];
1670
1671 decode_varint(&mut u64_max_plus_one).expect_err("decoding u64::MAX + 1 succeeded");
1672 decode_varint_slow(&mut u64_max_plus_one)
1673 .expect_err("slow decoding u64::MAX + 1 succeeded");
1674 }
1675
1676 #[cfg(feature = "std")]
1680 macro_rules! map_tests {
1681 (keys: $keys:tt,
1682 vals: $vals:tt) => {
1683 mod hash_map {
1684 map_tests!(@private HashMap, hash_map, $keys, $vals);
1685 }
1686 mod btree_map {
1687 map_tests!(@private BTreeMap, btree_map, $keys, $vals);
1688 }
1689 };
1690
1691 (@private $map_type:ident,
1692 $mod_name:ident,
1693 [$(($key_ty:ty, $key_proto:ident)),*],
1694 $vals:tt) => {
1695 $(
1696 mod $key_proto {
1697 use std::collections::$map_type;
1698
1699 use proptest::prelude::*;
1700
1701 use crate::encoding::*;
1702 use crate::encoding::test::check_collection_type;
1703
1704 map_tests!(@private $map_type, $mod_name, ($key_ty, $key_proto), $vals);
1705 }
1706 )*
1707 };
1708
1709 (@private $map_type:ident,
1710 $mod_name:ident,
1711 ($key_ty:ty, $key_proto:ident),
1712 [$(($val_ty:ty, $val_proto:ident)),*]) => {
1713 $(
1714 proptest! {
1715 #[test]
1716 fn $val_proto(values: $map_type<$key_ty, $val_ty>, tag in MIN_TAG..=MAX_TAG) {
1717 check_collection_type(values, tag, WireType::LengthDelimited,
1718 |tag, values, buf| {
1719 $mod_name::encode($key_proto::encode,
1720 $key_proto::encoded_len,
1721 $val_proto::encode,
1722 $val_proto::encoded_len,
1723 tag,
1724 values,
1725 buf)
1726 },
1727 |wire_type, values, buf, ctx| {
1728 check_wire_type(WireType::LengthDelimited, wire_type)?;
1729 $mod_name::merge($key_proto::merge,
1730 $val_proto::merge,
1731 values,
1732 buf,
1733 ctx)
1734 },
1735 |tag, values| {
1736 $mod_name::encoded_len($key_proto::encoded_len,
1737 $val_proto::encoded_len,
1738 tag,
1739 values)
1740 })?;
1741 }
1742 }
1743 )*
1744 };
1745 }
1746
1747 #[cfg(feature = "std")]
1748 map_tests!(keys: [
1749 (i32, int32),
1750 (i64, int64),
1751 (u32, uint32),
1752 (u64, uint64),
1753 (i32, sint32),
1754 (i64, sint64),
1755 (u32, fixed32),
1756 (u64, fixed64),
1757 (i32, sfixed32),
1758 (i64, sfixed64),
1759 (bool, bool),
1760 (String, string)
1761 ],
1762 vals: [
1763 (f32, float),
1764 (f64, double),
1765 (i32, int32),
1766 (i64, int64),
1767 (u32, uint32),
1768 (u64, uint64),
1769 (i32, sint32),
1770 (i64, sint64),
1771 (u32, fixed32),
1772 (u64, fixed64),
1773 (i32, sfixed32),
1774 (i64, sfixed64),
1775 (bool, bool),
1776 (String, string),
1777 (Vec<u8>, bytes)
1778 ]);
1779}