Skip to content

Commit ef0fa17

Browse files
committed
avoid creating PyRef inside __traverse__ handler (#4479)
1 parent eb59c54 commit ef0fa17

File tree

6 files changed

+120
-43
lines changed

6 files changed

+120
-43
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ serde = { version = "1.0", features = ["derive"] }
6161
serde_json = "1.0.61"
6262
rayon = "1.6.1"
6363
futures = "0.3.28"
64+
static_assertions = "1.1.0"
6465

6566
[build-dependencies]
6667
pyo3-build-config = { path = "pyo3-build-config", version = "=0.22.2", features = ["resolve-config"] }

newsfragments/4479.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Remove illegal reference counting op inside implementation of `__traverse__` handlers.

src/impl_/pymethods.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@ use crate::callback::IntoPyCallbackOutput;
22
use crate::exceptions::PyStopAsyncIteration;
33
use crate::gil::LockGIL;
44
use crate::impl_::panic::PanicTrap;
5+
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
6+
use crate::pycell::impl_::PyClassBorrowChecker as _;
57
use crate::pycell::{PyBorrowError, PyBorrowMutError};
68
use crate::pyclass::boolean_struct::False;
79
use crate::types::any::PyAnyMethods;
810
#[cfg(feature = "gil-refs")]
911
use crate::types::{PyModule, PyType};
1012
use crate::{
11-
ffi, Borrowed, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject,
12-
PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
13+
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
14+
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
1315
};
1416
use std::ffi::CStr;
1517
use std::fmt;
18+
use std::marker::PhantomData;
1619
use std::os::raw::{c_int, c_void};
1720
use std::panic::{catch_unwind, AssertUnwindSafe};
1821
use std::ptr::null_mut;
@@ -234,6 +237,40 @@ impl PySetterDef {
234237
}
235238

236239
/// Calls an implementation of __traverse__ for tp_traverse
240+
///
241+
/// NB cannot accept `'static` visitor, this is a sanity check below:
242+
///
243+
/// ```rust,compile_fail
244+
/// use pyo3::prelude::*;
245+
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
246+
///
247+
/// #[pyclass]
248+
/// struct Foo;
249+
///
250+
/// #[pymethods]
251+
/// impl Foo {
252+
/// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
253+
/// Ok(())
254+
/// }
255+
/// }
256+
/// ```
257+
///
258+
/// Elided lifetime should compile ok:
259+
///
260+
/// ```rust
261+
/// use pyo3::prelude::*;
262+
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
263+
///
264+
/// #[pyclass]
265+
/// struct Foo;
266+
///
267+
/// #[pymethods]
268+
/// impl Foo {
269+
/// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
270+
/// Ok(())
271+
/// }
272+
/// }
273+
/// ```
237274
#[doc(hidden)]
238275
pub unsafe fn _call_traverse<T>(
239276
slf: *mut ffi::PyObject,
@@ -252,25 +289,43 @@ where
252289
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
253290
// token does not produce any owned objects thereby calling into `register_owned`.
254291
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
292+
let lock = LockGIL::during_traverse();
293+
294+
// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
295+
// traversal is running so no mutations can occur.
296+
let class_object: &PyClassObject<T> = &*slf.cast();
297+
298+
let retval =
299+
// `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
300+
// do not traverse them if not on their owning thread :(
301+
if class_object.check_threadsafe().is_ok()
302+
// ... and we cannot traverse a type which might be being mutated by a Rust thread
303+
&& class_object.borrow_checker().try_borrow().is_ok() {
304+
struct TraverseGuard<'a, T: PyClass>(&'a PyClassObject<T>);
305+
impl<'a, T: PyClass> Drop for TraverseGuard<'a, T> {
306+
fn drop(&mut self) {
307+
self.0.borrow_checker().release_borrow()
308+
}
309+
}
255310

256-
let py = Python::assume_gil_acquired();
257-
let slf = Borrowed::from_ptr_unchecked(py, slf).downcast_unchecked::<T>();
258-
let borrow = PyRef::try_borrow_threadsafe(&slf);
259-
let visit = PyVisit::from_raw(visit, arg, py);
311+
// `.try_borrow()` above created a borrow, we need to release it when we're done
312+
// traversing the object. This allows us to read `instance` safely.
313+
let _guard = TraverseGuard(class_object);
314+
let instance = &*class_object.contents.value.get();
260315

261-
let retval = if let Ok(borrow) = borrow {
262-
let _lock = LockGIL::during_traverse();
316+
let visit = PyVisit { visit, arg, _guard: PhantomData };
263317

264-
match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
265-
Ok(res) => match res {
266-
Ok(()) => 0,
267-
Err(PyTraverseError(value)) => value,
268-
},
318+
match catch_unwind(AssertUnwindSafe(move || impl_(instance, visit))) {
319+
Ok(Ok(())) => 0,
320+
Ok(Err(traverse_error)) => traverse_error.into_inner(),
269321
Err(_err) => -1,
270322
}
271323
} else {
272324
0
273325
};
326+
327+
// Drop lock before trap just in case dropping lock panics
328+
drop(lock);
274329
trap.disarm();
275330
retval
276331
}

src/pycell.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -673,14 +673,6 @@ impl<'py, T: PyClass> PyRef<'py, T> {
673673
.try_borrow()
674674
.map(|_| Self { inner: obj.clone() })
675675
}
676-
677-
pub(crate) fn try_borrow_threadsafe(obj: &Bound<'py, T>) -> Result<Self, PyBorrowError> {
678-
let cell = obj.get_class_object();
679-
cell.check_threadsafe()?;
680-
cell.borrow_checker()
681-
.try_borrow()
682-
.map(|_| Self { inner: obj.clone() })
683-
}
684676
}
685677

686678
impl<'p, T, U> PyRef<'p, T>

src/pyclass.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ mod create_type_object;
99
mod gc;
1010

1111
pub(crate) use self::create_type_object::{create_type_object, PyClassTypeObject};
12+
1213
pub use self::gc::{PyTraverseError, PyVisit};
1314

1415
/// Types that can be used as Python classes.

src/pyclass/gc.rs

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,75 @@
1-
use std::os::raw::{c_int, c_void};
1+
use std::{
2+
marker::PhantomData,
3+
os::raw::{c_int, c_void},
4+
};
25

3-
use crate::{ffi, AsPyPointer, Python};
6+
use crate::{ffi, AsPyPointer};
47

58
/// Error returned by a `__traverse__` visitor implementation.
69
#[repr(transparent)]
7-
pub struct PyTraverseError(pub(crate) c_int);
10+
pub struct PyTraverseError(NonZeroCInt);
11+
12+
impl PyTraverseError {
13+
/// Returns the error code.
14+
pub(crate) fn into_inner(self) -> c_int {
15+
self.0.into()
16+
}
17+
}
818

919
/// Object visitor for GC.
1020
#[derive(Clone)]
11-
pub struct PyVisit<'p> {
21+
pub struct PyVisit<'a> {
1222
pub(crate) visit: ffi::visitproc,
1323
pub(crate) arg: *mut c_void,
14-
/// VisitProc contains a Python instance to ensure that
15-
/// 1) it is cannot be moved out of the traverse() call
16-
/// 2) it cannot be sent to other threads
17-
pub(crate) _py: Python<'p>,
24+
/// Prevents the `PyVisit` from outliving the `__traverse__` call.
25+
pub(crate) _guard: PhantomData<&'a ()>,
1826
}
1927

20-
impl<'p> PyVisit<'p> {
28+
impl<'a> PyVisit<'a> {
2129
/// Visit `obj`.
2230
pub fn call<T>(&self, obj: &T) -> Result<(), PyTraverseError>
2331
where
2432
T: AsPyPointer,
2533
{
2634
let ptr = obj.as_ptr();
2735
if !ptr.is_null() {
28-
let r = unsafe { (self.visit)(ptr, self.arg) };
29-
if r == 0 {
30-
Ok(())
31-
} else {
32-
Err(PyTraverseError(r))
36+
match NonZeroCInt::new(unsafe { (self.visit)(ptr, self.arg) }) {
37+
None => Ok(()),
38+
Some(r) => Err(PyTraverseError(r)),
3339
}
3440
} else {
3541
Ok(())
3642
}
3743
}
44+
}
3845

39-
/// Creates the PyVisit from the arguments to tp_traverse
40-
#[doc(hidden)]
41-
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, py: Python<'p>) -> Self {
42-
Self {
43-
visit,
44-
arg,
45-
_py: py,
46-
}
46+
/// Workaround for `NonZero<c_int>` not being available until MSRV 1.79
47+
mod get_nonzero_c_int {
48+
pub struct GetNonZeroCInt<const WIDTH: usize>();
49+
50+
pub trait NonZeroCIntType {
51+
type Type;
52+
}
53+
impl NonZeroCIntType for GetNonZeroCInt<16> {
54+
type Type = std::num::NonZeroI16;
55+
}
56+
impl NonZeroCIntType for GetNonZeroCInt<32> {
57+
type Type = std::num::NonZeroI32;
58+
}
59+
60+
pub type Type =
61+
<GetNonZeroCInt<{ std::mem::size_of::<std::os::raw::c_int>() * 8 }> as NonZeroCIntType>::Type;
62+
}
63+
64+
use get_nonzero_c_int::Type as NonZeroCInt;
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::PyVisit;
69+
use static_assertions::assert_not_impl_any;
70+
71+
#[test]
72+
fn py_visit_not_send_sync() {
73+
assert_not_impl_any!(PyVisit<'_>: Send, Sync);
4774
}
4875
}

0 commit comments

Comments
 (0)