Skip to content

Commit a15e4b1

Browse files
authored
Allow pymodule functions to take a single Bound<'_, PyModule> arg (#3905)
1 parent 6f03a54 commit a15e4b1

File tree

5 files changed

+45
-1
lines changed

5 files changed

+45
-1
lines changed

newsfragments/3905.changed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The `#[pymodule]` macro now supports module functions that take a single argument as a `&Bound<'_, PyModule>`.

pyo3-macros-backend/src/module.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
201201
let doc = get_doc(&function.attrs, None);
202202

203203
let initialization = module_initialization(options, ident);
204+
205+
// Module function called with optional Python<'_> marker as first arg, followed by the module.
206+
let mut module_args = Vec::new();
207+
if function.sig.inputs.len() == 2 {
208+
module_args.push(quote!(module.py()));
209+
}
210+
module_args.push(quote!(::std::convert::Into::into(BoundRef(module))));
211+
204212
Ok(quote! {
205213
#function
206214
#vis mod #ident {
@@ -218,7 +226,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result<TokenStream>
218226
use #krate::impl_::pymethods::BoundRef;
219227

220228
fn __pyo3_pymodule(module: &#krate::Bound<'_, #krate::types::PyModule>) -> #krate::PyResult<()> {
221-
#ident(module.py(), ::std::convert::Into::into(BoundRef(module)))
229+
#ident(#(#module_args),*)
222230
}
223231

224232
impl #ident::MakeDef {

src/tests/hygiene/pyfunction.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,11 @@ fn invoke_wrap_pyfunction() {
1414
crate::py_run!(py, func, r#"func(5)"#);
1515
});
1616
}
17+
18+
#[test]
19+
fn invoke_wrap_pyfunction_bound() {
20+
crate::Python::with_gil(|py| {
21+
let func = crate::wrap_pyfunction_bound!(do_something, py).unwrap();
22+
crate::py_run!(py, func, r#"func(5)"#);
23+
});
24+
}

src/tests/hygiene/pymodule.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,18 @@ fn my_module(_py: crate::Python<'_>, m: &crate::types::PyModule) -> crate::PyRes
2121

2222
::std::result::Result::Ok(())
2323
}
24+
25+
#[crate::pymodule]
26+
#[pyo3(crate = "crate")]
27+
fn my_module_bound(m: &crate::Bound<'_, crate::types::PyModule>) -> crate::PyResult<()> {
28+
<crate::Bound<'_, crate::types::PyModule> as crate::types::PyModuleMethods>::add_function(
29+
m,
30+
crate::wrap_pyfunction_bound!(do_something, m)?,
31+
)?;
32+
<crate::Bound<'_, crate::types::PyModule> as crate::types::PyModuleMethods>::add_wrapped(
33+
m,
34+
crate::wrap_pymodule!(foo),
35+
)?;
36+
37+
::std::result::Result::Ok(())
38+
}

tests/test_no_imports.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ fn basic_module(_py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyRes
2222
Ok(())
2323
}
2424

25+
#[pyo3::pymodule]
26+
fn basic_module_bound(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> {
27+
#[pyfn(m)]
28+
fn answer() -> usize {
29+
42
30+
}
31+
32+
m.add_function(pyo3::wrap_pyfunction_bound!(basic_function, m)?)?;
33+
34+
Ok(())
35+
}
36+
2537
#[pyo3::pyclass]
2638
struct BasicClass {
2739
#[pyo3(get)]

0 commit comments

Comments
 (0)