1use anyhow::{bail, Error};
2use proc_macro2::{Span, TokenStream};
3use quote::quote;
4use syn::{Ident, Lit, Meta, MetaNameValue, NestedMeta};
5
6use crate::field::{scalar, set_option, tag_attr};
7
8#[derive(Clone, Debug)]
9pub enum MapTy {
10 HashMap,
11 BTreeMap,
12}
13
14impl MapTy {
15 fn from_str(s: &str) -> Option<MapTy> {
16 match s {
17 "map" | "hash_map" => Some(MapTy::HashMap),
18 "btree_map" => Some(MapTy::BTreeMap),
19 _ => None,
20 }
21 }
22
23 fn module(&self) -> Ident {
24 match *self {
25 MapTy::HashMap => Ident::new("hash_map", Span::call_site()),
26 MapTy::BTreeMap => Ident::new("btree_map", Span::call_site()),
27 }
28 }
29
30 fn lib(&self) -> TokenStream {
31 match self {
32 MapTy::HashMap => quote! { std },
33 MapTy::BTreeMap => quote! { prost::alloc },
34 }
35 }
36}
37
38fn fake_scalar(ty: scalar::Ty) -> scalar::Field {
39 let kind = scalar::Kind::Plain(scalar::DefaultValue::new(&ty));
40 scalar::Field {
41 ty,
42 kind,
43 tag: 0, }
45}
46
47#[derive(Clone)]
48pub struct Field {
49 pub map_ty: MapTy,
50 pub key_ty: scalar::Ty,
51 pub value_ty: ValueTy,
52 pub tag: u32,
53}
54
55impl Field {
56 pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
57 let mut types = None;
58 let mut tag = None;
59
60 for attr in attrs {
61 if let Some(t) = tag_attr(attr)? {
62 set_option(&mut tag, t, "duplicate tag attributes")?;
63 } else if let Some(map_ty) = attr
64 .path()
65 .get_ident()
66 .and_then(|i| MapTy::from_str(&i.to_string()))
67 {
68 let (k, v): (String, String) = match &*attr {
69 Meta::NameValue(MetaNameValue {
70 lit: Lit::Str(lit), ..
71 }) => {
72 let items = lit.value();
73 let mut items = items.split(',').map(ToString::to_string);
74 let k = items.next().unwrap();
75 let v = match items.next() {
76 Some(k) => k,
77 None => bail!("invalid map attribute: must have key and value types"),
78 };
79 if items.next().is_some() {
80 bail!("invalid map attribute: {:?}", attr);
81 }
82 (k, v)
83 }
84 Meta::List(meta_list) => {
85 if meta_list.nested.len() != 2 {
87 bail!("invalid map attribute: must contain key and value types");
88 }
89 let k = match &meta_list.nested[0] {
90 NestedMeta::Meta(Meta::Path(k)) if k.get_ident().is_some() => {
91 k.get_ident().unwrap().to_string()
92 }
93 _ => bail!("invalid map attribute: key must be an identifier"),
94 };
95 let v = match &meta_list.nested[1] {
96 NestedMeta::Meta(Meta::Path(v)) if v.get_ident().is_some() => {
97 v.get_ident().unwrap().to_string()
98 }
99 _ => bail!("invalid map attribute: value must be an identifier"),
100 };
101 (k, v)
102 }
103 _ => return Ok(None),
104 };
105 set_option(
106 &mut types,
107 (map_ty, key_ty_from_str(&k)?, ValueTy::from_str(&v)?),
108 "duplicate map type attribute",
109 )?;
110 } else {
111 return Ok(None);
112 }
113 }
114
115 Ok(match (types, tag.or(inferred_tag)) {
116 (Some((map_ty, key_ty, value_ty)), Some(tag)) => Some(Field {
117 map_ty,
118 key_ty,
119 value_ty,
120 tag,
121 }),
122 _ => None,
123 })
124 }
125
126 pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
127 Field::new(attrs, None)
128 }
129
130 pub fn encode(&self, ident: TokenStream) -> TokenStream {
132 let tag = self.tag;
133 let key_mod = self.key_ty.module();
134 let ke = quote!(::prost::encoding::#key_mod::encode);
135 let kl = quote!(::prost::encoding::#key_mod::encoded_len);
136 let module = self.map_ty.module();
137 match &self.value_ty {
138 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
139 let default = quote!(#ty::default() as i32);
140 quote! {
141 ::prost::encoding::#module::encode_with_default(
142 #ke,
143 #kl,
144 ::prost::encoding::int32::encode,
145 ::prost::encoding::int32::encoded_len,
146 &(#default),
147 #tag,
148 &#ident,
149 buf,
150 );
151 }
152 }
153 ValueTy::Scalar(value_ty) => {
154 let val_mod = value_ty.module();
155 let ve = quote!(::prost::encoding::#val_mod::encode);
156 let vl = quote!(::prost::encoding::#val_mod::encoded_len);
157 quote! {
158 ::prost::encoding::#module::encode(
159 #ke,
160 #kl,
161 #ve,
162 #vl,
163 #tag,
164 &#ident,
165 buf,
166 );
167 }
168 }
169 ValueTy::Message => quote! {
170 ::prost::encoding::#module::encode(
171 #ke,
172 #kl,
173 ::prost::encoding::message::encode,
174 ::prost::encoding::message::encoded_len,
175 #tag,
176 &#ident,
177 buf,
178 );
179 },
180 }
181 }
182
183 pub fn merge(&self, ident: TokenStream) -> TokenStream {
186 let key_mod = self.key_ty.module();
187 let km = quote!(::prost::encoding::#key_mod::merge);
188 let module = self.map_ty.module();
189 match &self.value_ty {
190 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
191 let default = quote!(#ty::default() as i32);
192 quote! {
193 ::prost::encoding::#module::merge_with_default(
194 #km,
195 ::prost::encoding::int32::merge,
196 #default,
197 &mut #ident,
198 buf,
199 ctx,
200 )
201 }
202 }
203 ValueTy::Scalar(value_ty) => {
204 let val_mod = value_ty.module();
205 let vm = quote!(::prost::encoding::#val_mod::merge);
206 quote!(::prost::encoding::#module::merge(#km, #vm, &mut #ident, buf, ctx))
207 }
208 ValueTy::Message => quote! {
209 ::prost::encoding::#module::merge(
210 #km,
211 ::prost::encoding::message::merge,
212 &mut #ident,
213 buf,
214 ctx,
215 )
216 },
217 }
218 }
219
220 pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
222 let tag = self.tag;
223 let key_mod = self.key_ty.module();
224 let kl = quote!(::prost::encoding::#key_mod::encoded_len);
225 let module = self.map_ty.module();
226 match &self.value_ty {
227 ValueTy::Scalar(scalar::Ty::Enumeration(ty)) => {
228 let default = quote!(#ty::default() as i32);
229 quote! {
230 ::prost::encoding::#module::encoded_len_with_default(
231 #kl,
232 ::prost::encoding::int32::encoded_len,
233 &(#default),
234 #tag,
235 &#ident,
236 )
237 }
238 }
239 ValueTy::Scalar(value_ty) => {
240 let val_mod = value_ty.module();
241 let vl = quote!(::prost::encoding::#val_mod::encoded_len);
242 quote!(::prost::encoding::#module::encoded_len(#kl, #vl, #tag, &#ident))
243 }
244 ValueTy::Message => quote! {
245 ::prost::encoding::#module::encoded_len(
246 #kl,
247 ::prost::encoding::message::encoded_len,
248 #tag,
249 &#ident,
250 )
251 },
252 }
253 }
254
255 pub fn clear(&self, ident: TokenStream) -> TokenStream {
256 quote!(#ident.clear())
257 }
258
259 pub fn methods(&self, ident: &Ident) -> Option<TokenStream> {
261 if let ValueTy::Scalar(scalar::Ty::Enumeration(ty)) = &self.value_ty {
262 let key_ty = self.key_ty.rust_type();
263 let key_ref_ty = self.key_ty.rust_ref_type();
264
265 let get = Ident::new(&format!("get_{}", ident), Span::call_site());
266 let insert = Ident::new(&format!("insert_{}", ident), Span::call_site());
267 let take_ref = if self.key_ty.is_numeric() {
268 quote!(&)
269 } else {
270 quote!()
271 };
272
273 let get_doc = format!(
274 "Returns the enum value for the corresponding key in `{}`, \
275 or `None` if the entry does not exist or it is not a valid enum value.",
276 ident,
277 );
278 let insert_doc = format!("Inserts a key value pair into `{}`.", ident);
279 Some(quote! {
280 #[doc=#get_doc]
281 pub fn #get(&self, key: #key_ref_ty) -> ::core::option::Option<#ty> {
282 self.#ident.get(#take_ref key).cloned().and_then(#ty::from_i32)
283 }
284 #[doc=#insert_doc]
285 pub fn #insert(&mut self, key: #key_ty, value: #ty) -> ::core::option::Option<#ty> {
286 self.#ident.insert(key, value as i32).and_then(#ty::from_i32)
287 }
288 })
289 } else {
290 None
291 }
292 }
293
294 pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
299 let type_name = match self.map_ty {
300 MapTy::HashMap => Ident::new("HashMap", Span::call_site()),
301 MapTy::BTreeMap => Ident::new("BTreeMap", Span::call_site()),
302 };
303
304 let key_wrapper = fake_scalar(self.key_ty.clone()).debug(quote!(KeyWrapper));
306 let key = self.key_ty.rust_type();
307 let value_wrapper = self.value_ty.debug();
308 let libname = self.map_ty.lib();
309 let fmt = quote! {
310 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
311 #key_wrapper
312 #value_wrapper
313 let mut builder = f.debug_map();
314 for (k, v) in self.0 {
315 builder.entry(&KeyWrapper(k), &ValueWrapper(v));
316 }
317 builder.finish()
318 }
319 };
320 match &self.value_ty {
321 ValueTy::Scalar(ty) => {
322 if let &scalar::Ty::Bytes(_) = ty {
323 return quote! {
324 struct #wrapper_name<'a>(&'a dyn ::core::fmt::Debug);
325 impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
326 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
327 self.0.fmt(f)
328 }
329 }
330 };
331 }
332
333 let value = ty.rust_type();
334 quote! {
335 struct #wrapper_name<'a>(&'a ::#libname::collections::#type_name<#key, #value>);
336 impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
337 #fmt
338 }
339 }
340 }
341 ValueTy::Message => quote! {
342 struct #wrapper_name<'a, V: 'a>(&'a ::#libname::collections::#type_name<#key, V>);
343 impl<'a, V> ::core::fmt::Debug for #wrapper_name<'a, V>
344 where
345 V: ::core::fmt::Debug + 'a,
346 {
347 #fmt
348 }
349 },
350 }
351 }
352}
353
354fn key_ty_from_str(s: &str) -> Result<scalar::Ty, Error> {
355 let ty = scalar::Ty::from_str(s)?;
356 match ty {
357 scalar::Ty::Int32
358 | scalar::Ty::Int64
359 | scalar::Ty::Uint32
360 | scalar::Ty::Uint64
361 | scalar::Ty::Sint32
362 | scalar::Ty::Sint64
363 | scalar::Ty::Fixed32
364 | scalar::Ty::Fixed64
365 | scalar::Ty::Sfixed32
366 | scalar::Ty::Sfixed64
367 | scalar::Ty::Bool
368 | scalar::Ty::String => Ok(ty),
369 _ => bail!("invalid map key type: {}", s),
370 }
371}
372
373#[derive(Clone, Debug, PartialEq, Eq)]
375pub enum ValueTy {
376 Scalar(scalar::Ty),
377 Message,
378}
379
380impl ValueTy {
381 fn from_str(s: &str) -> Result<ValueTy, Error> {
382 if let Ok(ty) = scalar::Ty::from_str(s) {
383 Ok(ValueTy::Scalar(ty))
384 } else if s.trim() == "message" {
385 Ok(ValueTy::Message)
386 } else {
387 bail!("invalid map value type: {}", s);
388 }
389 }
390
391 fn debug(&self) -> TokenStream {
396 match self {
397 ValueTy::Scalar(ty) => fake_scalar(ty.clone()).debug(quote!(ValueWrapper)),
398 ValueTy::Message => quote!(
399 fn ValueWrapper<T>(v: T) -> T {
400 v
401 }
402 ),
403 }
404 }
405}