tracing_attributes/
expand.rs

1use std::iter;
2
3use proc_macro2::TokenStream;
4use quote::TokenStreamExt;
5use quote::{quote, quote_spanned, ToTokens};
6use syn::visit_mut::VisitMut;
7use syn::{
8    punctuated::Punctuated, spanned::Spanned, Expr, ExprAsync, ExprCall, FieldPat, FnArg, Ident,
9    Item, ItemFn, Pat, PatIdent, PatReference, PatStruct, PatTuple, PatTupleStruct, PatType, Path,
10    ReturnType, Signature, Stmt, Token, Type, TypePath,
11};
12
13use crate::{
14    attr::{Field, FieldName, Fields, FormatMode, InstrumentArgs, Level},
15    MaybeItemFn, MaybeItemFnRef,
16};
17
18/// Given an existing function, generate an instrumented version of that function
19pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
20    input: MaybeItemFnRef<'a, B>,
21    args: InstrumentArgs,
22    instrumented_function_name: &str,
23    self_type: Option<&TypePath>,
24) -> proc_macro2::TokenStream {
25    // these are needed ahead of time, as ItemFn contains the function body _and_
26    // isn't representable inside a quote!/quote_spanned! macro
27    // (Syn's ToTokens isn't implemented for ItemFn)
28    let MaybeItemFnRef {
29        outer_attrs,
30        inner_attrs,
31        vis,
32        sig,
33        brace_token,
34        block,
35    } = input;
36
37    let Signature {
38        output,
39        inputs: params,
40        unsafety,
41        asyncness,
42        constness,
43        abi,
44        ident,
45        generics:
46            syn::Generics {
47                params: gen_params,
48                where_clause,
49                lt_token,
50                gt_token,
51            },
52        fn_token,
53        paren_token,
54        variadic,
55    } = sig;
56
57    let warnings = args.warnings();
58
59    let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output {
60        (erase_impl_trait(return_type), return_type.span())
61    } else {
62        // Point at function name if we don't have an explicit return type
63        (syn::parse_quote! { () }, ident.span())
64    };
65    // Install a fake return statement as the first thing in the function
66    // body, so that we eagerly infer that the return type is what we
67    // declared in the async fn signature.
68    // The `#[allow(..)]` is given because the return statement is
69    // unreachable, but does affect inference, so it needs to be written
70    // exactly that way for it to do its magic.
71    let fake_return_edge = quote_spanned! {return_span=>
72        #[allow(
73            unknown_lints,
74            unreachable_code,
75            clippy::diverging_sub_expression,
76            clippy::empty_loop,
77            clippy::let_unit_value,
78            clippy::let_with_type_underscore,
79            clippy::needless_return,
80            clippy::unreachable
81        )]
82        if false {
83            let __tracing_attr_fake_return: #return_type = loop {};
84            return __tracing_attr_fake_return;
85        }
86    };
87    let block = quote! {
88        {
89            #fake_return_edge
90            { #block }
91        }
92    };
93
94    let body = gen_block(
95        &block,
96        params,
97        asyncness.is_some(),
98        args,
99        instrumented_function_name,
100        self_type,
101    );
102
103    let mut result = quote!(
104        #(#outer_attrs) *
105        #vis #constness #asyncness #unsafety #abi #fn_token #ident
106        #lt_token #gen_params #gt_token
107    );
108
109    paren_token.surround(&mut result, |tokens| {
110        params.to_tokens(tokens);
111        variadic.to_tokens(tokens);
112    });
113
114    output.to_tokens(&mut result);
115    where_clause.to_tokens(&mut result);
116
117    brace_token.surround(&mut result, |tokens| {
118        tokens.append_all(inner_attrs);
119        warnings.to_tokens(tokens);
120        body.to_tokens(tokens);
121    });
122
123    result
124}
125
126/// Instrument a block
127fn gen_block<B: ToTokens>(
128    block: &B,
129    params: &Punctuated<FnArg, Token![,]>,
130    async_context: bool,
131    mut args: InstrumentArgs,
132    instrumented_function_name: &str,
133    self_type: Option<&TypePath>,
134) -> proc_macro2::TokenStream {
135    // generate the span's name
136    let span_name = args
137        // did the user override the span's name?
138        .name
139        .as_ref()
140        .map(|name| quote!(#name))
141        .unwrap_or_else(|| quote!(#instrumented_function_name));
142
143    let args_level = args.level();
144    let level = args_level.clone();
145
146    let follows_from = args.follows_from.iter();
147    let follows_from = quote! {
148        #(for cause in #follows_from {
149            __tracing_attr_span.follows_from(cause);
150        })*
151    };
152
153    // generate this inside a closure, so we can return early on errors.
154    let span = (|| {
155        // Pull out the arguments-to-be-skipped first, so we can filter results
156        // below.
157        let param_names: Vec<(Ident, (Ident, RecordType))> = params
158            .clone()
159            .into_iter()
160            .flat_map(|param| match param {
161                FnArg::Typed(PatType { pat, ty, .. }) => {
162                    param_names(*pat, RecordType::parse_from_ty(&ty))
163                }
164                FnArg::Receiver(_) => Box::new(iter::once((
165                    Ident::new("self", param.span()),
166                    RecordType::Debug,
167                ))),
168            })
169            // Little dance with new (user-exposed) names and old (internal)
170            // names of identifiers. That way, we could do the following
171            // even though async_trait (<=0.1.43) rewrites "self" as "_self":
172            // ```
173            // #[async_trait]
174            // impl Foo for FooImpl {
175            //     #[instrument(skip(self))]
176            //     async fn foo(&self, v: usize) {}
177            // }
178            // ```
179            .map(|(x, record_type)| {
180                // if we are inside a function generated by async-trait <=0.1.43, we need to
181                // take care to rewrite "_self" as "self" for 'user convenience'
182                if self_type.is_some() && x == "_self" {
183                    (Ident::new("self", x.span()), (x, record_type))
184                } else {
185                    (x.clone(), (x, record_type))
186                }
187            })
188            .collect();
189
190        for skip in &args.skips {
191            if !param_names.iter().map(|(user, _)| user).any(|y| y == skip) {
192                return quote_spanned! {skip.span()=>
193                    compile_error!("attempting to skip non-existent parameter")
194                };
195            }
196        }
197
198        let target = args.target();
199
200        let parent = args.parent.iter();
201
202        // filter out skipped fields
203        let quoted_fields: Vec<_> = param_names
204            .iter()
205            .filter(|(param, _)| {
206                if args.skip_all || args.skips.contains(param) {
207                    return false;
208                }
209
210                // If any parameters have the same name as a custom field, skip
211                // and allow them to be formatted by the custom field.
212                if let Some(ref fields) = args.fields {
213                    fields.0.iter().all(|Field { ref name, .. }| {
214                        match name {
215                            // #3158: Expressions cannot be evaluated at compile time and will
216                            // incur a runtime cost to de-duplicate.
217                            FieldName::Expr(_) => true,
218                            FieldName::Punctuated(punctuated) => {
219                                let first = punctuated.first();
220                                first != punctuated.last()
221                                    || !first.iter().any(|name| name == &param)
222                            }
223                        }
224                    })
225                } else {
226                    true
227                }
228            })
229            .map(|(user_name, (real_name, record_type))| match record_type {
230                RecordType::Value => quote!(#user_name = #real_name),
231                RecordType::Debug => quote!(#user_name = ::tracing::field::debug(&#real_name)),
232            })
233            .collect();
234
235        // replace every use of a variable with its original name
236        if let Some(Fields(ref mut fields)) = args.fields {
237            let mut replacer = IdentAndTypesRenamer {
238                idents: param_names.into_iter().map(|(a, (b, _))| (a, b)).collect(),
239                types: Vec::new(),
240            };
241
242            // when async-trait <=0.1.43 is in use, replace instances
243            // of the "Self" type inside the fields values
244            if let Some(self_type) = self_type {
245                replacer.types.push(("Self", self_type.clone()));
246            }
247
248            for e in fields.iter_mut().filter_map(|f| f.value.as_mut()) {
249                syn::visit_mut::visit_expr_mut(&mut replacer, e);
250            }
251        }
252
253        let custom_fields = &args.fields;
254
255        quote!(::tracing::span!(
256            target: #target,
257            #(parent: #parent,)*
258            #level,
259            #span_name,
260            #(#quoted_fields,)*
261            #custom_fields
262
263        ))
264    })();
265
266    let target = args.target();
267
268    let err_event = match args.err_args {
269        Some(event_args) => {
270            let level_tokens = event_args.level(Level::Error);
271            match event_args.mode {
272                FormatMode::Default | FormatMode::Display => Some(quote!(
273                    ::tracing::event!(target: #target, #level_tokens, error = %e)
274                )),
275                FormatMode::Debug => Some(quote!(
276                    ::tracing::event!(target: #target, #level_tokens, error = ?e)
277                )),
278            }
279        }
280        _ => None,
281    };
282
283    let ret_event = match args.ret_args {
284        Some(event_args) => {
285            let level_tokens = event_args.level(args_level);
286            match event_args.mode {
287                FormatMode::Display => Some(quote!(
288                    ::tracing::event!(target: #target, #level_tokens, return = %x)
289                )),
290                FormatMode::Default | FormatMode::Debug => Some(quote!(
291                    ::tracing::event!(target: #target, #level_tokens, return = ?x)
292                )),
293            }
294        }
295        _ => None,
296    };
297
298    // Generate the instrumented function body.
299    // If the function is an `async fn`, this will wrap it in an async block,
300    // which is `instrument`ed using `tracing-futures`. Otherwise, this will
301    // enter the span and then perform the rest of the body.
302    // If `err` is in args, instrument any resulting `Err`s.
303    // If `ret` is in args, instrument any resulting `Ok`s when the function
304    // returns `Result`s, otherwise instrument any resulting values.
305    if async_context {
306        let mk_fut = match (err_event, ret_event) {
307            (Some(err_event), Some(ret_event)) => quote_spanned!(block.span()=>
308                async move {
309                    let __match_scrutinee = async move #block.await;
310                    match  __match_scrutinee {
311                        #[allow(clippy::unit_arg)]
312                        Ok(x) => {
313                            #ret_event;
314                            Ok(x)
315                        },
316                        Err(e) => {
317                            #err_event;
318                            Err(e)
319                        }
320                    }
321                }
322            ),
323            (Some(err_event), None) => quote_spanned!(block.span()=>
324                async move {
325                    match async move #block.await {
326                        #[allow(clippy::unit_arg)]
327                        Ok(x) => Ok(x),
328                        Err(e) => {
329                            #err_event;
330                            Err(e)
331                        }
332                    }
333                }
334            ),
335            (None, Some(ret_event)) => quote_spanned!(block.span()=>
336                async move {
337                    let x = async move #block.await;
338                    #ret_event;
339                    x
340                }
341            ),
342            (None, None) => quote_spanned!(block.span()=>
343                async move #block
344            ),
345        };
346
347        return quote!(
348            let __tracing_attr_span = #span;
349            let __tracing_instrument_future = #mk_fut;
350            if !__tracing_attr_span.is_disabled() {
351                #follows_from
352                ::tracing::Instrument::instrument(
353                    __tracing_instrument_future,
354                    __tracing_attr_span
355                )
356                .await
357            } else {
358                __tracing_instrument_future.await
359            }
360        );
361    }
362
363    let span = quote!(
364        // These variables are left uninitialized and initialized only
365        // if the tracing level is statically enabled at this point.
366        // While the tracing level is also checked at span creation
367        // time, that will still create a dummy span, and a dummy guard
368        // and drop the dummy guard later. By lazily initializing these
369        // variables, Rust will generate a drop flag for them and thus
370        // only drop the guard if it was created. This creates code that
371        // is very straightforward for LLVM to optimize out if the tracing
372        // level is statically disabled, while not causing any performance
373        // regression in case the level is enabled.
374        let __tracing_attr_span;
375        let __tracing_attr_guard;
376        if ::tracing::level_enabled!(#level) || ::tracing::if_log_enabled!(#level, {true} else {false}) {
377            __tracing_attr_span = #span;
378            #follows_from
379            __tracing_attr_guard = __tracing_attr_span.enter();
380        }
381    );
382
383    match (err_event, ret_event) {
384        (Some(err_event), Some(ret_event)) => quote_spanned! {block.span()=>
385            #span
386            #[allow(clippy::redundant_closure_call)]
387            match (move || #block)() {
388                #[allow(clippy::unit_arg)]
389                Ok(x) => {
390                    #ret_event;
391                    Ok(x)
392                },
393                Err(e) => {
394                    #err_event;
395                    Err(e)
396                }
397            }
398        },
399        (Some(err_event), None) => quote_spanned!(block.span()=>
400            #span
401            #[allow(clippy::redundant_closure_call)]
402            match (move || #block)() {
403                #[allow(clippy::unit_arg)]
404                Ok(x) => Ok(x),
405                Err(e) => {
406                    #err_event;
407                    Err(e)
408                }
409            }
410        ),
411        (None, Some(ret_event)) => quote_spanned!(block.span()=>
412            #span
413            #[allow(clippy::redundant_closure_call)]
414            let x = (move || #block)();
415            #ret_event;
416            x
417        ),
418        (None, None) => quote_spanned!(block.span() =>
419            // Because `quote` produces a stream of tokens _without_ whitespace, the
420            // `if` and the block will appear directly next to each other. This
421            // generates a clippy lint about suspicious `if/else` formatting.
422            // Therefore, suppress the lint inside the generated code...
423            #[allow(clippy::suspicious_else_formatting)]
424            {
425                #span
426                // ...but turn the lint back on inside the function body.
427                #[warn(clippy::suspicious_else_formatting)]
428                #block
429            }
430        ),
431    }
432}
433
434/// Indicates whether a field should be recorded as `Value` or `Debug`.
435enum RecordType {
436    /// The field should be recorded using its `Value` implementation.
437    Value,
438    /// The field should be recorded using `tracing::field::debug()`.
439    Debug,
440}
441
442impl RecordType {
443    /// Array of primitive types which should be recorded as [RecordType::Value].
444    const TYPES_FOR_VALUE: &'static [&'static str] = &[
445        "bool",
446        "str",
447        "u8",
448        "i8",
449        "u16",
450        "i16",
451        "u32",
452        "i32",
453        "u64",
454        "i64",
455        "u128",
456        "i128",
457        "f32",
458        "f64",
459        "usize",
460        "isize",
461        "String",
462        "NonZeroU8",
463        "NonZeroI8",
464        "NonZeroU16",
465        "NonZeroI16",
466        "NonZeroU32",
467        "NonZeroI32",
468        "NonZeroU64",
469        "NonZeroI64",
470        "NonZeroU128",
471        "NonZeroI128",
472        "NonZeroUsize",
473        "NonZeroIsize",
474        "Wrapping",
475    ];
476
477    /// Parse `RecordType` from [Type] by looking up
478    /// the [RecordType::TYPES_FOR_VALUE] array.
479    fn parse_from_ty(ty: &Type) -> Self {
480        match ty {
481            Type::Path(TypePath { path, .. })
482                if path
483                    .segments
484                    .iter()
485                    .next_back()
486                    .map(|path_segment| {
487                        let ident = path_segment.ident.to_string();
488                        Self::TYPES_FOR_VALUE.iter().any(|&t| t == ident)
489                    })
490                    .unwrap_or(false) =>
491            {
492                RecordType::Value
493            }
494            Type::Reference(syn::TypeReference { elem, .. }) => RecordType::parse_from_ty(elem),
495            _ => RecordType::Debug,
496        }
497    }
498}
499
500fn param_names(pat: Pat, record_type: RecordType) -> Box<dyn Iterator<Item = (Ident, RecordType)>> {
501    match pat {
502        Pat::Ident(PatIdent { ident, .. }) => Box::new(iter::once((ident, record_type))),
503        Pat::Reference(PatReference { pat, .. }) => param_names(*pat, record_type),
504        // We can't get the concrete type of fields in the struct/tuple
505        // patterns by using `syn`. e.g. `fn foo(Foo { x, y }: Foo) {}`.
506        // Therefore, the struct/tuple patterns in the arguments will just
507        // always be recorded as `RecordType::Debug`.
508        Pat::Struct(PatStruct { fields, .. }) => Box::new(
509            fields
510                .into_iter()
511                .flat_map(|FieldPat { pat, .. }| param_names(*pat, RecordType::Debug)),
512        ),
513        Pat::Tuple(PatTuple { elems, .. }) => Box::new(
514            elems
515                .into_iter()
516                .flat_map(|p| param_names(p, RecordType::Debug)),
517        ),
518        Pat::TupleStruct(PatTupleStruct { elems, .. }) => Box::new(
519            elems
520                .into_iter()
521                .flat_map(|p| param_names(p, RecordType::Debug)),
522        ),
523
524        // The above *should* cover all cases of irrefutable patterns,
525        // but we purposefully don't do any funny business here
526        // (such as panicking) because that would obscure rustc's
527        // much more informative error message.
528        _ => Box::new(iter::empty()),
529    }
530}
531
532/// The specific async code pattern that was detected
533enum AsyncKind<'a> {
534    /// Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
535    /// `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
536    Function(&'a ItemFn),
537    /// A function returning an async (move) block, optionally `Box::pin`-ed,
538    /// as generated by `async-trait >= 0.1.44`:
539    /// `Box::pin(async move { ... })`
540    Async {
541        async_expr: &'a ExprAsync,
542        pinned_box: bool,
543    },
544}
545
546pub(crate) struct AsyncInfo<'block> {
547    // statement that must be patched
548    source_stmt: &'block Stmt,
549    kind: AsyncKind<'block>,
550    self_type: Option<TypePath>,
551    input: &'block ItemFn,
552}
553
554impl<'block> AsyncInfo<'block> {
555    /// Get the AST of the inner function we need to hook, if it looks like a
556    /// manual future implementation.
557    ///
558    /// When we are given a function that returns a (pinned) future containing the
559    /// user logic, it is that (pinned) future that needs to be instrumented.
560    /// Were we to instrument its parent, we would only collect information
561    /// regarding the allocation of that future, and not its own span of execution.
562    ///
563    /// We inspect the block of the function to find if it matches any of the
564    /// following patterns:
565    ///
566    /// - Immediately-invoked async fn, as generated by `async-trait <= 0.1.43`:
567    ///   `async fn foo<...>(...) {...}; Box::pin(foo<...>(...))`
568    ///
569    /// - A function returning an async (move) block, optionally `Box::pin`-ed,
570    ///   as generated by `async-trait >= 0.1.44`:
571    ///   `Box::pin(async move { ... })`
572    ///
573    /// We the return the statement that must be instrumented, along with some
574    /// other information.
575    /// 'gen_body' will then be able to use that information to instrument the
576    /// proper function/future.
577    ///
578    /// (this follows the approach suggested in
579    /// https://github.com/dtolnay/async-trait/issues/45#issuecomment-571245673)
580    pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> {
581        // are we in an async context? If yes, this isn't a manual async-like pattern
582        if input.sig.asyncness.is_some() {
583            return None;
584        }
585
586        let block = &input.block;
587
588        // list of async functions declared inside the block
589        let inside_funs = block.stmts.iter().filter_map(|stmt| {
590            if let Stmt::Item(Item::Fn(fun)) = &stmt {
591                // If the function is async, this is a candidate
592                if fun.sig.asyncness.is_some() {
593                    return Some((stmt, fun));
594                }
595            }
596            None
597        });
598
599        // last expression of the block: it determines the return value of the
600        // block, this is quite likely a `Box::pin` statement or an async block
601        let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
602            if let Stmt::Expr(expr, _semi) = stmt {
603                Some((stmt, expr))
604            } else {
605                None
606            }
607        })?;
608
609        // is the last expression an async block?
610        if let Expr::Async(async_expr) = last_expr {
611            return Some(AsyncInfo {
612                source_stmt: last_expr_stmt,
613                kind: AsyncKind::Async {
614                    async_expr,
615                    pinned_box: false,
616                },
617                self_type: None,
618                input,
619            });
620        }
621
622        // is the last expression a function call?
623        let (outside_func, outside_args) = match last_expr {
624            Expr::Call(ExprCall { func, args, .. }) => (func, args),
625            _ => return None,
626        };
627
628        // is it a call to `Box::pin()`?
629        let path = match outside_func.as_ref() {
630            Expr::Path(path) => &path.path,
631            _ => return None,
632        };
633        if !path_to_string(path).ends_with("Box::pin") {
634            return None;
635        }
636
637        // Does the call take an argument? If it doesn't,
638        // it's not gonna compile anyway, but that's no reason
639        // to (try to) perform an out of bounds access
640        if outside_args.is_empty() {
641            return None;
642        }
643
644        // Is the argument to Box::pin an async block that
645        // captures its arguments?
646        if let Expr::Async(async_expr) = &outside_args[0] {
647            return Some(AsyncInfo {
648                source_stmt: last_expr_stmt,
649                kind: AsyncKind::Async {
650                    async_expr,
651                    pinned_box: true,
652                },
653                self_type: None,
654                input,
655            });
656        }
657
658        // Is the argument to Box::pin a function call itself?
659        let func = match &outside_args[0] {
660            Expr::Call(ExprCall { func, .. }) => func,
661            _ => return None,
662        };
663
664        // "stringify" the path of the function called
665        let func_name = match **func {
666            Expr::Path(ref func_path) => path_to_string(&func_path.path),
667            _ => return None,
668        };
669
670        // Was that function defined inside of the current block?
671        // If so, retrieve the statement where it was declared and the function itself
672        let (stmt_func_declaration, func) = inside_funs
673            .into_iter()
674            .find(|(_, fun)| fun.sig.ident == func_name)?;
675
676        // If "_self" is present as an argument, we store its type to be able to rewrite "Self" (the
677        // parameter type) with the type of "_self"
678        let mut self_type = None;
679        for arg in &func.sig.inputs {
680            if let FnArg::Typed(ty) = arg {
681                if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
682                    if ident == "_self" {
683                        let mut ty = *ty.ty.clone();
684                        // extract the inner type if the argument is "&self" or "&mut self"
685                        if let Type::Reference(syn::TypeReference { elem, .. }) = ty {
686                            ty = *elem;
687                        }
688
689                        if let Type::Path(tp) = ty {
690                            self_type = Some(tp);
691                            break;
692                        }
693                    }
694                }
695            }
696        }
697
698        Some(AsyncInfo {
699            source_stmt: stmt_func_declaration,
700            kind: AsyncKind::Function(func),
701            self_type,
702            input,
703        })
704    }
705
706    pub(crate) fn gen_async(
707        self,
708        args: InstrumentArgs,
709        instrumented_function_name: &str,
710    ) -> Result<proc_macro::TokenStream, syn::Error> {
711        // let's rewrite some statements!
712        let mut out_stmts: Vec<TokenStream> = self
713            .input
714            .block
715            .stmts
716            .iter()
717            .map(|stmt| stmt.to_token_stream())
718            .collect();
719
720        if let Some((iter, _stmt)) = self
721            .input
722            .block
723            .stmts
724            .iter()
725            .enumerate()
726            .find(|(_iter, stmt)| *stmt == self.source_stmt)
727        {
728            // instrument the future by rewriting the corresponding statement
729            out_stmts[iter] = match self.kind {
730                // `Box::pin(immediately_invoked_async_fn())`
731                AsyncKind::Function(fun) => {
732                    let fun = MaybeItemFn::from(fun.clone());
733                    gen_function(
734                        fun.as_ref(),
735                        args,
736                        instrumented_function_name,
737                        self.self_type.as_ref(),
738                    )
739                }
740                // `async move { ... }`, optionally pinned
741                AsyncKind::Async {
742                    async_expr,
743                    pinned_box,
744                } => {
745                    let instrumented_block = gen_block(
746                        &async_expr.block,
747                        &self.input.sig.inputs,
748                        true,
749                        args,
750                        instrumented_function_name,
751                        None,
752                    );
753                    let async_attrs = &async_expr.attrs;
754                    if pinned_box {
755                        quote! {
756                            ::std::boxed::Box::pin(#(#async_attrs) * async move { #instrumented_block })
757                        }
758                    } else {
759                        quote! {
760                            #(#async_attrs) * async move { #instrumented_block }
761                        }
762                    }
763                }
764            };
765        }
766
767        let vis = &self.input.vis;
768        let sig = &self.input.sig;
769        let attrs = &self.input.attrs;
770        Ok(quote!(
771            #(#attrs) *
772            #vis #sig {
773                #(#out_stmts) *
774            }
775        )
776        .into())
777    }
778}
779
780// Return a path as a String
781fn path_to_string(path: &Path) -> String {
782    use std::fmt::Write;
783    // some heuristic to prevent too many allocations
784    let mut res = String::with_capacity(path.segments.len() * 5);
785    for i in 0..path.segments.len() {
786        write!(&mut res, "{}", path.segments[i].ident)
787            .expect("writing to a String should never fail");
788        if i < path.segments.len() - 1 {
789            res.push_str("::");
790        }
791    }
792    res
793}
794
795/// A visitor struct to replace idents and types in some piece
796/// of code (e.g. the "self" and "Self" tokens in user-supplied
797/// fields expressions when the function is generated by an old
798/// version of async-trait).
799struct IdentAndTypesRenamer<'a> {
800    types: Vec<(&'a str, TypePath)>,
801    idents: Vec<(Ident, Ident)>,
802}
803
804impl VisitMut for IdentAndTypesRenamer<'_> {
805    // we deliberately compare strings because we want to ignore the spans
806    // If we apply clippy's lint, the behavior changes
807    #[allow(clippy::cmp_owned)]
808    fn visit_ident_mut(&mut self, id: &mut Ident) {
809        for (old_ident, new_ident) in &self.idents {
810            if id.to_string() == old_ident.to_string() {
811                *id = new_ident.clone();
812            }
813        }
814    }
815
816    fn visit_type_mut(&mut self, ty: &mut Type) {
817        for (type_name, new_type) in &self.types {
818            if let Type::Path(TypePath { path, .. }) = ty {
819                if path_to_string(path) == *type_name {
820                    *ty = Type::Path(new_type.clone());
821                }
822            }
823        }
824    }
825}
826
827// Replaces any `impl Trait` with `_` so it can be used as the type in
828// a `let` statement's LHS.
829struct ImplTraitEraser;
830
831impl VisitMut for ImplTraitEraser {
832    fn visit_type_mut(&mut self, t: &mut Type) {
833        if let Type::ImplTrait(..) = t {
834            *t = syn::TypeInfer {
835                underscore_token: Token![_](t.span()),
836            }
837            .into();
838        } else {
839            syn::visit_mut::visit_type_mut(self, t);
840        }
841    }
842}
843
844fn erase_impl_trait(ty: &Type) -> Type {
845    let mut ty = ty.clone();
846    ImplTraitEraser.visit_type_mut(&mut ty);
847    ty
848}