prost_derive/field/
map.rs

1use anyhow::{bail, Error};
2use proc_macro2::{Span, TokenStream};
3use quote::quote;
4use syn::{Ident, Lit, Meta, MetaNameValue, NestedMeta};
5
6use crate::field::{scalar, set_option, tag_attr};
7
8#[derive(Clone, Debug)]
9pub enum MapTy {
10    HashMap,
11    BTreeMap,
12}
13
14impl MapTy {
15    fn from_str(s: &str) -> Option<MapTy> {
16        match s {
17            "map" | "hash_map" => Some(MapTy::HashMap),
18            "btree_map" => Some(MapTy::BTreeMap),
19            _ => None,
20        }
21    }
22
23    fn module(&self) -> Ident {
24        match *self {
25            MapTy::HashMap => Ident::new("hash_map", Span::call_site()),
26            MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()),
27        }
28    }
29
30    fn lib(&self) -> TokenStream {
31        match self {
32            MapTy::HashMap => quote! { std },
33            MapTy::BTreeMap => quote! { prost::alloc },
34        }
35    }
36}
37
38fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
39    let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty));
40    scalar::Field {
41        ty,
42        kind,
43        tag: 0, // Not used here
44    }
45}
46
47#[derive(Clone)]
48pub struct Field {
49    pub map_ty: MapTy,
50    pub key_ty: scalar::Ty,
51    pub value_ty: ValueTy,
52    pub tag: u32,
53}
54
55impl Field {
56    pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
57        let mut types = None;
58        let mut tag = None;
59
60        for attr in attrs {
61            if let Some(t) = tag_attr(attr)? {
62                set_option(&mut tag, t, "duplicate tag attributes")?;
63            } else if let Some(map_ty) = attr
64                .path()
65                .get_ident()
66                .and_then(|i| MapTy::from_str(&i.to_string()))
67            {
68                let (k, v): (String, String) = match &*attr {
69                    Meta::NameValue(MetaNameValue {
70                        lit: Lit::Str(lit), ..
71                    }) => {
72                        let items = lit.value();
73                        let mut items = items.split(',').map(ToString::to_string);
74                        let k = items.next().unwrap();
75                        let v = match items.next() {
76                            Some(k) => k,
77                            None => bail!("invalid map attribute: must have key and value types"),
78                        };
79                        if items.next().is_some() {
80                            bail!("invalid map attribute: {:?}", attr);
81                        }
82                        (k, v)
83                    }
84                    Meta::List(meta_list) => {
85                        // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer.
86                        if meta_list.nested.len() != 2 {
87                            bail!("invalid map attribute: must contain key and value types");
88                        }
89                        let k = match &meta_list.nested[0] {
90                            NestedMeta::Meta(Meta::Path(k)) if k.get_ident().is_some() => {
91                                k.get_ident().unwrap().to_string()
92                            }
93                            _ => bail!("invalid map attribute: key must be an identifier"),
94                        };
95                        let v = match &meta_list.nested[1] {
96                            NestedMeta::Meta(Meta::Path(v)) if v.get_ident().is_some() => {
97                                v.get_ident().unwrap().to_string()
98                            }
99                            _ => bail!("invalid map attribute: value must be an identifier"),
100                        };
101                        (k, v)
102                    }
103                    _ => return Ok(None),
104                };
105                set_option(
106                    &mut types,
107                    (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?),
108                    "duplicate map type attribute",
109                )?;
110            } else {
111                return Ok(None);
112            }
113        }
114
115        Ok(match (types, tag.or(inferred_tag)) {
116            (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field {
117                map_ty,
118                key_ty,
119                value_ty,
120                tag,
121            }),
122            _ => None,
123        })
124    }
125
126    pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
127        Field::new(attrs, None)
128    }
129
130    /// Returns a statement which encodes the map field.
131    pub fn encode(&self, ident: TokenStream) -> TokenStream {
132        let tag = self.tag;
133        let key_mod = self.key_ty.module();
134        let ke = quote!(::prost::encoding::#key_mod::encode);
135        let kl = quote!(::prost::encoding::#key_mod::encoded_len);
136        let module = self.map_ty.module();
137        match &self.value_ty {
138            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
139                let default = quote!(#ty::default() as i32);
140                quote! {
141                    ::prost::encoding::#module::encode_with_default(
142                        #ke,
143                        #kl,
144                        ::prost::encoding::int32::encode,
145                        ::prost::encoding::int32::encoded_len,
146                        &(#default),
147                        #tag,
148                        &#ident,
149                        buf,
150                    );
151                }
152            }
153            ValueTy::Scalar(value_ty) => {
154                let val_mod = value_ty.module();
155                let ve = quote!(::prost::encoding::#val_mod::encode);
156                let vl = quote!(::prost::encoding::#val_mod::encoded_len);
157                quote! {
158                    ::prost::encoding::#module::encode(
159                        #ke,
160                        #kl,
161                        #ve,
162                        #vl,
163                        #tag,
164                        &#ident,
165                        buf,
166                    );
167                }
168            }
169            ValueTy::Message => quote! {
170                ::prost::encoding::#module::encode(
171                    #ke,
172                    #kl,
173                    ::prost::encoding::message::encode,
174                    ::prost::encoding::message::encoded_len,
175                    #tag,
176                    &#ident,
177                    buf,
178                );
179            },
180        }
181    }
182
183    /// Returns an expression which evaluates to the result of merging a decoded key value pair
184    /// into the map.
185    pub fn merge(&self, ident: TokenStream) -> TokenStream {
186        let key_mod = self.key_ty.module();
187        let km = quote!(::prost::encoding::#key_mod::merge);
188        let module = self.map_ty.module();
189        match &self.value_ty {
190            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
191                let default = quote!(#ty::default() as i32);
192                quote! {
193                    ::prost::encoding::#module::merge_with_default(
194                        #km,
195                        ::prost::encoding::int32::merge,
196                        #default,
197                        &mut #ident,
198                        buf,
199                        ctx,
200                    )
201                }
202            }
203            ValueTy::Scalar(value_ty) => {
204                let val_mod = value_ty.module();
205                let vm = quote!(::prost::encoding::#val_mod::merge);
206                quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
207            }
208            ValueTy::Message => quote! {
209                ::prost::encoding::#module::merge(
210                    #km,
211                    ::prost::encoding::message::merge,
212                    &mut #ident,
213                    buf,
214                    ctx,
215                )
216            },
217        }
218    }
219
220    /// Returns an expression which evaluates to the encoded length of the map.
221    pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
222        let tag = self.tag;
223        let key_mod = self.key_ty.module();
224        let kl = quote!(::prost::encoding::#key_mod::encoded_len);
225        let module = self.map_ty.module();
226        match &self.value_ty {
227            ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
228                let default = quote!(#ty::default() as i32);
229                quote! {
230                    ::prost::encoding::#module::encoded_len_with_default(
231                        #kl,
232                        ::prost::encoding::int32::encoded_len,
233                        &(#default),
234                        #tag,
235                        &#ident,
236                    )
237                }
238            }
239            ValueTy::Scalar(value_ty) => {
240                let val_mod = value_ty.module();
241                let vl = quote!(::prost::encoding::#val_mod::encoded_len);
242                quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
243            }
244            ValueTy::Message => quote! {
245                ::prost::encoding::#module::encoded_len(
246                    #kl,
247                    ::prost::encoding::message::encoded_len,
248                    #tag,
249                    &#ident,
250                )
251            },
252        }
253    }
254
255    pub fn clear(&self, ident: TokenStream) -> TokenStream {
256        quote!(#ident.clear())
257    }
258
259    /// Returns methods to embed in the message.
260    pub fn methods(&self, ident: &Ident) -> Option<TokenStream> {
261        if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
262            let key_ty = self.key_ty.rust_type();
263            let key_ref_ty = self.key_ty.rust_ref_type();
264
265            let get = Ident::new(&format!("get_{}", ident), Span::call_site());
266            let insert = Ident::new(&format!("insert_{}", ident), Span::call_site());
267            let take_ref = if self.key_ty.is_numeric() {
268                quote!(&)
269            } else {
270                quote!()
271            };
272
273            let get_doc = format!(
274                "Returns the enum value for the corresponding key in `{}`, \
275                 or `None` if the entry does not exist or it is not a valid enum value.",
276                ident,
277            );
278            let insert_doc = format!("Inserts a key value pair into `{}`.", ident);
279            Some(quote! {
280                #[doc=#get_doc]
281                pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
282                    self.#ident.get(#take_ref key).cloned().and_then(#ty::from_i32)
283                }
284                #[doc=#insert_doc]
285                pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
286                    self.#ident.insert(key, value as i32).and_then(#ty::from_i32)
287                }
288            })
289        } else {
290            None
291        }
292    }
293
294    /// Returns a newtype wrapper around the map, implementing nicer Debug
295    ///
296    /// The Debug tries to convert any enumerations met into the variants if possible, instead of
297    /// outputting the raw numbers.
298    pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
299        let type_name = match self.map_ty {
300            MapTy::HashMap => Ident::new("HashMap", Span::call_site()),
301            MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()),
302        };
303
304        // A fake field for generating the debug wrapper
305        let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper));
306        let key = self.key_ty.rust_type();
307        let value_wrapper = self.value_ty.debug();
308        let libname = self.map_ty.lib();
309        let fmt = quote! {
310            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
311                #key_wrapper
312                #value_wrapper
313                let mut builder = f.debug_map();
314                for (k, v) in self.0 {
315                    builder.entry(&KeyWrapper(k), &ValueWrapper(v));
316                }
317                builder.finish()
318            }
319        };
320        match &self.value_ty {
321            ValueTy::Scalar(ty) => {
322                if let &scalar::Ty::Bytes(_) = ty {
323                    return quote! {
324                        struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug);
325                        impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
326                            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
327                                self.0.fmt(f)
328                            }
329                        }
330                    };
331                }
332
333                let value = ty.rust_type();
334                quote! {
335                    struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>);
336                    impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
337                        #fmt
338                    }
339                }
340            }
341            ValueTy::Message => quote! {
342                struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>);
343                impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V>
344                where
345                    V: ::core::fmt::Debug + 'a,
346                {
347                    #fmt
348                }
349            },
350        }
351    }
352}
353
354fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
355    let ty = scalar::Ty::from_str(s)?;
356    match ty {
357        scalar::Ty::Int32
358        | scalar::Ty::Int64
359        | scalar::Ty::Uint32
360        | scalar::Ty::Uint64
361        | scalar::Ty::Sint32
362        | scalar::Ty::Sint64
363        | scalar::Ty::Fixed32
364        | scalar::Ty::Fixed64
365        | scalar::Ty::Sfixed32
366        | scalar::Ty::Sfixed64
367        | scalar::Ty::Bool
368        | scalar::Ty::String => Ok(ty),
369        _ => bail!("invalid map key type: {}", s),
370    }
371}
372
373/// A map value type.
374#[derive(Clone, Debug, PartialEq, Eq)]
375pub enum ValueTy {
376    Scalar(scalar::Ty),
377    Message,
378}
379
380impl ValueTy {
381    fn from_str(s: &str) -> Result<ValueTy, Error> {
382        if let Ok(ty) = scalar::Ty::from_str(s) {
383            Ok(ValueTy::Scalar(ty))
384        } else if s.trim() == "message" {
385            Ok(ValueTy::Message)
386        } else {
387            bail!("invalid map value type: {}", s);
388        }
389    }
390
391    /// Returns a newtype wrapper around the ValueTy for nicer debug.
392    ///
393    /// If the contained value is enumeration, it tries to convert it to the variant. If not, it
394    /// just forwards the implementation.
395    fn debug(&self) -> TokenStream {
396        match self {
397            ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)),
398            ValueTy::Message => quote!(
399                fn ValueWrapper<T>(v: T) -> T {
400                    v
401                }
402            ),
403        }
404    }
405}