Skip to content

Commit 9f66846

Browse files
authored
Merge pull request #3595 from davidhewitt/ok-wrap
refactor `OkWrap` to not call `.into_py(py)`
2 parents cbd0630 + c814078 commit 9f66846

File tree

7 files changed

+69
-40
lines changed

7 files changed

+69
-40
lines changed

pyo3-macros-backend/src/method.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -455,11 +455,13 @@ impl<'a> FnSpec<'a> {
455455
let func_name = &self.name;
456456

457457
let rust_call = |args: Vec<TokenStream>| {
458-
let mut call = quote! { function(#self_arg #(#args),*) };
459-
if self.asyncness.is_some() {
460-
call = quote! { _pyo3::impl_::coroutine::wrap_future(#call) };
461-
}
462-
quotes::map_result_into_ptr(quotes::ok_wrap(call))
458+
let call = quote! { function(#self_arg #(#args),*) };
459+
let wrapped_call = if self.asyncness.is_some() {
460+
quote! { _pyo3::PyResult::Ok(_pyo3::impl_::wrap::wrap_future(#call)) }
461+
} else {
462+
quotes::ok_wrap(call)
463+
};
464+
quotes::map_result_into_ptr(wrapped_call)
463465
};
464466

465467
let rust_name = if let Some(cls) = cls {

pyo3-macros-backend/src/pymethod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ fn impl_py_class_attribute(cls: &syn::Type, spec: &FnSpec<'_>) -> syn::Result<Me
458458
let associated_method = quote! {
459459
fn #wrapper_ident(py: _pyo3::Python<'_>) -> _pyo3::PyResult<_pyo3::PyObject> {
460460
let function = #cls::#name; // Shadow the method name to avoid #3017
461-
#body
461+
_pyo3::impl_::wrap::map_result_into_py(py, #body)
462462
}
463463
};
464464

pyo3-macros-backend/src/quotes.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ pub(crate) fn some_wrap(obj: TokenStream) -> TokenStream {
99

1010
pub(crate) fn ok_wrap(obj: TokenStream) -> TokenStream {
1111
quote! {
12-
_pyo3::impl_::wrap::OkWrap::wrap(#obj, py)
12+
_pyo3::impl_::wrap::OkWrap::wrap(#obj)
1313
.map_err(::core::convert::Into::<_pyo3::PyErr>::into)
1414
}
1515
}
1616

1717
pub(crate) fn map_result_into_ptr(result: TokenStream) -> TokenStream {
1818
quote! {
19-
#result.map(_pyo3::PyObject::into_ptr)
19+
_pyo3::impl_::wrap::map_result_into_ptr(py, #result)
2020
}
2121
}

src/impl_.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
//! APIs may may change at any time without documentation in the CHANGELOG and without
77
//! breaking semver guarantees.
88
9-
#[cfg(feature = "macros")]
10-
pub mod coroutine;
119
pub mod deprecations;
1210
pub mod extract_argument;
1311
pub mod freelist;

src/impl_/coroutine.rs

Lines changed: 0 additions & 19 deletions
This file was deleted.

src/impl_/wrap.rs

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
use crate::{IntoPy, Py, PyAny, PyErr, PyObject, PyResult, Python};
1+
use std::convert::Infallible;
2+
3+
use crate::{ffi, IntoPy, PyObject, PyResult, Python};
24

35
/// Used to wrap values in `Option<T>` for default arguments.
46
pub trait SomeWrap<T> {
5-
fn wrap(self) -> T;
7+
fn wrap(self) -> Option<T>;
68
}
79

8-
impl<T> SomeWrap<Option<T>> for T {
10+
impl<T> SomeWrap<T> for T {
911
fn wrap(self) -> Option<T> {
1012
Some(self)
1113
}
1214
}
1315

14-
impl<T> SomeWrap<Option<T>> for Option<T> {
16+
impl<T> SomeWrap<T> for Option<T> {
1517
fn wrap(self) -> Self {
1618
self
1719
}
@@ -20,7 +22,7 @@ impl<T> SomeWrap<Option<T>> for Option<T> {
2022
/// Used to wrap the result of `#[pyfunction]` and `#[pymethods]`.
2123
pub trait OkWrap<T> {
2224
type Error;
23-
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error>;
25+
fn wrap(self) -> Result<T, Self::Error>;
2426
}
2527

2628
// The T: IntoPy<PyObject> bound here is necessary to prevent the
@@ -29,9 +31,10 @@ impl<T> OkWrap<T> for T
2931
where
3032
T: IntoPy<PyObject>,
3133
{
32-
type Error = PyErr;
33-
fn wrap(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
34-
Ok(self.into_py(py))
34+
type Error = Infallible;
35+
#[inline]
36+
fn wrap(self) -> Result<T, Infallible> {
37+
Ok(self)
3538
}
3639
}
3740

@@ -40,11 +43,44 @@ where
4043
T: IntoPy<PyObject>,
4144
{
4245
type Error = E;
43-
fn wrap(self, py: Python<'_>) -> Result<Py<PyAny>, Self::Error> {
44-
self.map(|o| o.into_py(py))
46+
#[inline]
47+
fn wrap(self) -> Result<T, Self::Error> {
48+
self
4549
}
4650
}
4751

52+
/// This is a follow-up function to `OkWrap::wrap` that converts the result into
53+
/// a `*mut ffi::PyObject` pointer.
54+
pub fn map_result_into_ptr<T: IntoPy<PyObject>>(
55+
py: Python<'_>,
56+
result: PyResult<T>,
57+
) -> PyResult<*mut ffi::PyObject> {
58+
result.map(|obj| obj.into_py(py).into_ptr())
59+
}
60+
61+
/// This is a follow-up function to `OkWrap::wrap` that converts the result into
62+
/// a safe wrapper.
63+
pub fn map_result_into_py<T: IntoPy<PyObject>>(
64+
py: Python<'_>,
65+
result: PyResult<T>,
66+
) -> PyResult<PyObject> {
67+
result.map(|err| err.into_py(py))
68+
}
69+
70+
/// Used to wrap the result of async `#[pyfunction]` and `#[pymethods]`.
71+
#[cfg(feature = "macros")]
72+
pub fn wrap_future<F, R, T>(future: F) -> crate::coroutine::Coroutine
73+
where
74+
F: std::future::Future<Output = R> + Send + 'static,
75+
R: OkWrap<T>,
76+
T: IntoPy<PyObject>,
77+
crate::PyErr: From<R::Error>,
78+
{
79+
crate::coroutine::Coroutine::from_future::<_, T, crate::PyErr>(async move {
80+
OkWrap::wrap(future.await).map_err(Into::into)
81+
})
82+
}
83+
4884
#[cfg(test)]
4985
mod tests {
5086
use super::*;
@@ -57,4 +93,16 @@ mod tests {
5793
let b: Option<u8> = SomeWrap::wrap(None);
5894
assert_eq!(b, None);
5995
}
96+
97+
#[test]
98+
fn wrap_result() {
99+
let a: Result<u8, _> = OkWrap::wrap(42u8);
100+
assert!(matches!(a, Ok(42)));
101+
102+
let b: PyResult<u8> = OkWrap::wrap(Ok(42u8));
103+
assert!(matches!(b, Ok(42)));
104+
105+
let c: Result<u8, &str> = OkWrap::wrap(Err("error"));
106+
assert_eq!(c, Err("error"));
107+
}
60108
}

tests/test_compile_error.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ fn test_compile_errors() {
3333
t.compile_fail("tests/ui/invalid_pymethod_receiver.rs");
3434
t.compile_fail("tests/ui/missing_intopy.rs");
3535
// adding extra error conversion impls changes the output
36-
#[cfg(all(target_os = "linux", not(any(feature = "eyre", feature = "anyhow"))))]
36+
#[cfg(not(any(windows, feature = "eyre", feature = "anyhow")))]
3737
t.compile_fail("tests/ui/invalid_result_conversion.rs");
3838
t.compile_fail("tests/ui/not_send.rs");
3939
t.compile_fail("tests/ui/not_send2.rs");

0 commit comments

Comments
 (0)