diff --git a/uefi-macros/src/lib.rs b/uefi-macros/src/lib.rs index 371246f20..da7b589c4 100644 --- a/uefi-macros/src/lib.rs +++ b/uefi-macros/src/lib.rs @@ -4,10 +4,14 @@ extern crate proc_macro; use proc_macro::TokenStream; -use proc_macro2::Span; -use quote::{quote, TokenStreamExt}; -use syn::parse::{Parse, ParseStream}; -use syn::{parse_macro_input, DeriveInput, Generics, Ident, ItemFn, ItemType, LitStr}; +use proc_macro2::{TokenStream as TokenStream2, TokenTree}; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, + spanned::Spanned, + DeriveInput, Error, Generics, Ident, ItemFn, ItemType, LitStr, Visibility, +}; /// Parses a type definition, extracts its identifier and generic parameters struct TypeDefinition { @@ -33,51 +37,34 @@ impl Parse for TypeDefinition { } } +macro_rules! err { + ($span:expr, $message:expr $(,)?) => { + Error::new($span.span(), $message).to_compile_error() + }; + ($span:expr, $message:expr, $($args:expr),*) => { + Error::new($span.span(), format!($message, $($args),*)).to_compile_error() + }; +} + /// `unsafe_guid` attribute macro, implements the `Identify` trait for any type /// (mostly works like a custom derive, but also supports type aliases) #[proc_macro_attribute] pub fn unsafe_guid(args: TokenStream, input: TokenStream) -> TokenStream { // Parse the arguments and input using Syn - let guid_str = parse_macro_input!(args as LitStr).value(); - let mut result: proc_macro2::TokenStream = input.clone().into(); - let type_definition = parse_macro_input!(input as TypeDefinition); + let (time_low, time_mid, time_high_and_version, clock_seq_and_variant, node) = + match parse_guid(parse_macro_input!(args as LitStr)) { + Ok(data) => data, + Err(tokens) => return tokens.into(), + }; - // We expect a canonical GUID string, such as "12345678-9abc-def0-fedc-ba9876543210" - if guid_str.len() != 36 { - panic!( - "\"{}\" is not a canonical GUID string (expected 36 bytes, found {})", - guid_str, - guid_str.len() - ); - } - let mut guid_hex_iter = guid_str.split('-'); - let mut next_guid_int = |expected_num_bits: usize| -> u64 { - let guid_hex_component = guid_hex_iter.next().unwrap(); - if guid_hex_component.len() != expected_num_bits / 4 { - panic!( - "GUID component \"{}\" is not a {}-bit hexadecimal string", - guid_hex_component, expected_num_bits - ); - } - match u64::from_str_radix(guid_hex_component, 16) { - Ok(number) => number, - _ => panic!( - "GUID component \"{}\" is not a hexadecimal number", - guid_hex_component - ), - } - }; + let mut result: TokenStream2 = input.clone().into(); - // The GUID string is composed of a 32-bit integer, three 16-bit ones, and a 48-bit one - let time_low = next_guid_int(32) as u32; - let time_mid = next_guid_int(16) as u16; - let time_high_and_version = next_guid_int(16) as u16; - let clock_seq_and_variant = next_guid_int(16) as u16; - let node = next_guid_int(48); + let type_definition = parse_macro_input!(input as TypeDefinition); // At this point, we know everything we need to implement Identify - let ident = type_definition.ident.clone(); + let ident = &type_definition.ident; let (impl_generics, ty_generics, where_clause) = type_definition.generics.split_for_impl(); + result.append_all(quote! { unsafe impl #impl_generics ::uefi::Identify for #ident #ty_generics #where_clause { #[doc(hidden)] @@ -94,6 +81,61 @@ pub fn unsafe_guid(args: TokenStream, input: TokenStream) -> TokenStream { result.into() } +fn parse_guid(guid_lit: LitStr) -> Result<(u32, u16, u16, u16, u64), TokenStream2> { + let guid_str = guid_lit.value(); + + // We expect a canonical GUID string, such as "12345678-9abc-def0-fedc-ba9876543210" + if guid_str.len() != 36 { + return Err(err!( + guid_lit, + "\"{}\" is not a canonical GUID string (expected 36 bytes, found {})", + guid_str, + guid_str.len() + )); + } + let mut offset = 1; // 1 is for the starting quote + let mut guid_hex_iter = guid_str.split('-'); + let mut next_guid_int = |len: usize| -> Result { + let guid_hex_component = guid_hex_iter.next().unwrap(); + + // convert syn::LitStr to proc_macro2::Literal.. + let lit = match guid_lit.to_token_stream().into_iter().next().unwrap() { + TokenTree::Literal(lit) => lit, + _ => unreachable!(), + }; + // ..so that we can call subspan and nightly users (us) will get the fancy span + let span = lit + .subspan(offset..offset + guid_hex_component.len()) + .unwrap_or_else(|| lit.span()); + + if guid_hex_component.len() != len * 2 { + return Err(err!( + span, + "GUID component \"{}\" is not a {}-bit hexadecimal string", + guid_hex_component, + len * 8 + )); + } + offset += guid_hex_component.len() + 1; // + 1 for the dash + u64::from_str_radix(guid_hex_component, 16).map_err(|_| { + err!( + span, + "GUID component \"{}\" is not a hexadecimal number", + guid_hex_component + ) + }) + }; + + // The GUID string is composed of a 32-bit integer, three 16-bit ones, and a 48-bit one + Ok(( + next_guid_int(4)? as u32, + next_guid_int(2)? as u16, + next_guid_int(2)? as u16, + next_guid_int(2)? as u16, + next_guid_int(6)?, + )) +} + /// Custom derive for the `Protocol` trait #[proc_macro_derive(Protocol)] pub fn derive_protocol(item: TokenStream) -> TokenStream { @@ -122,19 +164,51 @@ pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream { // This code is inspired by the approach in this embedded Rust crate: // https://github.com/rust-embedded/cortex-m-rt/blob/965bf1e3291571e7e3b34834864117dc020fb391/macros/src/lib.rs#L85 + let mut errors = TokenStream2::new(); + if !args.is_empty() { - panic!("This attribute accepts no arguments"); + errors.append_all(err!( + TokenStream2::from(args), + "Entry attribute accepts no arguments" + )); } let mut f = parse_macro_input!(input as ItemFn); - // force the exported symbol to be 'efi_main' - f.sig.ident = Ident::new("efi_main", Span::call_site()); + if let Some(ref abi) = f.sig.abi { + errors.append_all(err!(abi, "Entry method must have no ABI modifier")); + } + if let Some(asyncness) = f.sig.asyncness { + errors.append_all(err!(asyncness, "Entry method should not be async")); + } + if let Some(constness) = f.sig.constness { + errors.append_all(err!(constness, "Entry method should not be const")); + } + if !f.sig.generics.params.is_empty() { + errors.append_all(err!( + &f.sig.generics.params, + "Entry method should not be generic" + )); + } + + // show most errors at once instead of one by one + if !errors.is_empty() { + return errors.into(); + } + + // allow the entry function to be unsafe (by moving the keyword around so that it actually works) + let unsafety = f.sig.unsafety.take(); + // strip any visibility modifiers + f.vis = Visibility::Inherited; + + let ident = &f.sig.ident; let result = quote! { - static _UEFI_ENTRY_POINT_TYPE_CHECK: extern "efiapi" fn(uefi::Handle, uefi::table::SystemTable) -> uefi::Status = efi_main; - #[no_mangle] - pub extern "efiapi" #f + #[export_name = "efi_main"] + #unsafety extern "efiapi" #f + + // typecheck the function pointer + const _: #unsafety extern "efiapi" fn(::uefi::Handle, ::uefi::table::SystemTable<::uefi::table::Boot>) -> ::uefi::Status = #ident; }; result.into() }