prost_derive/field/
scalar.rs

1use std::convert::TryFrom;
2use std::fmt;
3
4use anyhow::{anyhow, bail, Error};
5use proc_macro2::{Span, TokenStream};
6use quote::{quote, ToTokens, TokenStreamExt};
7use syn::{parse_str, Ident, Lit, LitByteStr, Meta, MetaList, MetaNameValue, NestedMeta, Path};
8
9use crate::field::{bool_attr, set_option, tag_attr, Label};
10
11/// A scalar protobuf field.
12#[derive(Clone)]
13pub struct Field {
14    pub ty: Ty,
15    pub kind: Kind,
16    pub tag: u32,
17}
18
19impl Field {
20    pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
21        let mut ty = None;
22        let mut label = None;
23        let mut packed = None;
24        let mut default = None;
25        let mut tag = None;
26
27        let mut unknown_attrs = Vec::new();
28
29        for attr in attrs {
30            if let Some(t) = Ty::from_attr(attr)? {
31                set_option(&mut ty, t, "duplicate type attributes")?;
32            } else if let Some(p) = bool_attr("packed", attr)? {
33                set_option(&mut packed, p, "duplicate packed attributes")?;
34            } else if let Some(t) = tag_attr(attr)? {
35                set_option(&mut tag, t, "duplicate tag attributes")?;
36            } else if let Some(l) = Label::from_attr(attr) {
37                set_option(&mut label, l, "duplicate label attributes")?;
38            } else if let Some(d) = DefaultValue::from_attr(attr)? {
39                set_option(&mut default, d, "duplicate default attributes")?;
40            } else {
41                unknown_attrs.push(attr);
42            }
43        }
44
45        let ty = match ty {
46            Some(ty) => ty,
47            None => return Ok(None),
48        };
49
50        match unknown_attrs.len() {
51            0 => (),
52            1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
53            _ => bail!("unknown attributes: {:?}", unknown_attrs),
54        }
55
56        let tag = match tag.or(inferred_tag) {
57            Some(tag) => tag,
58            None => bail!("missing tag attribute"),
59        };
60
61        let has_default = default.is_some();
62        let default = default.map_or_else(
63            || Ok(DefaultValue::new(&ty)),
64            |lit| DefaultValue::from_lit(&ty, lit),
65        )?;
66
67        let kind = match (label, packed, has_default) {
68            (None, Some(true), _)
69            | (Some(Label::Optional), Some(true), _)
70            | (Some(Label::Required), Some(true), _) => {
71                bail!("packed attribute may only be applied to repeated fields");
72            }
73            (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => {
74                bail!("packed attribute may only be applied to numeric types");
75            }
76            (Some(Label::Repeated), _, true) => {
77                bail!("repeated fields may not have a default value");
78            }
79
80            (None, _, _) => Kind::Plain(default),
81            (Some(Label::Optional), _, _) => Kind::Optional(default),
82            (Some(Label::Required), _, _) => Kind::Required(default),
83            (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => {
84                Kind::Packed
85            }
86            (Some(Label::Repeated), _, false) => Kind::Repeated,
87        };
88
89        Ok(Some(Field { ty, kind, tag }))
90    }
91
92    pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
93        if let Some(mut field) = Field::new(attrs, None)? {
94            match field.kind {
95                Kind::Plain(default) => {
96                    field.kind = Kind::Required(default);
97                    Ok(Some(field))
98                }
99                Kind::Optional(..) => bail!("invalid optional attribute on oneof field"),
100                Kind::Required(..) => bail!("invalid required attribute on oneof field"),
101                Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"),
102            }
103        } else {
104            Ok(None)
105        }
106    }
107
108    pub fn encode(&self, ident: TokenStream) -> TokenStream {
109        let module = self.ty.module();
110        let encode_fn = match self.kind {
111            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode),
112            Kind::Repeated => quote!(encode_repeated),
113            Kind::Packed => quote!(encode_packed),
114        };
115        let encode_fn = quote!(::prost::encoding::#module::#encode_fn);
116        let tag = self.tag;
117
118        match self.kind {
119            Kind::Plain(ref default) => {
120                let default = default.typed();
121                quote! {
122                    if #ident != #default {
123                        #encode_fn(#tag, &#ident, buf);
124                    }
125                }
126            }
127            Kind::Optional(..) => quote! {
128                if let ::core::option::Option::Some(ref value) = #ident {
129                    #encode_fn(#tag, value, buf);
130                }
131            },
132            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
133                #encode_fn(#tag, &#ident, buf);
134            },
135        }
136    }
137
138    /// Returns an expression which evaluates to the result of merging a decoded
139    /// scalar value into the field.
140    pub fn merge(&self, ident: TokenStream) -> TokenStream {
141        let module = self.ty.module();
142        let merge_fn = match self.kind {
143            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge),
144            Kind::Repeated | Kind::Packed => quote!(merge_repeated),
145        };
146        let merge_fn = quote!(::prost::encoding::#module::#merge_fn);
147
148        match self.kind {
149            Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
150                #merge_fn(wire_type, #ident, buf, ctx)
151            },
152            Kind::Optional(..) => quote! {
153                #merge_fn(wire_type,
154                          #ident.get_or_insert_with(::core::default::Default::default),
155                          buf,
156                          ctx)
157            },
158        }
159    }
160
161    /// Returns an expression which evaluates to the encoded length of the field.
162    pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
163        let module = self.ty.module();
164        let encoded_len_fn = match self.kind {
165            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len),
166            Kind::Repeated => quote!(encoded_len_repeated),
167            Kind::Packed => quote!(encoded_len_packed),
168        };
169        let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn);
170        let tag = self.tag;
171
172        match self.kind {
173            Kind::Plain(ref default) => {
174                let default = default.typed();
175                quote! {
176                    if #ident != #default {
177                        #encoded_len_fn(#tag, &#ident)
178                    } else {
179                        0
180                    }
181                }
182            }
183            Kind::Optional(..) => quote! {
184                #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value))
185            },
186            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
187                #encoded_len_fn(#tag, &#ident)
188            },
189        }
190    }
191
192    pub fn clear(&self, ident: TokenStream) -> TokenStream {
193        match self.kind {
194            Kind::Plain(ref default) | Kind::Required(ref default) => {
195                let default = default.typed();
196                match self.ty {
197                    Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
198                    _ => quote!(#ident = #default),
199                }
200            }
201            Kind::Optional(_) => quote!(#ident = ::core::option::Option::None),
202            Kind::Repeated | Kind::Packed => quote!(#ident.clear()),
203        }
204    }
205
206    /// Returns an expression which evaluates to the default value of the field.
207    pub fn default(&self) -> TokenStream {
208        match self.kind {
209            Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(),
210            Kind::Optional(_) => quote!(::core::option::Option::None),
211            Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()),
212        }
213    }
214
215    /// An inner debug wrapper, around the base type.
216    fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream {
217        if let Ty::Enumeration(ref ty) = self.ty {
218            quote! {
219                struct #wrap_name<'a>(&'a i32);
220                impl<'a> ::core::fmt::Debug for #wrap_name<'a> {
221                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
222                        match #ty::from_i32(*self.0) {
223                            None => ::core::fmt::Debug::fmt(&self.0, f),
224                            Some(en) => ::core::fmt::Debug::fmt(&en, f),
225                        }
226                    }
227                }
228            }
229        } else {
230            quote! {
231                fn #wrap_name<T>(v: T) -> T { v }
232            }
233        }
234    }
235
236    /// Returns a fragment for formatting the field `ident` in `Debug`.
237    pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
238        let wrapper = self.debug_inner(quote!(Inner));
239        let inner_ty = self.ty.rust_type();
240        match self.kind {
241            Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name),
242            Kind::Optional(_) => quote! {
243                struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>);
244                impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
245                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
246                        #wrapper
247                        ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f)
248                    }
249                }
250            },
251            Kind::Repeated | Kind::Packed => {
252                quote! {
253                    struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>);
254                    impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
255                        fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
256                            let mut vec_builder = f.debug_list();
257                            for v in self.0 {
258                                #wrapper
259                                vec_builder.entry(&Inner(v));
260                            }
261                            vec_builder.finish()
262                        }
263                    }
264                }
265            }
266        }
267    }
268
269    /// Returns methods to embed in the message.
270    pub fn methods(&self, ident: &Ident) -> Option<TokenStream> {
271        let mut ident_str = ident.to_string();
272        if ident_str.starts_with("r#") {
273            ident_str = ident_str[2..].to_owned();
274        }
275
276        if let Ty::Enumeration(ref ty) = self.ty {
277            let set = Ident::new(&format!("set_{}", ident_str), Span::call_site());
278            let set_doc = format!("Sets `{}` to the provided enum value.", ident_str);
279            Some(match self.kind {
280                Kind::Plain(ref default) | Kind::Required(ref default) => {
281                    let get_doc = format!(
282                        "Returns the enum value of `{}`, \
283                         or the default if the field is set to an invalid enum value.",
284                        ident_str,
285                    );
286                    quote! {
287                        #[doc=#get_doc]
288                        pub fn #ident(&self) -> #ty {
289                            #ty::from_i32(self.#ident).unwrap_or(#default)
290                        }
291
292                        #[doc=#set_doc]
293                        pub fn #set(&mut self, value: #ty) {
294                            self.#ident = value as i32;
295                        }
296                    }
297                }
298                Kind::Optional(ref default) => {
299                    let get_doc = format!(
300                        "Returns the enum value of `{}`, \
301                         or the default if the field is unset or set to an invalid enum value.",
302                        ident_str,
303                    );
304                    quote! {
305                        #[doc=#get_doc]
306                        pub fn #ident(&self) -> #ty {
307                            self.#ident.and_then(#ty::from_i32).unwrap_or(#default)
308                        }
309
310                        #[doc=#set_doc]
311                        pub fn #set(&mut self, value: #ty) {
312                            self.#ident = ::core::option::Option::Some(value as i32);
313                        }
314                    }
315                }
316                Kind::Repeated | Kind::Packed => {
317                    let iter_doc = format!(
318                        "Returns an iterator which yields the valid enum values contained in `{}`.",
319                        ident_str,
320                    );
321                    let push = Ident::new(&format!("push_{}", ident_str), Span::call_site());
322                    let push_doc = format!("Appends the provided enum value to `{}`.", ident_str);
323                    quote! {
324                        #[doc=#iter_doc]
325                        pub fn #ident(&self) -> ::core::iter::FilterMap<
326                            ::core::iter::Cloned<::core::slice::Iter<i32>>,
327                            fn(i32) -> ::core::option::Option<#ty>,
328                        > {
329                            self.#ident.iter().cloned().filter_map(#ty::from_i32)
330                        }
331                        #[doc=#push_doc]
332                        pub fn #push(&mut self, value: #ty) {
333                            self.#ident.push(value as i32);
334                        }
335                    }
336                }
337            })
338        } else if let Kind::Optional(ref default) = self.kind {
339            let ty = self.ty.rust_ref_type();
340
341            let match_some = if self.ty.is_numeric() {
342                quote!(::core::option::Option::Some(val) => val,)
343            } else {
344                quote!(::core::option::Option::Some(ref val) => &val[..],)
345            };
346
347            let get_doc = format!(
348                "Returns the value of `{0}`, or the default value if `{0}` is unset.",
349                ident_str,
350            );
351
352            Some(quote! {
353                #[doc=#get_doc]
354                pub fn #ident(&self) -> #ty {
355                    match self.#ident {
356                        #match_some
357                        ::core::option::Option::None => #default,
358                    }
359                }
360            })
361        } else {
362            None
363        }
364    }
365}
366
367/// A scalar protobuf field type.
368#[derive(Clone, PartialEq, Eq)]
369pub enum Ty {
370    Double,
371    Float,
372    Int32,
373    Int64,
374    Uint32,
375    Uint64,
376    Sint32,
377    Sint64,
378    Fixed32,
379    Fixed64,
380    Sfixed32,
381    Sfixed64,
382    Bool,
383    String,
384    Bytes(BytesTy),
385    Enumeration(Path),
386}
387
388#[derive(Clone, Debug, PartialEq, Eq)]
389pub enum BytesTy {
390    Vec,
391    Bytes,
392}
393
394impl BytesTy {
395    fn try_from_str(s: &str) -> Result<Self, Error> {
396        match s {
397            "vec" => Ok(BytesTy::Vec),
398            "bytes" => Ok(BytesTy::Bytes),
399            _ => bail!("Invalid bytes type: {}", s),
400        }
401    }
402
403    fn rust_type(&self) -> TokenStream {
404        match self {
405            BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> },
406            BytesTy::Bytes => quote! { ::prost::bytes::Bytes },
407        }
408    }
409}
410
411impl Ty {
412    pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
413        let ty = match *attr {
414            Meta::Path(ref name) if name.is_ident("float") => Ty::Float,
415            Meta::Path(ref name) if name.is_ident("double") => Ty::Double,
416            Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32,
417            Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64,
418            Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32,
419            Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64,
420            Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32,
421            Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64,
422            Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32,
423            Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64,
424            Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
425            Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
426            Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
427            Meta::Path(ref name) if name.is_ident("string") => Ty::String,
428            Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
429            Meta::NameValue(MetaNameValue {
430                ref path,
431                lit: Lit::Str(ref l),
432                ..
433            }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
434            Meta::NameValue(MetaNameValue {
435                ref path,
436                lit: Lit::Str(ref l),
437                ..
438            }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?),
439            Meta::List(MetaList {
440                ref path,
441                ref nested,
442                ..
443            }) if path.is_ident("enumeration") => {
444                // TODO(rustlang/rust#23121): slice pattern matching would make this much nicer.
445                if nested.len() == 1 {
446                    if let NestedMeta::Meta(Meta::Path(ref path)) = nested[0] {
447                        Ty::Enumeration(path.clone())
448                    } else {
449                        bail!("invalid enumeration attribute: item must be an identifier");
450                    }
451                } else {
452                    bail!("invalid enumeration attribute: only a single identifier is supported");
453                }
454            }
455            _ => return Ok(None),
456        };
457        Ok(Some(ty))
458    }
459
460    pub fn from_str(s: &str) -> Result<Ty, Error> {
461        let enumeration_len = "enumeration".len();
462        let error = Err(anyhow!("invalid type: {}", s));
463        let ty = match s.trim() {
464            "float" => Ty::Float,
465            "double" => Ty::Double,
466            "int32" => Ty::Int32,
467            "int64" => Ty::Int64,
468            "uint32" => Ty::Uint32,
469            "uint64" => Ty::Uint64,
470            "sint32" => Ty::Sint32,
471            "sint64" => Ty::Sint64,
472            "fixed32" => Ty::Fixed32,
473            "fixed64" => Ty::Fixed64,
474            "sfixed32" => Ty::Sfixed32,
475            "sfixed64" => Ty::Sfixed64,
476            "bool" => Ty::Bool,
477            "string" => Ty::String,
478            "bytes" => Ty::Bytes(BytesTy::Vec),
479            s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
480                let s = &s[enumeration_len..].trim();
481                match s.chars().next() {
482                    Some('<') | Some('(') => (),
483                    _ => return error,
484                }
485                match s.chars().next_back() {
486                    Some('>') | Some(')') => (),
487                    _ => return error,
488                }
489
490                Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?)
491            }
492            _ => return error,
493        };
494        Ok(ty)
495    }
496
497    /// Returns the type as it appears in protobuf field declarations.
498    pub fn as_str(&self) -> &'static str {
499        match *self {
500            Ty::Double => "double",
501            Ty::Float => "float",
502            Ty::Int32 => "int32",
503            Ty::Int64 => "int64",
504            Ty::Uint32 => "uint32",
505            Ty::Uint64 => "uint64",
506            Ty::Sint32 => "sint32",
507            Ty::Sint64 => "sint64",
508            Ty::Fixed32 => "fixed32",
509            Ty::Fixed64 => "fixed64",
510            Ty::Sfixed32 => "sfixed32",
511            Ty::Sfixed64 => "sfixed64",
512            Ty::Bool => "bool",
513            Ty::String => "string",
514            Ty::Bytes(..) => "bytes",
515            Ty::Enumeration(..) => "enum",
516        }
517    }
518
519    // TODO: rename to 'owned_type'.
520    pub fn rust_type(&self) -> TokenStream {
521        match self {
522            Ty::String => quote!(::prost::alloc::string::String),
523            Ty::Bytes(ty) => ty.rust_type(),
524            _ => self.rust_ref_type(),
525        }
526    }
527
528    // TODO: rename to 'ref_type'
529    pub fn rust_ref_type(&self) -> TokenStream {
530        match *self {
531            Ty::Double => quote!(f64),
532            Ty::Float => quote!(f32),
533            Ty::Int32 => quote!(i32),
534            Ty::Int64 => quote!(i64),
535            Ty::Uint32 => quote!(u32),
536            Ty::Uint64 => quote!(u64),
537            Ty::Sint32 => quote!(i32),
538            Ty::Sint64 => quote!(i64),
539            Ty::Fixed32 => quote!(u32),
540            Ty::Fixed64 => quote!(u64),
541            Ty::Sfixed32 => quote!(i32),
542            Ty::Sfixed64 => quote!(i64),
543            Ty::Bool => quote!(bool),
544            Ty::String => quote!(&str),
545            Ty::Bytes(..) => quote!(&[u8]),
546            Ty::Enumeration(..) => quote!(i32),
547        }
548    }
549
550    pub fn module(&self) -> Ident {
551        match *self {
552            Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
553            _ => Ident::new(self.as_str(), Span::call_site()),
554        }
555    }
556
557    /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
558    pub fn is_numeric(&self) -> bool {
559        !matches!(self, Ty::String | Ty::Bytes(..))
560    }
561}
562
563impl fmt::Debug for Ty {
564    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
565        f.write_str(self.as_str())
566    }
567}
568
569impl fmt::Display for Ty {
570    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
571        f.write_str(self.as_str())
572    }
573}
574
575/// Scalar Protobuf field types.
576#[derive(Clone)]
577pub enum Kind {
578    /// A plain proto3 scalar field.
579    Plain(DefaultValue),
580    /// An optional scalar field.
581    Optional(DefaultValue),
582    /// A required proto2 scalar field.
583    Required(DefaultValue),
584    /// A repeated scalar field.
585    Repeated,
586    /// A packed repeated scalar field.
587    Packed,
588}
589
590/// Scalar Protobuf field default value.
591#[derive(Clone, Debug)]
592pub enum DefaultValue {
593    F64(f64),
594    F32(f32),
595    I32(i32),
596    I64(i64),
597    U32(u32),
598    U64(u64),
599    Bool(bool),
600    String(String),
601    Bytes(Vec<u8>),
602    Enumeration(TokenStream),
603    Path(Path),
604}
605
606impl DefaultValue {
607    pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> {
608        if !attr.path().is_ident("default") {
609            Ok(None)
610        } else if let Meta::NameValue(ref name_value) = *attr {
611            Ok(Some(name_value.lit.clone()))
612        } else {
613            bail!("invalid default value attribute: {:?}", attr)
614        }
615    }
616
617    pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> {
618        let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32;
619        let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64;
620
621        let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32;
622        let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64;
623
624        let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty();
625
626        let default = match lit {
627            Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
628                DefaultValue::I32(lit.base10_parse()?)
629            }
630            Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
631                DefaultValue::I64(lit.base10_parse()?)
632            }
633            Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => {
634                DefaultValue::U32(lit.base10_parse()?)
635            }
636            Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => {
637                DefaultValue::U64(lit.base10_parse()?)
638            }
639
640            Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => {
641                DefaultValue::F32(lit.base10_parse()?)
642            }
643            Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?),
644
645            Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => {
646                DefaultValue::F64(lit.base10_parse()?)
647            }
648            Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),
649
650            Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
651            Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
652            Lit::ByteStr(ref lit)
653                if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
654            {
655                DefaultValue::Bytes(lit.value())
656            }
657
658            Lit::Str(ref lit) => {
659                let value = lit.value();
660                let value = value.trim();
661
662                if let Ty::Enumeration(ref path) = *ty {
663                    let variant = Ident::new(value, Span::call_site());
664                    return Ok(DefaultValue::Enumeration(quote!(#path::#variant)));
665                }
666
667                // Parse special floating point values.
668                if *ty == Ty::Float {
669                    match value {
670                        "inf" => {
671                            return Ok(DefaultValue::Path(parse_str::<Path>(
672                                "::core::f32::INFINITY",
673                            )?));
674                        }
675                        "-inf" => {
676                            return Ok(DefaultValue::Path(parse_str::<Path>(
677                                "::core::f32::NEG_INFINITY",
678                            )?));
679                        }
680                        "nan" => {
681                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?));
682                        }
683                        _ => (),
684                    }
685                }
686                if *ty == Ty::Double {
687                    match value {
688                        "inf" => {
689                            return Ok(DefaultValue::Path(parse_str::<Path>(
690                                "::core::f64::INFINITY",
691                            )?));
692                        }
693                        "-inf" => {
694                            return Ok(DefaultValue::Path(parse_str::<Path>(
695                                "::core::f64::NEG_INFINITY",
696                            )?));
697                        }
698                        "nan" => {
699                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?));
700                        }
701                        _ => (),
702                    }
703                }
704
705                // Rust doesn't have a negative literals, so they have to be parsed specially.
706                if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) {
707                    match lit {
708                        Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
709                            // Initially parse into an i64, so that i32::MIN does not overflow.
710                            let value: i64 = -lit.base10_parse()?;
711                            return Ok(i32::try_from(value).map(DefaultValue::I32)?);
712                        }
713                        Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
714                            // Initially parse into an i128, so that i64::MIN does not overflow.
715                            let value: i128 = -lit.base10_parse()?;
716                            return Ok(i64::try_from(value).map(DefaultValue::I64)?);
717                        }
718                        Lit::Float(ref lit)
719                            if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) =>
720                        {
721                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
722                        }
723                        Lit::Float(ref lit)
724                            if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) =>
725                        {
726                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
727                        }
728                        Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => {
729                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
730                        }
731                        Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => {
732                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
733                        }
734                        _ => (),
735                    }
736                }
737                match syn::parse_str::<Lit>(&value) {
738                    Ok(Lit::Str(_)) => (),
739                    Ok(lit) => return DefaultValue::from_lit(ty, lit),
740                    _ => (),
741                }
742                bail!("invalid default value: {}", quote!(#value));
743            }
744            _ => bail!("invalid default value: {}", quote!(#lit)),
745        };
746
747        Ok(default)
748    }
749
750    pub fn new(ty: &Ty) -> DefaultValue {
751        match *ty {
752            Ty::Float => DefaultValue::F32(0.0),
753            Ty::Double => DefaultValue::F64(0.0),
754            Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0),
755            Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0),
756            Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0),
757            Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),
758
759            Ty::Bool => DefaultValue::Bool(false),
760            Ty::String => DefaultValue::String(String::new()),
761            Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
762            Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
763        }
764    }
765
766    pub fn owned(&self) -> TokenStream {
767        match *self {
768            DefaultValue::String(ref value) if value.is_empty() => {
769                quote!(::prost::alloc::string::String::new())
770            }
771            DefaultValue::String(ref value) => quote!(#value.into()),
772            DefaultValue::Bytes(ref value) if value.is_empty() => {
773                quote!(::core::default::Default::default())
774            }
775            DefaultValue::Bytes(ref value) => {
776                let lit = LitByteStr::new(value, Span::call_site());
777                quote!(#lit.as_ref().into())
778            }
779
780            ref other => other.typed(),
781        }
782    }
783
784    pub fn typed(&self) -> TokenStream {
785        if let DefaultValue::Enumeration(_) = *self {
786            quote!(#self as i32)
787        } else {
788            quote!(#self)
789        }
790    }
791}
792
793impl ToTokens for DefaultValue {
794    fn to_tokens(&self, tokens: &mut TokenStream) {
795        match *self {
796            DefaultValue::F64(value) => value.to_tokens(tokens),
797            DefaultValue::F32(value) => value.to_tokens(tokens),
798            DefaultValue::I32(value) => value.to_tokens(tokens),
799            DefaultValue::I64(value) => value.to_tokens(tokens),
800            DefaultValue::U32(value) => value.to_tokens(tokens),
801            DefaultValue::U64(value) => value.to_tokens(tokens),
802            DefaultValue::Bool(value) => value.to_tokens(tokens),
803            DefaultValue::String(ref value) => value.to_tokens(tokens),
804            DefaultValue::Bytes(ref value) => {
805                let byte_str = LitByteStr::new(value, Span::call_site());
806                tokens.append_all(quote!(#byte_str as &[u8]));
807            }
808            DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
809            DefaultValue::Path(ref value) => value.to_tokens(tokens),
810        }
811    }
812}