Skip to content

Commit 2ea93ce

Browse files
committed
Add support for pyo3(get(attr1, attr2, attr3))
1 parent aeeaf68 commit 2ea93ce

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

pyo3-macros-backend/src/attributes.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use proc_macro2::TokenStream;
22
use quote::{quote, ToTokens};
3+
use syn::parenthesized;
34
use syn::parse::Parser;
45
use syn::{
56
ext::IdentExt,
@@ -339,6 +340,57 @@ impl ToTokens for NewImplTypeAttributeValue {
339340
}
340341
}
341342

343+
#[derive(Clone)]
344+
pub struct GetListAttribute {
345+
pub get_token: kw::get,
346+
pub paren_token: syn::token::Paren,
347+
pub fields: Punctuated<Ident, Token![,]>,
348+
}
349+
350+
impl syn::parse::Parse for GetListAttribute {
351+
fn parse(input: syn::parse::ParseStream<'_>) -> Result<Self> {
352+
// Parse the keyword: get
353+
let get_token: kw::get = input.parse()?;
354+
355+
// Parse the parentheses: ( ... )
356+
let content;
357+
let paren_token = parenthesized!(content in input);
358+
359+
// Parse identifiers inside: a, b, c
360+
let fields =
361+
content.parse_terminated(Ident::parse, Token![,])?;
362+
363+
// Reject empty list: get()
364+
if fields.is_empty() {
365+
return Err(syn::Error::new(
366+
paren_token.span.join(),
367+
"`get(...)` must contain at least one field name",
368+
));
369+
}
370+
371+
Ok(GetListAttribute {
372+
get_token,
373+
paren_token,
374+
fields,
375+
})
376+
}
377+
}
378+
379+
impl ToTokens for GetListAttribute {
380+
fn to_tokens(&self, tokens: &mut TokenStream) {
381+
// keyword `get`
382+
self.get_token.to_tokens(tokens);
383+
384+
// parentheses
385+
let paren_content = self.fields.iter().map(|f| quote::quote! { #f });
386+
let paren_tokens = quote::quote! { ( #( #paren_content ),* ) };
387+
388+
tokens.extend(quote::quote! {
389+
#paren_tokens
390+
});
391+
}
392+
}
393+
342394
pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
343395
pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
344396
pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;

pyo3-macros-backend/src/pyclass.rs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::attributes::kw::frozen;
1212
use crate::attributes::{
1313
self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute,
1414
ModuleAttribute, NameAttribute, NameLitStr, NewImplTypeAttribute, NewImplTypeAttributeValue,
15-
RenameAllAttribute, StrFormatterAttribute,
15+
RenameAllAttribute, StrFormatterAttribute, GetListAttribute
1616
};
1717
use crate::combine_errors::CombineErrors;
1818
#[cfg(feature = "experimental-inspect")]
@@ -97,6 +97,7 @@ pub struct PyClassPyO3Options {
9797
pub generic: Option<kw::generic>,
9898
pub from_py_object: Option<kw::from_py_object>,
9999
pub skip_from_py_object: Option<kw::skip_from_py_object>,
100+
pub get: Option<GetListAttribute>
100101
}
101102

102103
pub enum PyClassPyO3Option {
@@ -125,6 +126,7 @@ pub enum PyClassPyO3Option {
125126
Generic(kw::generic),
126127
FromPyObject(kw::from_py_object),
127128
SkipFromPyObject(kw::skip_from_py_object),
129+
Get(GetListAttribute)
128130
}
129131

130132
impl Parse for PyClassPyO3Option {
@@ -180,6 +182,8 @@ impl Parse for PyClassPyO3Option {
180182
input.parse().map(PyClassPyO3Option::FromPyObject)
181183
} else if lookahead.peek(attributes::kw::skip_from_py_object) {
182184
input.parse().map(PyClassPyO3Option::SkipFromPyObject)
185+
} else if lookahead.peek(attributes::kw::get) {
186+
input.parse().map(PyClassPyO3Option::Get)
183187
} else {
184188
Err(lookahead.error())
185189
}
@@ -274,6 +278,7 @@ impl PyClassPyO3Options {
274278
);
275279
set_option!(from_py_object)
276280
}
281+
PyClassPyO3Option::Get(get) => set_option!(get)
277282
}
278283
Ok(())
279284
}
@@ -359,6 +364,24 @@ pub fn build_py_class(
359364
}
360365
}
361366

367+
if let Some(get_list_attr) = &args.options.get {
368+
// get_list_attr contains the list of desired field names (NameAttribute or Ident)
369+
for name in get_list_attr.fields.iter() {
370+
// find matching field in `field_options`:
371+
if let Some((_, field_opts)) =
372+
field_options.iter_mut().find(|(f, _)| match &f.ident {
373+
Some(ident) => ident == name,
374+
None => false,
375+
}) {
376+
if let Some(old_get) = field_opts.get.replace(Annotated::Struct(kw::get_all::default())) {
377+
return Err(syn::Error::new(old_get.span(), "duplicate get specified"));
378+
}
379+
} else {
380+
return Err(syn::Error::new_spanned(get_list_attr.clone(), format!("no field named `{}`", name)));
381+
}
382+
}
383+
}
384+
362385
impl_class(&class.ident, &args, doc, field_options, methods_type, ctx)
363386
}
364387

0 commit comments

Comments
 (0)