Skip to content

Commit 7baface

Browse files
committed
Add support for pyo3(get(attr1, attr2, attr3))
1 parent 18aa0c2 commit 7baface

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

pyo3-macros-backend/src/attributes.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use proc_macro2::TokenStream;
22
use quote::{quote, ToTokens};
3+
use syn::parenthesized;
34
use syn::parse::Parser;
45
use syn::{
56
ext::IdentExt,
@@ -311,6 +312,57 @@ impl ToTokens for TextSignatureAttributeValue {
311312
}
312313
}
313314

315+
#[derive(Clone)]
316+
pub struct GetListAttribute {
317+
pub get_token: kw::get,
318+
pub paren_token: syn::token::Paren,
319+
pub fields: Punctuated<Ident, Token![,]>,
320+
}
321+
322+
impl syn::parse::Parse for GetListAttribute {
323+
fn parse(input: syn::parse::ParseStream<'_>) -> Result<Self> {
324+
// Parse the keyword: get
325+
let get_token: kw::get = input.parse()?;
326+
327+
// Parse the parentheses: ( ... )
328+
let content;
329+
let paren_token = parenthesized!(content in input);
330+
331+
// Parse identifiers inside: a, b, c
332+
let fields =
333+
content.parse_terminated(Ident::parse, Token![,])?;
334+
335+
// Reject empty list: get()
336+
if fields.is_empty() {
337+
return Err(syn::Error::new(
338+
paren_token.span.join(),
339+
"`get(...)` must contain at least one field name",
340+
));
341+
}
342+
343+
Ok(GetListAttribute {
344+
get_token,
345+
paren_token,
346+
fields,
347+
})
348+
}
349+
}
350+
351+
impl ToTokens for GetListAttribute {
352+
fn to_tokens(&self, tokens: &mut TokenStream) {
353+
// keyword `get`
354+
self.get_token.to_tokens(tokens);
355+
356+
// parentheses
357+
let paren_content = self.fields.iter().map(|f| quote::quote! { #f });
358+
let paren_tokens = quote::quote! { ( #( #paren_content ),* ) };
359+
360+
tokens.extend(quote::quote! {
361+
#paren_tokens
362+
});
363+
}
364+
}
365+
314366
pub type ExtendsAttribute = KeywordAttribute<kw::extends, Path>;
315367
pub type FreelistAttribute = KeywordAttribute<kw::freelist, Box<Expr>>;
316368
pub type ModuleAttribute = KeywordAttribute<kw::module, LitStr>;

pyo3-macros-backend/src/pyclass.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::attributes::kw::frozen;
1212
use crate::attributes::{
1313
self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute,
1414
ModuleAttribute, NameAttribute, NameLitStr, RenameAllAttribute, StrFormatterAttribute,
15+
GetListAttribute
1516
};
1617
use crate::combine_errors::CombineErrors;
1718
#[cfg(feature = "experimental-inspect")]
@@ -95,6 +96,7 @@ pub struct PyClassPyO3Options {
9596
pub generic: Option<kw::generic>,
9697
pub from_py_object: Option<kw::from_py_object>,
9798
pub skip_from_py_object: Option<kw::skip_from_py_object>,
99+
pub get: Option<GetListAttribute>
98100
}
99101

100102
pub enum PyClassPyO3Option {
@@ -122,6 +124,7 @@ pub enum PyClassPyO3Option {
122124
Generic(kw::generic),
123125
FromPyObject(kw::from_py_object),
124126
SkipFromPyObject(kw::skip_from_py_object),
127+
Get(GetListAttribute)
125128
}
126129

127130
impl Parse for PyClassPyO3Option {
@@ -175,6 +178,8 @@ impl Parse for PyClassPyO3Option {
175178
input.parse().map(PyClassPyO3Option::FromPyObject)
176179
} else if lookahead.peek(attributes::kw::skip_from_py_object) {
177180
input.parse().map(PyClassPyO3Option::SkipFromPyObject)
181+
} else if lookahead.peek(attributes::kw::get) {
182+
input.parse().map(PyClassPyO3Option::Get)
178183
} else {
179184
Err(lookahead.error())
180185
}
@@ -268,6 +273,7 @@ impl PyClassPyO3Options {
268273
);
269274
set_option!(from_py_object)
270275
}
276+
PyClassPyO3Option::Get(get) => set_option!(get)
271277
}
272278
Ok(())
273279
}
@@ -353,6 +359,25 @@ pub fn build_py_class(
353359
}
354360
}
355361

362+
if let Some(get_list_attr) = &args.options.get {
363+
// get_list_attr contains the list of desired field names (NameAttribute or Ident)
364+
for name in get_list_attr.fields.iter() {
365+
// find matching field in `field_options`:
366+
if let Some((_, field_opts)) =
367+
field_options.iter_mut().find(|(f, _)| match &f.ident {
368+
Some(ident) => ident == name,
369+
None => false,
370+
}) {
371+
if let Some(old_get) = field_opts.get.replace(Annotated::Struct(kw::get_all::default())) {
372+
return Err(syn::Error::new(old_get.span(), "duplicate get specified"));
373+
}
374+
} else {
375+
return Err(syn::Error::new_spanned(get_list_attr.clone(), format!("no field named `{}`", name)));
376+
}
377+
}
378+
}
379+
380+
356381
impl_class(&class.ident, &args, doc, field_options, methods_type, ctx)
357382
}
358383

0 commit comments

Comments
 (0)