prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.9.0")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15    FieldsUnnamed, Ident, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22    let input: DeriveInput = syn::parse(input)?;
23
24    let ident = input.ident;
25
26    let variant_data = match input.data {
27        Data::Struct(variant_data) => variant_data,
28        Data::Enum(..) => bail!("Message can not be derived for an enum"),
29        Data::Union(..) => bail!("Message can not be derived for a union"),
30    };
31
32    let generics = &input.generics;
33    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34
35    let fields = match variant_data {
36        DataStruct {
37            fields: Fields::Named(FieldsNamed { named: fields, .. }),
38            ..
39        }
40        | DataStruct {
41            fields:
42                Fields::Unnamed(FieldsUnnamed {
43                    unnamed: fields, ..
44                }),
45            ..
46        } => fields.into_iter().collect(),
47        DataStruct {
48            fields: Fields::Unit,
49            ..
50        } => Vec::new(),
51    };
52
53    let mut next_tag: u32 = 1;
54    let mut fields = fields
55        .into_iter()
56        .enumerate()
57        .flat_map(|(idx, field)| {
58            let field_ident = field
59                .ident
60                .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
61            match Field::new(field.attrs, Some(next_tag)) {
62                Ok(Some(field)) => {
63                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
64                    Some(Ok((field_ident, field)))
65                }
66                Ok(None) => None,
67                Err(err) => Some(Err(
68                    err.context(format!("invalid message field {}.{}", ident, field_ident))
69                )),
70            }
71        })
72        .collect::<Result<Vec<_>, _>>()?;
73
74    // We want Debug to be in declaration order
75    let unsorted_fields = fields.clone();
76
77    // Sort the fields by tag number so that fields will be encoded in tag order.
78    // TODO: This encodes oneof fields in the position of their lowest tag,
79    // regardless of the currently occupied variant, is that consequential?
80    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
81    fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
82    let fields = fields;
83
84    let mut tags = fields
85        .iter()
86        .flat_map(|&(_, ref field)| field.tags())
87        .collect::<Vec<_>>();
88    let num_tags = tags.len();
89    tags.sort_unstable();
90    tags.dedup();
91    if tags.len() != num_tags {
92        bail!("message {} has fields with duplicate tags", ident);
93    }
94
95    let encoded_len = fields
96        .iter()
97        .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
98
99    let encode = fields
100        .iter()
101        .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
102
103    let merge = fields.iter().map(|&(ref field_ident, ref field)| {
104        let merge = field.merge(quote!(value));
105        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
106        let tags = Itertools::intersperse(tags, quote!(|));
107
108        quote! {
109            #(#tags)* => {
110                let mut value = &mut self.#field_ident;
111                #merge.map_err(|mut error| {
112                    error.push(STRUCT_NAME, stringify!(#field_ident));
113                    error
114                })
115            },
116        }
117    });
118
119    let struct_name = if fields.is_empty() {
120        quote!()
121    } else {
122        quote!(
123            const STRUCT_NAME: &'static str = stringify!(#ident);
124        )
125    };
126
127    // TODO
128    let is_struct = true;
129
130    let clear = fields
131        .iter()
132        .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
133
134    let default = fields.iter().map(|&(ref field_ident, ref field)| {
135        let value = field.default();
136        quote!(#field_ident: #value,)
137    });
138
139    let methods = fields
140        .iter()
141        .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
142        .collect::<Vec<_>>();
143    let methods = if methods.is_empty() {
144        quote!()
145    } else {
146        quote! {
147            #[allow(dead_code)]
148            impl #impl_generics #ident #ty_generics #where_clause {
149                #(#methods)*
150            }
151        }
152    };
153
154    let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
155        let wrapper = field.debug(quote!(self.#field_ident));
156        let call = if is_struct {
157            quote!(builder.field(stringify!(#field_ident), &wrapper))
158        } else {
159            quote!(builder.field(&wrapper))
160        };
161        quote! {
162             let builder = {
163                 let wrapper = #wrapper;
164                 #call
165             };
166        }
167    });
168    let debug_builder = if is_struct {
169        quote!(f.debug_struct(stringify!(#ident)))
170    } else {
171        quote!(f.debug_tuple(stringify!(#ident)))
172    };
173
174    let expanded = quote! {
175        impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176            #[allow(unused_variables)]
177            fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178                #(#encode)*
179            }
180
181            #[allow(unused_variables)]
182            fn merge_field<B>(
183                &mut self,
184                tag: u32,
185                wire_type: ::prost::encoding::WireType,
186                buf: &mut B,
187                ctx: ::prost::encoding::DecodeContext,
188            ) -> ::core::result::Result<(), ::prost::DecodeError>
189            where B: ::prost::bytes::Buf {
190                #struct_name
191                match tag {
192                    #(#merge)*
193                    _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194                }
195            }
196
197            #[inline]
198            fn encoded_len(&self) -> usize {
199                0 #(+ #encoded_len)*
200            }
201
202            fn clear(&mut self) {
203                #(#clear;)*
204            }
205        }
206
207        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208            fn default() -> Self {
209                #ident {
210                    #(#default)*
211                }
212            }
213        }
214
215        impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
216            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
217                let mut builder = #debug_builder;
218                #(#debugs;)*
219                builder.finish()
220            }
221        }
222
223        #methods
224    };
225
226    Ok(expanded.into())
227}
228
229#[proc_macro_derive(Message, attributes(prost))]
230pub fn message(input: TokenStream) -> TokenStream {
231    try_message(input).unwrap()
232}
233
234fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
235    let input: DeriveInput = syn::parse(input)?;
236    let ident = input.ident;
237
238    let generics = &input.generics;
239    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240
241    let punctuated_variants = match input.data {
242        Data::Enum(DataEnum { variants, .. }) => variants,
243        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
244        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
245    };
246
247    // Map the variants into 'fields'.
248    let mut variants: Vec<(Ident, Expr)> = Vec::new();
249    for Variant {
250        ident,
251        fields,
252        discriminant,
253        ..
254    } in punctuated_variants
255    {
256        match fields {
257            Fields::Unit => (),
258            Fields::Named(_) | Fields::Unnamed(_) => {
259                bail!("Enumeration variants may not have fields")
260            }
261        }
262
263        match discriminant {
264            Some((_, expr)) => variants.push((ident, expr)),
265            None => bail!("Enumeration variants must have a disriminant"),
266        }
267    }
268
269    if variants.is_empty() {
270        panic!("Enumeration must have at least one variant");
271    }
272
273    let default = variants[0].0.clone();
274
275    let is_valid = variants
276        .iter()
277        .map(|&(_, ref value)| quote!(#value => true));
278    let from = variants.iter().map(
279        |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
280    );
281
282    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
283    let from_i32_doc = format!(
284        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
285        ident
286    );
287
288    let expanded = quote! {
289        impl #impl_generics #ident #ty_generics #where_clause {
290            #[doc=#is_valid_doc]
291            pub fn is_valid(value: i32) -> bool {
292                match value {
293                    #(#is_valid,)*
294                    _ => false,
295                }
296            }
297
298            #[doc=#from_i32_doc]
299            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
300                match value {
301                    #(#from,)*
302                    _ => ::core::option::Option::None,
303                }
304            }
305        }
306
307        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
308            fn default() -> #ident {
309                #ident::#default
310            }
311        }
312
313        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
314            fn from(value: #ident) -> i32 {
315                value as i32
316            }
317        }
318    };
319
320    Ok(expanded.into())
321}
322
323#[proc_macro_derive(Enumeration, attributes(prost))]
324pub fn enumeration(input: TokenStream) -> TokenStream {
325    try_enumeration(input).unwrap()
326}
327
328fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
329    let input: DeriveInput = syn::parse(input)?;
330
331    let ident = input.ident;
332
333    let variants = match input.data {
334        Data::Enum(DataEnum { variants, .. }) => variants,
335        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
336        Data::Union(..) => bail!("Oneof can not be derived for a union"),
337    };
338
339    let generics = &input.generics;
340    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
341
342    // Map the variants into 'fields'.
343    let mut fields: Vec<(Ident, Field)> = Vec::new();
344    for Variant {
345        attrs,
346        ident: variant_ident,
347        fields: variant_fields,
348        ..
349    } in variants
350    {
351        let variant_fields = match variant_fields {
352            Fields::Unit => Punctuated::new(),
353            Fields::Named(FieldsNamed { named: fields, .. })
354            | Fields::Unnamed(FieldsUnnamed {
355                unnamed: fields, ..
356            }) => fields,
357        };
358        if variant_fields.len() != 1 {
359            bail!("Oneof enum variants must have a single field");
360        }
361        match Field::new_oneof(attrs)? {
362            Some(field) => fields.push((variant_ident, field)),
363            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
364        }
365    }
366
367    let mut tags = fields
368        .iter()
369        .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
370            if field.tags().len() > 1 {
371                bail!(
372                    "invalid oneof variant {}::{}: oneof variants may only have a single tag",
373                    ident,
374                    variant_ident
375                );
376            }
377            Ok(field.tags()[0])
378        })
379        .collect::<Vec<_>>();
380    tags.sort_unstable();
381    tags.dedup();
382    if tags.len() != fields.len() {
383        panic!("invalid oneof {}: variants have duplicate tags", ident);
384    }
385
386    let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
387        let encode = field.encode(quote!(*value));
388        quote!(#ident::#variant_ident(ref value) => { #encode })
389    });
390
391    let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
392        let tag = field.tags()[0];
393        let merge = field.merge(quote!(value));
394        quote! {
395            #tag => {
396                match field {
397                    ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
398                        #merge
399                    },
400                    _ => {
401                        let mut owned_value = ::core::default::Default::default();
402                        let value = &mut owned_value;
403                        #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
404                    },
405                }
406            }
407        }
408    });
409
410    let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
411        let encoded_len = field.encoded_len(quote!(*value));
412        quote!(#ident::#variant_ident(ref value) => #encoded_len)
413    });
414
415    let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
416        let wrapper = field.debug(quote!(*value));
417        quote!(#ident::#variant_ident(ref value) => {
418            let wrapper = #wrapper;
419            f.debug_tuple(stringify!(#variant_ident))
420                .field(&wrapper)
421                .finish()
422        })
423    });
424
425    let expanded = quote! {
426        impl #impl_generics #ident #ty_generics #where_clause {
427            pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
428                match *self {
429                    #(#encode,)*
430                }
431            }
432
433            pub fn merge<B>(
434                field: &mut ::core::option::Option<#ident #ty_generics>,
435                tag: u32,
436                wire_type: ::prost::encoding::WireType,
437                buf: &mut B,
438                ctx: ::prost::encoding::DecodeContext,
439            ) -> ::core::result::Result<(), ::prost::DecodeError>
440            where B: ::prost::bytes::Buf {
441                match tag {
442                    #(#merge,)*
443                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
444                }
445            }
446
447            #[inline]
448            pub fn encoded_len(&self) -> usize {
449                match *self {
450                    #(#encoded_len,)*
451                }
452            }
453        }
454
455        impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
456            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
457                match *self {
458                    #(#debug,)*
459                }
460            }
461        }
462    };
463
464    Ok(expanded.into())
465}
466
467#[proc_macro_derive(Oneof, attributes(prost))]
468pub fn oneof(input: TokenStream) -> TokenStream {
469    try_oneof(input).unwrap()
470}