Skip to content

Improve macro errors #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 119 additions & 45 deletions uefi-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ extern crate proc_macro;

use proc_macro::TokenStream;

use proc_macro2::Span;
use quote::{quote, TokenStreamExt};
use syn::parse::{Parse, ParseStream};
use syn::{parse_macro_input, DeriveInput, Generics, Ident, ItemFn, ItemType, LitStr};
use proc_macro2::{TokenStream as TokenStream2, TokenTree};
use quote::{quote, ToTokens, TokenStreamExt};
use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
DeriveInput, Error, Generics, Ident, ItemFn, ItemType, LitStr, Visibility,
};

/// Parses a type definition, extracts its identifier and generic parameters
struct TypeDefinition {
Expand All @@ -33,51 +37,34 @@ impl Parse for TypeDefinition {
}
}

macro_rules! err {
($span:expr, $message:expr $(,)?) => {
Error::new($span.span(), $message).to_compile_error()
};
($span:expr, $message:expr, $($args:expr),*) => {
Error::new($span.span(), format!($message, $($args),*)).to_compile_error()
};
}

/// `unsafe_guid` attribute macro, implements the `Identify` trait for any type
/// (mostly works like a custom derive, but also supports type aliases)
#[proc_macro_attribute]
pub fn unsafe_guid(args: TokenStream, input: TokenStream) -> TokenStream {
// Parse the arguments and input using Syn
let guid_str = parse_macro_input!(args as LitStr).value();
let mut result: proc_macro2::TokenStream = input.clone().into();
let type_definition = parse_macro_input!(input as TypeDefinition);
let (time_low, time_mid, time_high_and_version, clock_seq_and_variant, node) =
match parse_guid(parse_macro_input!(args as LitStr)) {
Ok(data) => data,
Err(tokens) => return tokens.into(),
};

// We expect a canonical GUID string, such as "12345678-9abc-def0-fedc-ba9876543210"
if guid_str.len() != 36 {
panic!(
"\"{}\" is not a canonical GUID string (expected 36 bytes, found {})",
guid_str,
guid_str.len()
);
}
let mut guid_hex_iter = guid_str.split('-');
let mut next_guid_int = |expected_num_bits: usize| -> u64 {
let guid_hex_component = guid_hex_iter.next().unwrap();
if guid_hex_component.len() != expected_num_bits / 4 {
panic!(
"GUID component \"{}\" is not a {}-bit hexadecimal string",
guid_hex_component, expected_num_bits
);
}
match u64::from_str_radix(guid_hex_component, 16) {
Ok(number) => number,
_ => panic!(
"GUID component \"{}\" is not a hexadecimal number",
guid_hex_component
),
}
};
let mut result: TokenStream2 = input.clone().into();

// The GUID string is composed of a 32-bit integer, three 16-bit ones, and a 48-bit one
let time_low = next_guid_int(32) as u32;
let time_mid = next_guid_int(16) as u16;
let time_high_and_version = next_guid_int(16) as u16;
let clock_seq_and_variant = next_guid_int(16) as u16;
let node = next_guid_int(48);
let type_definition = parse_macro_input!(input as TypeDefinition);

// At this point, we know everything we need to implement Identify
let ident = type_definition.ident.clone();
let ident = &type_definition.ident;
let (impl_generics, ty_generics, where_clause) = type_definition.generics.split_for_impl();

result.append_all(quote! {
unsafe impl #impl_generics ::uefi::Identify for #ident #ty_generics #where_clause {
#[doc(hidden)]
Expand All @@ -94,6 +81,61 @@ pub fn unsafe_guid(args: TokenStream, input: TokenStream) -> TokenStream {
result.into()
}

fn parse_guid(guid_lit: LitStr) -> Result<(u32, u16, u16, u16, u64), TokenStream2> {
let guid_str = guid_lit.value();

// We expect a canonical GUID string, such as "12345678-9abc-def0-fedc-ba9876543210"
if guid_str.len() != 36 {
return Err(err!(
guid_lit,
"\"{}\" is not a canonical GUID string (expected 36 bytes, found {})",
guid_str,
guid_str.len()
));
}
let mut offset = 1; // 1 is for the starting quote
let mut guid_hex_iter = guid_str.split('-');
let mut next_guid_int = |len: usize| -> Result<u64, TokenStream2> {
let guid_hex_component = guid_hex_iter.next().unwrap();

// convert syn::LitStr to proc_macro2::Literal..
let lit = match guid_lit.to_token_stream().into_iter().next().unwrap() {
TokenTree::Literal(lit) => lit,
_ => unreachable!(),
};
// ..so that we can call subspan and nightly users (us) will get the fancy span
let span = lit
.subspan(offset..offset + guid_hex_component.len())
.unwrap_or_else(|| lit.span());

if guid_hex_component.len() != len * 2 {
return Err(err!(
span,
"GUID component \"{}\" is not a {}-bit hexadecimal string",
guid_hex_component,
len * 8
));
}
offset += guid_hex_component.len() + 1; // + 1 for the dash
u64::from_str_radix(guid_hex_component, 16).map_err(|_| {
err!(
span,
"GUID component \"{}\" is not a hexadecimal number",
guid_hex_component
)
})
};

// The GUID string is composed of a 32-bit integer, three 16-bit ones, and a 48-bit one
Ok((
next_guid_int(4)? as u32,
next_guid_int(2)? as u16,
next_guid_int(2)? as u16,
next_guid_int(2)? as u16,
next_guid_int(6)?,
))
}

/// Custom derive for the `Protocol` trait
#[proc_macro_derive(Protocol)]
pub fn derive_protocol(item: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -122,19 +164,51 @@ pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
// This code is inspired by the approach in this embedded Rust crate:
// https://github.com/rust-embedded/cortex-m-rt/blob/965bf1e3291571e7e3b34834864117dc020fb391/macros/src/lib.rs#L85

let mut errors = TokenStream2::new();

if !args.is_empty() {
panic!("This attribute accepts no arguments");
errors.append_all(err!(
TokenStream2::from(args),
"Entry attribute accepts no arguments"
));
}

let mut f = parse_macro_input!(input as ItemFn);

// force the exported symbol to be 'efi_main'
f.sig.ident = Ident::new("efi_main", Span::call_site());
if let Some(ref abi) = f.sig.abi {
errors.append_all(err!(abi, "Entry method must have no ABI modifier"));
}
if let Some(asyncness) = f.sig.asyncness {
errors.append_all(err!(asyncness, "Entry method should not be async"));
}
if let Some(constness) = f.sig.constness {
errors.append_all(err!(constness, "Entry method should not be const"));
}
if !f.sig.generics.params.is_empty() {
errors.append_all(err!(
&f.sig.generics.params,
"Entry method should not be generic"
));
}

// show most errors at once instead of one by one
if !errors.is_empty() {
return errors.into();
}

// allow the entry function to be unsafe (by moving the keyword around so that it actually works)
let unsafety = f.sig.unsafety.take();
// strip any visibility modifiers
f.vis = Visibility::Inherited;

let ident = &f.sig.ident;

let result = quote! {
static _UEFI_ENTRY_POINT_TYPE_CHECK: extern "efiapi" fn(uefi::Handle, uefi::table::SystemTable<uefi::table::Boot>) -> uefi::Status = efi_main;
#[no_mangle]
pub extern "efiapi" #f
#[export_name = "efi_main"]
#unsafety extern "efiapi" #f

// typecheck the function pointer
const _: #unsafety extern "efiapi" fn(::uefi::Handle, ::uefi::table::SystemTable<::uefi::table::Boot>) -> ::uefi::Status = #ident;
};
result.into()
}