Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 128 additions & 106 deletions core/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::{
DeriveInput, FnArg, ImplItem, ImplItemFn, ItemEnum, ItemTrait, LitStr, Meta, Pat, TraitItem,
Visibility, parse_macro_input, parse_quote,
Visibility, parse_macro_input, parse_quote, Error, Result,
};

/// Define an enum whose variants each implement a trait.
Expand Down Expand Up @@ -40,30 +41,37 @@ use syn::{
/// impl MyTrait for Object {}
/// ```
#[proc_macro_attribute]
#[cfg(feature = "enum-trait-object")]
pub fn enum_trait_object(args: TokenStream, item: TokenStream) -> TokenStream {
// Parse the input.
let mut input_trait = parse_macro_input!(item as ItemTrait);
let enum_input = parse_macro_input!(args as ItemEnum);

// Using a result here makes it much easier to report errors via syn::Error.
match expand_enum_trait_object(&mut input_trait, enum_input) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}

fn expand_enum_trait_object(input_trait: &mut ItemTrait, enum_input: ItemEnum) -> Result<TokenStream2> {
let trait_name = &input_trait.ident;
let trait_generics = &input_trait.generics;
let enum_input = parse_macro_input!(args as ItemEnum);
let enum_name = &enum_input.ident;

// TODO: Revise whether the first two asserts are needed at all, and whether
// the second condition should be `== 0` instead, based on the error message.
assert!(
trait_generics.lifetimes().count() <= 1,
"Only one lifetime parameter is currently supported"
);
if trait_generics.lifetimes().count() > 1 {
return Err(Error::new_spanned(trait_generics, "Only one lifetime parameter is currently supported"));
}

assert!(
trait_generics.type_params().count() <= 1,
"Generic type parameters are currently unsupported"
);
if trait_generics.type_params().count() > 1 {
return Err(Error::new_spanned(trait_generics, "Generic type parameters are currently unsupported"));
}

assert_eq!(
trait_generics, &enum_input.generics,
"Trait and enum should have the same generic parameters"
);
if trait_generics != &enum_input.generics {
return Err(Error::new_spanned(&enum_input.generics, "Trait and enum should have the same generic parameters"));
}

/// An hacky way to prevent accidental method overriding.
///
Expand All @@ -82,10 +90,10 @@ pub fn enum_trait_object(args: TokenStream, item: TokenStream) -> TokenStream {
impl NoOverrideModule {
fn make(trait_name: &syn::Ident) -> Self {
let mod_name = syn::Ident::new(
&format!("__{trait_name}_do_not_override"),
Span::call_site(),
&format!("__ruffle_{trait_name}_do_not_override"),
Span::mixed_site(),
);
let lt = syn::Lifetime::new("'no_dyn", Span::call_site());
let lt = syn::Lifetime::new("'no_dyn", Span::mixed_site());
let contents = quote! {
#[automatically_derived]
#[doc(hidden)]
Expand Down Expand Up @@ -115,69 +123,44 @@ pub fn enum_trait_object(args: TokenStream, item: TokenStream) -> TokenStream {
}

let mut no_override: Option<NoOverrideModule> = None;

// We check if the trait has a lifetime so we can correctly specify the enum type in the delegation.
let has_lifetime = trait_generics.lifetimes().next().is_some();
let enum_ty = if has_lifetime { quote!(#enum_name<'_>) } else { quote!(#enum_name) };

// Implement each trait. This will match against each enum variant and delegate
// to the underlying type.
let trait_methods: Vec<_> = input_trait
.items
.iter_mut()
.map(|item| match item {
let mut trait_methods = Vec::new();
for item in &mut input_trait.items {
match item {
TraitItem::Fn(method) => {
let mut is_no_dynamic = false;

method.attrs.retain(|attr| match &attr.meta {
Meta::Path(path) => {
if path.is_ident("no_dynamic") {
is_no_dynamic = true;

// Remove the #[no_dynamic] attribute from the
// list of method attributes.
false
} else {
true
}
}
_ => true,
});

let params: Vec<_> = method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(arg) = arg && let Pat::Ident(i) = &*arg.pat {
return Some(i.ident.clone());
}
None
})
.collect();
let (is_no_dynamic, params) = parse_trait_method_meta(method);

let method_block = if is_no_dynamic {
no_override
.get_or_insert_with(|| NoOverrideModule::make(trait_name))
.adjust_method(method);

let method_name = &method.sig.ident;
let deref = if let Some(syn::Receiver {
colon_token: None,
reference,
..
}) = method.sig.receiver()
{
reference.is_some().then(|| quote!(*))
let deref = if let Some(receiver) = method.sig.receiver() {
if receiver.colon_token.is_none() && receiver.reference.is_some() {
quote!(*)
} else {
quote!()
}
} else {
panic!("#[no_dynamic] method `{method_name}` must take `self`, `&self`, or `&mut self`")
return Err(Error::new_spanned(&method.sig.ident, format!("#[no_dynamic] method `{method_name}` must take `self`, `&self`, or `&mut self`")));
};

// Moves the provided default body to the enum's generated trait impl,
// and replace it by an impl that delegates to the enum.
method
.default
.replace(parse_quote!({
let mut o: #enum_name<'_> = (#deref self).into();
o.#method_name(#(#params),*)
}))
.expect("#[no_dynamic] method `{method_name}` must have a default body")
match method.default.replace(parse_quote!({
let o: #enum_ty = (#deref self).into();
o.#method_name(#(#params),*)
})) {
Some(body) => body,
None => return Err(Error::new_spanned(&method.sig.ident, format!("#[no_dynamic] method `{method_name}` must have a default body"))),
}
} else {
let method_name = &method.sig.ident;
let match_arms: Vec<_> = enum_input
Expand All @@ -198,46 +181,24 @@ pub fn enum_trait_object(args: TokenStream, item: TokenStream) -> TokenStream {
})
};

ImplItem::Fn(ImplItemFn {
trait_methods.push(ImplItem::Fn(ImplItemFn {
attrs: method.attrs.clone(),
vis: Visibility::Inherited,
defaultness: None,
sig: method.sig.clone(),
block: method_block,
})
}));
}
_ => panic!("Unsupported trait item: {item:?}"),
})
.collect();
_ => return Err(Error::new_spanned(item, format!("Unsupported trait item: {item:?}"))),
}
}

let (impl_generics, ty_generics, where_clause) = trait_generics.split_for_impl();
let from_impls = generate_from_impls(enum_name, &enum_input, trait_generics);
let no_override_tokens = no_override.map(|s| s.contents).into_iter();

// Implement `From` for each variant type.
let from_impls: Vec<_> = enum_input
.variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;
let variant_type = &variant
.fields
.iter()
.next()
.expect("Missing field for enum variant")
.ty;

quote!(
impl #impl_generics From<#variant_type> for #enum_name #ty_generics {
fn from(obj: #variant_type) -> #enum_name #trait_generics {
#enum_name::#variant_name(obj)
}
}
)
})
.collect();

let no_override = no_override.map(|s| s.contents).into_iter();
let out = quote!(
#(#no_override)*
Ok(quote!(
#(#no_override_tokens)*

#input_trait

Expand All @@ -248,12 +209,64 @@ pub fn enum_trait_object(args: TokenStream, item: TokenStream) -> TokenStream {
}

#(#from_impls)*
);
))
}

fn parse_trait_method_meta(method: &mut syn::TraitItemFn) -> (bool, Vec<syn::Ident>) {
let mut is_no_dynamic = false;

method.attrs.retain(|attr| {
if let Meta::Path(path) = &attr.meta {
if path.is_ident("no_dynamic") {
is_no_dynamic = true;
// Remove the #[no_dynamic] attribute from the list of method attributes.
return false;
}
}
true
});

out.into()
let params: Vec<_> = method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(arg) = arg {
if let Pat::Ident(i) = &*arg.pat {
return Some(i.ident.clone());
}
}
None
})
.collect();

(is_no_dynamic, params)
}

fn generate_from_impls(enum_name: &syn::Ident, enum_input: &ItemEnum, trait_generics: &syn::Generics) -> Vec<TokenStream2> {
let (impl_generics, ty_generics, _) = trait_generics.split_for_impl();
enum_input
.variants
.iter()
.filter_map(|variant| {
let variant_name = &variant.ident;
let field = variant.fields.iter().next()?;
let variant_type = &field.ty;

Some(quote!(
#[automatically_derived]
impl #impl_generics From<#variant_type> for #enum_name #ty_generics {
fn from(obj: #variant_type) -> #enum_name #ty_generics {
#enum_name::#variant_name(obj)
}
}
))
})
.collect()
}

#[proc_macro_derive(HasPrefixField)]
#[cfg(feature = "prefix-field")]
pub fn derive_has_prefix_field(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);

Expand All @@ -268,19 +281,21 @@ pub fn derive_has_prefix_field(input: TokenStream) -> TokenStream {
}
}

let Some(first_field) = ({
if let syn::Data::Struct(data) = &input.data {
let first_field = match &input.data {
syn::Data::Struct(data) => {
data.fields
.iter()
.next()
.filter(|f| is_repr_c && f.ident.is_some())
} else {
None
}
}) else {
panic!(
_ => None,
};

let Some(first_field) = first_field else {
return Error::new_spanned(
&input.ident,
"`HasPrefixField` can only be derived for repr(C) structs with at least one named field"
);
).to_compile_error().into();
};

let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
Expand Down Expand Up @@ -333,12 +348,14 @@ pub fn derive_has_prefix_field(input: TokenStream) -> TokenStream {
/// activation.context.strings.common().ascii_chars[65 /* 'A' */];
/// ```
#[proc_macro]
#[cfg(feature = "atoms")]
pub fn atom(item: TokenStream) -> TokenStream {
atom_internal(item, |atom| atom)
}

/// Like `atom!`, but returns an `AvmString` instead of an `AvmAtom`.
#[proc_macro]
#[cfg(feature = "atoms")]
pub fn istr(item: TokenStream) -> TokenStream {
atom_internal(item, |atom| {
quote!(
Expand Down Expand Up @@ -370,10 +387,15 @@ fn atom_internal(
}

let input = parse_macro_input!(item as Input);

let string = input.str.value();

// We verify that the string is actually safe for use in an identifier to prevent broken output.
if !string.chars().all(|c| c.is_alphanumeric() || c == '_') && string.len() != 1 {
return Error::new_spanned(&input.str, "Atom string contains characters that are invalid for identifiers").to_compile_error().into();
}

let (string_ident, array_index) = if string.len() == 1 && string.is_ascii() {
// Special case: a single ASCII char.
// Special case: a single ASCII char has a fast-path lookup in AVM.
let c = string.as_bytes()[0];
(format_ident!("ascii_chars"), Some(c as usize))
} else {
Expand Down
Loading