use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
parse::{Parse, ParseBuffer, ParseStream},
visit_mut::VisitMut,
*,
};
use crate::utils::*;
use super::PIN;
pub(super) fn parse_derive(input: TokenStream) -> Result<TokenStream> {
match syn::parse2(input)? {
Item::Struct(ItemStruct { attrs, vis, ident, generics, fields, .. }) => {
validate_struct(&ident, &fields)?;
let mut cx = Context::new(attrs, vis, ident, generics)?;
let packed_check = cx.ensure_not_packed(&fields)?;
let mut proj_items = cx.parse_struct(&fields)?;
proj_items.extend(packed_check);
proj_items.extend(cx.make_unpin_impl());
proj_items.extend(cx.make_drop_impl());
Ok(proj_items)
}
Item::Enum(ItemEnum { attrs, vis, ident, generics, brace_token, variants, .. }) => {
validate_enum(brace_token, &variants)?;
let mut cx = Context::new(attrs, vis, ident, generics)?;
let mut proj_items = cx.parse_enum(&variants)?;
proj_items.extend(cx.make_unpin_impl());
proj_items.extend(cx.make_drop_impl());
Ok(proj_items)
}
item => Err(error!(item, "#[pin_project] attribute may only be used on structs or enums")),
}
}
fn validate_struct(ident: &Ident, fields: &Fields) -> Result<()> {
match fields {
Fields::Named(FieldsNamed { named: f, .. })
| Fields::Unnamed(FieldsUnnamed { unnamed: f, .. })
if f.is_empty() =>
{
Err(error!(
fields,
"#[pin_project] attribute may not be used on structs with zero fields"
))
}
Fields::Unit => {
Err(error!(ident, "#[pin_project] attribute may not be used on structs with units"))
}
_ => Ok(()),
}
}
fn validate_enum(brace_token: token::Brace, variants: &Variants) -> Result<()> {
if variants.is_empty() {
return Err(syn::Error::new(
brace_token.span,
"#[pin_project] attribute may not be used on enums without variants",
));
}
let has_field = variants.iter().try_fold(false, |has_field, v| {
if let Some((_, e)) = &v.discriminant {
Err(error!(e, "#[pin_project] attribute may not be used on enums with discriminants"))
} else if let Some(attr) = v.attrs.find(PIN) {
Err(error!(attr, "#[pin] attribute may only be used on fields of structs or variants"))
} else if let Fields::Unit = v.fields {
Ok(has_field)
} else {
Ok(true)
}
})?;
if has_field {
Ok(())
} else {
Err(error!(
variants,
"#[pin_project] attribute may not be used on enums that have no field"
))
}
}
#[derive(Default)]
struct Args {
pinned_drop: Option<Span>,
unsafe_unpin: Option<Span>,
}
const DUPLICATE_PIN: &str = "duplicate #[pin] attribute";
impl Args {
fn get(attrs: &[Attribute]) -> Result<Self> {
let mut prev: Option<(&Attribute, Result<Args>)> = None;
for attr in attrs {
if attr.path.is_ident(PIN) {
if let Some((prev_attr, prev_res)) = &prev {
let res = syn::parse2::<Self>(attr.tokens.clone());
let span = match (&prev_res, res) {
(Ok(_), Ok(_)) => unreachable!(),
(_, Ok(_)) => prev_attr,
(Ok(_), _) => attr,
(Err(prev_err), Err(_)) => {
if prev_err.to_string() == DUPLICATE_PIN {
attr
} else {
prev_attr
}
}
};
return Err(error!(span, DUPLICATE_PIN));
}
prev = Some((attr, syn::parse2::<Self>(attr.tokens.clone())));
}
}
prev.unwrap().1
}
}
impl Parse for Args {
fn parse(input: ParseStream<'_>) -> Result<Self> {
fn parse_input(input: ParseStream<'_>) -> Result<ParseBuffer<'_>> {
if let Ok(content) = input.parenthesized() {
if let Ok(private) = content.parse::<Ident>() {
if private == CURRENT_PRIVATE_MODULE {
if let Ok(args) = content.parenthesized() {
return Ok(args);
}
}
}
}
Err(error!(TokenStream::new(), DUPLICATE_PIN))
}
let input = parse_input(input)?;
let mut args = Self::default();
while !input.is_empty() {
let ident = input.parse::<Ident>()?;
match &*ident.to_string() {
"PinnedDrop" => {
if args.pinned_drop.is_some() {
return Err(error!(ident, "duplicate `PinnedDrop` argument"));
}
args.pinned_drop = Some(ident.span());
}
"UnsafeUnpin" => {
if args.unsafe_unpin.is_some() {
return Err(error!(ident, "duplicate `UnsafeUnpin` argument"));
}
args.unsafe_unpin = Some(ident.span());
}
_ => return Err(error!(ident, "unexpected argument: {}", ident)),
}
if !input.is_empty() {
let _: token::Comma = input.parse()?;
}
}
Ok(args)
}
}
struct OriginalType {
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
generics: Generics,
}
struct ProjectedType {
vis: Visibility,
mut_ident: Ident,
ref_ident: Ident,
lifetime: Lifetime,
generics: Generics,
}
struct Context {
orig: OriginalType,
proj: ProjectedType,
pinned_fields: Vec<Type>,
pinned_drop: Option<Span>,
unsafe_unpin: Option<Span>,
}
impl Context {
fn new(
attrs: Vec<Attribute>,
vis: Visibility,
ident: Ident,
mut generics: Generics,
) -> Result<Self> {
let Args { pinned_drop, unsafe_unpin } = Args::get(&attrs)?;
{
let ty_generics = generics.split_for_impl().1;
let self_ty = syn::parse_quote!(#ident #ty_generics);
let mut visitor = ReplaceReceiver::new(&self_ty);
visitor.visit_where_clause_mut(generics.make_where_clause());
}
let mut lifetime_name = String::from(DEFAULT_LIFETIME_NAME);
determine_lifetime_name(&mut lifetime_name, &generics.params);
let lifetime = Lifetime::new(&lifetime_name, Span::call_site());
let mut proj_generics = generics.clone();
insert_lifetime(&mut proj_generics, lifetime.clone());
Ok(Self {
proj: ProjectedType {
vis: determine_visibility(&vis),
mut_ident: proj_ident(&ident, Mutable),
ref_ident: proj_ident(&ident, Immutable),
lifetime,
generics: proj_generics,
},
orig: OriginalType { attrs, vis, ident, generics },
pinned_drop,
unsafe_unpin,
pinned_fields: Vec::new(),
})
}
fn parse_struct(&mut self, fields: &Fields) -> Result<TokenStream> {
let (proj_pat, proj_init, proj_fields, proj_ref_fields) = match fields {
Fields::Named(fields) => self.visit_named(fields)?,
Fields::Unnamed(fields) => self.visit_unnamed(fields, true)?,
Fields::Unit => unreachable!(),
};
let orig_ident = &self.orig.ident;
let proj_ident = &self.proj.mut_ident;
let proj_ref_ident = &self.proj.ref_ident;
let vis = &self.proj.vis;
let proj_generics = &self.proj.generics;
let where_clause = self.orig.generics.split_for_impl().2;
let mut proj_items = quote! {
#[allow(clippy::mut_mut)]
#[allow(dead_code)]
#vis struct #proj_ident #proj_generics #where_clause #proj_fields
#[allow(dead_code)]
#vis struct #proj_ref_ident #proj_generics #where_clause #proj_ref_fields
};
let proj_body = quote! {
let #orig_ident #proj_pat = self.get_unchecked_mut();
#proj_ident #proj_init
};
let proj_ref_body = quote! {
let #orig_ident #proj_pat = self.get_ref();
#proj_ref_ident #proj_init
};
proj_items.extend(self.make_proj_impl(&proj_body, &proj_ref_body));
Ok(proj_items)
}
fn parse_enum(&mut self, variants: &Variants) -> Result<TokenStream> {
let (proj_variants, proj_ref_variants, proj_arms, proj_ref_arms) =
self.visit_variants(variants)?;
let proj_ident = &self.proj.mut_ident;
let proj_ref_ident = &self.proj.ref_ident;
let vis = &self.proj.vis;
let proj_generics = &self.proj.generics;
let where_clause = self.orig.generics.split_for_impl().2;
let mut proj_items = quote! {
#[allow(clippy::mut_mut)]
#[allow(dead_code)]
#vis enum #proj_ident #proj_generics #where_clause {
#proj_variants
}
#[allow(dead_code)]
#vis enum #proj_ref_ident #proj_generics #where_clause {
#proj_ref_variants
}
};
let proj_body = quote! {
match self.get_unchecked_mut() {
#proj_arms
}
};
let proj_ref_body = quote! {
match self.get_ref() {
#proj_ref_arms
}
};
proj_items.extend(self.make_proj_impl(&proj_body, &proj_ref_body));
Ok(proj_items)
}
fn visit_variants(
&mut self,
variants: &Variants,
) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> {
let mut proj_variants = TokenStream::new();
let mut proj_ref_variants = TokenStream::new();
let mut proj_arms = TokenStream::new();
let mut proj_ref_arms = TokenStream::new();
for Variant { ident, fields, .. } in variants {
let (proj_pat, proj_body, proj_fields, proj_ref_fields) = match fields {
Fields::Named(fields) => self.visit_named(fields)?,
Fields::Unnamed(fields) => self.visit_unnamed(fields, false)?,
Fields::Unit => {
(TokenStream::new(), TokenStream::new(), TokenStream::new(), TokenStream::new())
}
};
let orig_ident = &self.orig.ident;
let proj_ident = &self.proj.mut_ident;
let proj_ref_ident = &self.proj.ref_ident;
proj_variants.extend(quote! {
#ident #proj_fields,
});
proj_ref_variants.extend(quote! {
#ident #proj_ref_fields,
});
proj_arms.extend(quote! {
#orig_ident::#ident #proj_pat => {
#proj_ident::#ident #proj_body
}
});
proj_ref_arms.extend(quote! {
#orig_ident::#ident #proj_pat => {
#proj_ref_ident::#ident #proj_body
}
});
}
Ok((proj_variants, proj_ref_variants, proj_arms, proj_ref_arms))
}
fn visit_named(
&mut self,
FieldsNamed { named: fields, .. }: &FieldsNamed,
) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> {
let mut proj_pat = Vec::with_capacity(fields.len());
let mut proj_body = Vec::with_capacity(fields.len());
let mut proj_fields = Vec::with_capacity(fields.len());
let mut proj_ref_fields = Vec::with_capacity(fields.len());
for Field { attrs, vis, ident, ty, .. } in fields {
if attrs.find_exact(PIN)?.is_some() {
self.pinned_fields.push(ty.clone());
let lifetime = &self.proj.lifetime;
proj_fields.push(quote! {
#vis #ident: ::core::pin::Pin<&#lifetime mut (#ty)>
});
proj_ref_fields.push(quote! {
#vis #ident: ::core::pin::Pin<&#lifetime (#ty)>
});
proj_body.push(quote! {
#ident: ::core::pin::Pin::new_unchecked(#ident)
});
} else {
let lifetime = &self.proj.lifetime;
proj_fields.push(quote! {
#vis #ident: &#lifetime mut (#ty)
});
proj_ref_fields.push(quote! {
#vis #ident: &#lifetime (#ty)
});
proj_body.push(quote! {
#ident
});
}
proj_pat.push(ident);
}
let proj_pat = quote!({ #(#proj_pat),* });
let proj_body = quote!({ #(#proj_body),* });
let proj_fields = quote!({ #(#proj_fields),* });
let proj_ref_fields = quote!({ #(#proj_ref_fields),* });
Ok((proj_pat, proj_body, proj_fields, proj_ref_fields))
}
fn visit_unnamed(
&mut self,
FieldsUnnamed { unnamed: fields, .. }: &FieldsUnnamed,
is_struct: bool,
) -> Result<(TokenStream, TokenStream, TokenStream, TokenStream)> {
let mut proj_pat = Vec::with_capacity(fields.len());
let mut proj_body = Vec::with_capacity(fields.len());
let mut proj_fields = Vec::with_capacity(fields.len());
let mut proj_ref_fields = Vec::with_capacity(fields.len());
for (i, Field { attrs, vis, ty, .. }) in fields.iter().enumerate() {
let id = format_ident!("_{}", i);
if attrs.find_exact(PIN)?.is_some() {
self.pinned_fields.push(ty.clone());
let lifetime = &self.proj.lifetime;
proj_fields.push(quote! {
#vis ::core::pin::Pin<&#lifetime mut (#ty)>
});
proj_ref_fields.push(quote! {
#vis ::core::pin::Pin<&#lifetime (#ty)>
});
proj_body.push(quote! {
::core::pin::Pin::new_unchecked(#id)
});
} else {
let lifetime = &self.proj.lifetime;
proj_fields.push(quote! {
#vis &#lifetime mut (#ty)
});
proj_ref_fields.push(quote! {
#vis &#lifetime (#ty)
});
proj_body.push(quote! {
#id
});
}
proj_pat.push(id);
}
let proj_pat = quote!((#(#proj_pat),*));
let proj_body = quote!((#(#proj_body),*));
let (proj_fields, proj_ref_fields) = if is_struct {
(quote!((#(#proj_fields),*);), quote!((#(#proj_ref_fields),*);))
} else {
(quote!((#(#proj_fields),*)), quote!((#(#proj_ref_fields),*)))
};
Ok((proj_pat, proj_body, proj_fields, proj_ref_fields))
}
fn make_unpin_impl(&mut self) -> TokenStream {
if let Some(unsafe_unpin) = self.unsafe_unpin {
let mut proj_generics = self.proj.generics.clone();
let orig_ident = &self.orig.ident;
let lifetime = &self.proj.lifetime;
let private = Ident::new(CURRENT_PRIVATE_MODULE, Span::call_site());
proj_generics.make_where_clause().predicates.push(
syn::parse2(quote_spanned! { unsafe_unpin =>
::pin_project::#private::Wrapper<#lifetime, Self>: ::pin_project::UnsafeUnpin
})
.unwrap(),
);
let (impl_generics, _, where_clause) = proj_generics.split_for_impl();
let ty_generics = self.orig.generics.split_for_impl().1;
quote! {
#[allow(single_use_lifetimes)]
impl #impl_generics ::core::marker::Unpin for #orig_ident #ty_generics #where_clause {}
}
} else {
let mut full_where_clause = self.orig.generics.where_clause.as_ref().cloned().unwrap();
let orig_ident = &self.orig.ident;
let make_span = || {
#[cfg(pin_project_show_unpin_struct)]
{
proc_macro::Span::def_site().into()
}
#[cfg(not(pin_project_show_unpin_struct))]
{
Span::call_site()
}
};
let struct_ident = format_ident!("__{}", orig_ident, span = make_span());
let fields: Vec<_> = self
.pinned_fields
.iter()
.enumerate()
.map(|(i, ty)| {
let field_ident = format_ident!("__field{}", i);
quote! {
#field_ident: #ty
}
})
.collect();
let lifetime_fields: Vec<_> = self
.orig
.generics
.lifetimes()
.enumerate()
.map(|(i, LifetimeDef { lifetime, .. })| {
let field_ident = format_ident!("__lifetime{}", i);
quote! {
#field_ident: &#lifetime ()
}
})
.collect();
let scope_ident = format_ident!("__unpin_scope_{}", orig_ident);
let vis = &self.orig.vis;
let lifetime = &self.proj.lifetime;
let type_params: Vec<_> = self.orig.generics.type_params().map(|t| &t.ident).collect();
let proj_generics = &self.proj.generics;
let (impl_generics, proj_ty_generics, _) = proj_generics.split_for_impl();
let (_, ty_generics, where_clause) = self.orig.generics.split_for_impl();
full_where_clause.predicates.push(syn::parse_quote! {
#struct_ident #proj_ty_generics: ::core::marker::Unpin
});
let private = Ident::new(CURRENT_PRIVATE_MODULE, Span::call_site());
let inner_data = quote! {
#vis struct #struct_ident #proj_generics #where_clause {
__pin_project_use_generics: ::pin_project::#private::AlwaysUnpin<#lifetime, (#(#type_params),*)>,
#(#fields,)*
#(#lifetime_fields,)*
}
impl #impl_generics ::core::marker::Unpin for #orig_ident #ty_generics #full_where_clause {}
};
if cfg!(pin_project_show_unpin_struct) {
inner_data
} else {
quote! {
#[allow(non_snake_case)]
fn #scope_ident() {
#inner_data
}
}
}
}
}
fn make_drop_impl(&self) -> TokenStream {
let ident = &self.orig.ident;
let (impl_generics, ty_generics, where_clause) = self.orig.generics.split_for_impl();
let private = Ident::new(CURRENT_PRIVATE_MODULE, Span::call_site());
if let Some(pinned_drop) = self.pinned_drop {
let call_drop = quote_spanned! { pinned_drop =>
::pin_project::#private::PinnedDrop::drop(pinned_self)
};
quote! {
#[allow(single_use_lifetimes)]
impl #impl_generics ::core::ops::Drop for #ident #ty_generics #where_clause {
fn drop(&mut self) {
let pinned_self = unsafe { ::core::pin::Pin::new_unchecked(self) };
unsafe {
#call_drop;
}
}
}
}
} else {
let trait_ident = format_ident!("{}MustNotImplDrop", ident);
quote! {
trait #trait_ident {}
#[allow(clippy::drop_bounds)]
impl<T: ::core::ops::Drop> #trait_ident for T {}
#[allow(single_use_lifetimes)]
impl #impl_generics #trait_ident for #ident #ty_generics #where_clause {}
#[allow(single_use_lifetimes)]
impl #impl_generics ::pin_project::#private::PinnedDrop for #ident #ty_generics #where_clause {
unsafe fn drop(self: ::core::pin::Pin<&mut Self>) {}
}
}
}
}
fn make_proj_impl(&self, proj_body: &TokenStream, proj_ref_body: &TokenStream) -> TokenStream {
let vis = &self.proj.vis;
let lifetime = &self.proj.lifetime;
let orig_ident = &self.orig.ident;
let proj_ident = &self.proj.mut_ident;
let proj_ref_ident = &self.proj.ref_ident;
let proj_ty_generics = self.proj.generics.split_for_impl().1;
let (impl_generics, ty_generics, where_clause) = self.orig.generics.split_for_impl();
quote! {
impl #impl_generics #orig_ident #ty_generics #where_clause {
#vis fn project<#lifetime>(
self: ::core::pin::Pin<&#lifetime mut Self>,
) -> #proj_ident #proj_ty_generics {
unsafe {
#proj_body
}
}
#vis fn project_ref<#lifetime>(
self: ::core::pin::Pin<&#lifetime Self>,
) -> #proj_ref_ident #proj_ty_generics {
unsafe {
#proj_ref_body
}
}
}
}
}
fn ensure_not_packed(&self, fields: &Fields) -> Result<TokenStream> {
for meta in self.orig.attrs.iter().filter_map(|attr| attr.parse_meta().ok()) {
if let Meta::List(l) = meta {
if l.path.is_ident("repr") {
for repr in l.nested.iter() {
match repr {
NestedMeta::Meta(Meta::Path(path))
| NestedMeta::Meta(Meta::List(MetaList { path, .. }))
if path.is_ident("packed") =>
{
return Err(error!(
repr,
"#[pin_project] attribute may not be used on #[repr(packed)] types"
));
}
_ => {}
}
}
}
}
}
let mut field_refs = vec![];
match fields {
Fields::Named(FieldsNamed { named, .. }) => {
for Field { ident, .. } in named {
field_refs.push(quote! {
&val.#ident;
});
}
}
Fields::Unnamed(FieldsUnnamed { unnamed, .. }) => {
for (index, _) in unnamed.iter().enumerate() {
let index = Index::from(index);
field_refs.push(quote! {
&val.#index;
});
}
}
Fields::Unit => {}
}
let (impl_generics, ty_generics, where_clause) = self.orig.generics.split_for_impl();
let struct_name = &self.orig.ident;
let method_name = format_ident!("__pin_project_assert_not_repr_packed_{}", self.orig.ident);
Ok(quote! {
#[allow(single_use_lifetimes)]
#[allow(non_snake_case)]
#[deny(safe_packed_borrows)]
fn #method_name #impl_generics (val: &#struct_name #ty_generics) #where_clause {
#(#field_refs)*
}
})
}
}