1#![doc(html_root_url = "https://docs.rs/prost-derive/0.9.0")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15 FieldsUnnamed, Ident, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22 let input: DeriveInput = syn::parse(input)?;
23
24 let ident = input.ident;
25
26 let variant_data = match input.data {
27 Data::Struct(variant_data) => variant_data,
28 Data::Enum(..) => bail!("Message can not be derived for an enum"),
29 Data::Union(..) => bail!("Message can not be derived for a union"),
30 };
31
32 let generics = &input.generics;
33 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
34
35 let fields = match variant_data {
36 DataStruct {
37 fields: Fields::Named(FieldsNamed { named: fields, .. }),
38 ..
39 }
40 | DataStruct {
41 fields:
42 Fields::Unnamed(FieldsUnnamed {
43 unnamed: fields, ..
44 }),
45 ..
46 } => fields.into_iter().collect(),
47 DataStruct {
48 fields: Fields::Unit,
49 ..
50 } => Vec::new(),
51 };
52
53 let mut next_tag: u32 = 1;
54 let mut fields = fields
55 .into_iter()
56 .enumerate()
57 .flat_map(|(idx, field)| {
58 let field_ident = field
59 .ident
60 .unwrap_or_else(|| Ident::new(&idx.to_string(), Span::call_site()));
61 match Field::new(field.attrs, Some(next_tag)) {
62 Ok(Some(field)) => {
63 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
64 Some(Ok((field_ident, field)))
65 }
66 Ok(None) => None,
67 Err(err) => Some(Err(
68 err.context(format!("invalid message field {}.{}", ident, field_ident))
69 )),
70 }
71 })
72 .collect::<Result<Vec<_>, _>>()?;
73
74 let unsorted_fields = fields.clone();
76
77 fields.sort_by_key(|&(_, ref field)| field.tags().into_iter().min().unwrap());
82 let fields = fields;
83
84 let mut tags = fields
85 .iter()
86 .flat_map(|&(_, ref field)| field.tags())
87 .collect::<Vec<_>>();
88 let num_tags = tags.len();
89 tags.sort_unstable();
90 tags.dedup();
91 if tags.len() != num_tags {
92 bail!("message {} has fields with duplicate tags", ident);
93 }
94
95 let encoded_len = fields
96 .iter()
97 .map(|&(ref field_ident, ref field)| field.encoded_len(quote!(self.#field_ident)));
98
99 let encode = fields
100 .iter()
101 .map(|&(ref field_ident, ref field)| field.encode(quote!(self.#field_ident)));
102
103 let merge = fields.iter().map(|&(ref field_ident, ref field)| {
104 let merge = field.merge(quote!(value));
105 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
106 let tags = Itertools::intersperse(tags, quote!(|));
107
108 quote! {
109 #(#tags)* => {
110 let mut value = &mut self.#field_ident;
111 #merge.map_err(|mut error| {
112 error.push(STRUCT_NAME, stringify!(#field_ident));
113 error
114 })
115 },
116 }
117 });
118
119 let struct_name = if fields.is_empty() {
120 quote!()
121 } else {
122 quote!(
123 const STRUCT_NAME: &'static str = stringify!(#ident);
124 )
125 };
126
127 let is_struct = true;
129
130 let clear = fields
131 .iter()
132 .map(|&(ref field_ident, ref field)| field.clear(quote!(self.#field_ident)));
133
134 let default = fields.iter().map(|&(ref field_ident, ref field)| {
135 let value = field.default();
136 quote!(#field_ident: #value,)
137 });
138
139 let methods = fields
140 .iter()
141 .flat_map(|&(ref field_ident, ref field)| field.methods(field_ident))
142 .collect::<Vec<_>>();
143 let methods = if methods.is_empty() {
144 quote!()
145 } else {
146 quote! {
147 #[allow(dead_code)]
148 impl #impl_generics #ident #ty_generics #where_clause {
149 #(#methods)*
150 }
151 }
152 };
153
154 let debugs = unsorted_fields.iter().map(|&(ref field_ident, ref field)| {
155 let wrapper = field.debug(quote!(self.#field_ident));
156 let call = if is_struct {
157 quote!(builder.field(stringify!(#field_ident), &wrapper))
158 } else {
159 quote!(builder.field(&wrapper))
160 };
161 quote! {
162 let builder = {
163 let wrapper = #wrapper;
164 #call
165 };
166 }
167 });
168 let debug_builder = if is_struct {
169 quote!(f.debug_struct(stringify!(#ident)))
170 } else {
171 quote!(f.debug_tuple(stringify!(#ident)))
172 };
173
174 let expanded = quote! {
175 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176 #[allow(unused_variables)]
177 fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178 #(#encode)*
179 }
180
181 #[allow(unused_variables)]
182 fn merge_field<B>(
183 &mut self,
184 tag: u32,
185 wire_type: ::prost::encoding::WireType,
186 buf: &mut B,
187 ctx: ::prost::encoding::DecodeContext,
188 ) -> ::core::result::Result<(), ::prost::DecodeError>
189 where B: ::prost::bytes::Buf {
190 #struct_name
191 match tag {
192 #(#merge)*
193 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194 }
195 }
196
197 #[inline]
198 fn encoded_len(&self) -> usize {
199 0 #(+ #encoded_len)*
200 }
201
202 fn clear(&mut self) {
203 #(#clear;)*
204 }
205 }
206
207 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208 fn default() -> Self {
209 #ident {
210 #(#default)*
211 }
212 }
213 }
214
215 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
216 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
217 let mut builder = #debug_builder;
218 #(#debugs;)*
219 builder.finish()
220 }
221 }
222
223 #methods
224 };
225
226 Ok(expanded.into())
227}
228
229#[proc_macro_derive(Message, attributes(prost))]
230pub fn message(input: TokenStream) -> TokenStream {
231 try_message(input).unwrap()
232}
233
234fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
235 let input: DeriveInput = syn::parse(input)?;
236 let ident = input.ident;
237
238 let generics = &input.generics;
239 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
240
241 let punctuated_variants = match input.data {
242 Data::Enum(DataEnum { variants, .. }) => variants,
243 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
244 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
245 };
246
247 let mut variants: Vec<(Ident, Expr)> = Vec::new();
249 for Variant {
250 ident,
251 fields,
252 discriminant,
253 ..
254 } in punctuated_variants
255 {
256 match fields {
257 Fields::Unit => (),
258 Fields::Named(_) | Fields::Unnamed(_) => {
259 bail!("Enumeration variants may not have fields")
260 }
261 }
262
263 match discriminant {
264 Some((_, expr)) => variants.push((ident, expr)),
265 None => bail!("Enumeration variants must have a disriminant"),
266 }
267 }
268
269 if variants.is_empty() {
270 panic!("Enumeration must have at least one variant");
271 }
272
273 let default = variants[0].0.clone();
274
275 let is_valid = variants
276 .iter()
277 .map(|&(_, ref value)| quote!(#value => true));
278 let from = variants.iter().map(
279 |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)),
280 );
281
282 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
283 let from_i32_doc = format!(
284 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
285 ident
286 );
287
288 let expanded = quote! {
289 impl #impl_generics #ident #ty_generics #where_clause {
290 #[doc=#is_valid_doc]
291 pub fn is_valid(value: i32) -> bool {
292 match value {
293 #(#is_valid,)*
294 _ => false,
295 }
296 }
297
298 #[doc=#from_i32_doc]
299 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
300 match value {
301 #(#from,)*
302 _ => ::core::option::Option::None,
303 }
304 }
305 }
306
307 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
308 fn default() -> #ident {
309 #ident::#default
310 }
311 }
312
313 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
314 fn from(value: #ident) -> i32 {
315 value as i32
316 }
317 }
318 };
319
320 Ok(expanded.into())
321}
322
323#[proc_macro_derive(Enumeration, attributes(prost))]
324pub fn enumeration(input: TokenStream) -> TokenStream {
325 try_enumeration(input).unwrap()
326}
327
328fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
329 let input: DeriveInput = syn::parse(input)?;
330
331 let ident = input.ident;
332
333 let variants = match input.data {
334 Data::Enum(DataEnum { variants, .. }) => variants,
335 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
336 Data::Union(..) => bail!("Oneof can not be derived for a union"),
337 };
338
339 let generics = &input.generics;
340 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
341
342 let mut fields: Vec<(Ident, Field)> = Vec::new();
344 for Variant {
345 attrs,
346 ident: variant_ident,
347 fields: variant_fields,
348 ..
349 } in variants
350 {
351 let variant_fields = match variant_fields {
352 Fields::Unit => Punctuated::new(),
353 Fields::Named(FieldsNamed { named: fields, .. })
354 | Fields::Unnamed(FieldsUnnamed {
355 unnamed: fields, ..
356 }) => fields,
357 };
358 if variant_fields.len() != 1 {
359 bail!("Oneof enum variants must have a single field");
360 }
361 match Field::new_oneof(attrs)? {
362 Some(field) => fields.push((variant_ident, field)),
363 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
364 }
365 }
366
367 let mut tags = fields
368 .iter()
369 .flat_map(|&(ref variant_ident, ref field)| -> Result<u32, Error> {
370 if field.tags().len() > 1 {
371 bail!(
372 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
373 ident,
374 variant_ident
375 );
376 }
377 Ok(field.tags()[0])
378 })
379 .collect::<Vec<_>>();
380 tags.sort_unstable();
381 tags.dedup();
382 if tags.len() != fields.len() {
383 panic!("invalid oneof {}: variants have duplicate tags", ident);
384 }
385
386 let encode = fields.iter().map(|&(ref variant_ident, ref field)| {
387 let encode = field.encode(quote!(*value));
388 quote!(#ident::#variant_ident(ref value) => { #encode })
389 });
390
391 let merge = fields.iter().map(|&(ref variant_ident, ref field)| {
392 let tag = field.tags()[0];
393 let merge = field.merge(quote!(value));
394 quote! {
395 #tag => {
396 match field {
397 ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
398 #merge
399 },
400 _ => {
401 let mut owned_value = ::core::default::Default::default();
402 let value = &mut owned_value;
403 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
404 },
405 }
406 }
407 }
408 });
409
410 let encoded_len = fields.iter().map(|&(ref variant_ident, ref field)| {
411 let encoded_len = field.encoded_len(quote!(*value));
412 quote!(#ident::#variant_ident(ref value) => #encoded_len)
413 });
414
415 let debug = fields.iter().map(|&(ref variant_ident, ref field)| {
416 let wrapper = field.debug(quote!(*value));
417 quote!(#ident::#variant_ident(ref value) => {
418 let wrapper = #wrapper;
419 f.debug_tuple(stringify!(#variant_ident))
420 .field(&wrapper)
421 .finish()
422 })
423 });
424
425 let expanded = quote! {
426 impl #impl_generics #ident #ty_generics #where_clause {
427 pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
428 match *self {
429 #(#encode,)*
430 }
431 }
432
433 pub fn merge<B>(
434 field: &mut ::core::option::Option<#ident #ty_generics>,
435 tag: u32,
436 wire_type: ::prost::encoding::WireType,
437 buf: &mut B,
438 ctx: ::prost::encoding::DecodeContext,
439 ) -> ::core::result::Result<(), ::prost::DecodeError>
440 where B: ::prost::bytes::Buf {
441 match tag {
442 #(#merge,)*
443 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
444 }
445 }
446
447 #[inline]
448 pub fn encoded_len(&self) -> usize {
449 match *self {
450 #(#encoded_len,)*
451 }
452 }
453 }
454
455 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
456 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
457 match *self {
458 #(#debug,)*
459 }
460 }
461 }
462 };
463
464 Ok(expanded.into())
465}
466
467#[proc_macro_derive(Oneof, attributes(prost))]
468pub fn oneof(input: TokenStream) -> TokenStream {
469 try_oneof(input).unwrap()
470}