Skip to content

Commit 1da0de8

Browse files
committed
avoid creating PyRef inside __traverse__ handler
1 parent 15e00ba commit 1da0de8

File tree

6 files changed

+120
-43
lines changed

6 files changed

+120
-43
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ serde_json = "1.0.61"
6262
rayon = "1.6.1"
6363
futures = "0.3.28"
6464
tempfile = "3.12.0"
65+
static_assertions = "1.1.0"
6566

6667
[build-dependencies]
6768
pyo3-build-config = { path = "pyo3-build-config", version = "=0.23.0-dev", features = ["resolve-config"] }

newsfragments/4479.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Remove illegal reference counting op inside implementation of `__traverse__` handlers.

src/impl_/pymethods.rs

Lines changed: 68 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ use crate::callback::IntoPyCallbackOutput;
22
use crate::exceptions::PyStopAsyncIteration;
33
use crate::gil::LockGIL;
44
use crate::impl_::panic::PanicTrap;
5+
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
6+
use crate::pycell::impl_::PyClassBorrowChecker as _;
57
use crate::pycell::{PyBorrowError, PyBorrowMutError};
68
use crate::pyclass::boolean_struct::False;
79
use crate::types::any::PyAnyMethods;
810
use crate::{
9-
ffi, Borrowed, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject,
10-
PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
11+
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
12+
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
1113
};
1214
use std::ffi::CStr;
1315
use std::fmt;
16+
use std::marker::PhantomData;
1417
use std::os::raw::{c_int, c_void};
1518
use std::panic::{catch_unwind, AssertUnwindSafe};
1619
use std::ptr::null_mut;
@@ -232,6 +235,40 @@ impl PySetterDef {
232235
}
233236

234237
/// Calls an implementation of __traverse__ for tp_traverse
238+
///
239+
/// NB cannot accept `'static` visitor, this is a sanity check below:
240+
///
241+
/// ```rust,compile_fail
242+
/// use pyo3::prelude::*;
243+
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
244+
///
245+
/// #[pyclass]
246+
/// struct Foo;
247+
///
248+
/// #[pymethods]
249+
/// impl Foo {
250+
/// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
251+
/// Ok(())
252+
/// }
253+
/// }
254+
/// ```
255+
///
256+
/// Elided lifetime should compile ok:
257+
///
258+
/// ```rust
259+
/// use pyo3::prelude::*;
260+
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
261+
///
262+
/// #[pyclass]
263+
/// struct Foo;
264+
///
265+
/// #[pymethods]
266+
/// impl Foo {
267+
/// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
268+
/// Ok(())
269+
/// }
270+
/// }
271+
/// ```
235272
#[doc(hidden)]
236273
pub unsafe fn _call_traverse<T>(
237274
slf: *mut ffi::PyObject,
@@ -250,25 +287,43 @@ where
250287
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
251288
// token does not produce any owned objects thereby calling into `register_owned`.
252289
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
290+
let lock = LockGIL::during_traverse();
291+
292+
// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
293+
// traversal is running so no mutations can occur.
294+
let class_object: &PyClassObject<T> = &*slf.cast();
295+
296+
let retval =
297+
// `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
298+
// do not traverse them if not on their owning thread :(
299+
if class_object.check_threadsafe().is_ok()
300+
// ... and we cannot traverse a type which might be being mutated by a Rust thread
301+
&& class_object.borrow_checker().try_borrow().is_ok() {
302+
struct TraverseGuard<'a, T: PyClass>(&'a PyClassObject<T>);
303+
impl<'a, T: PyClass> Drop for TraverseGuard<'a, T> {
304+
fn drop(&mut self) {
305+
self.0.borrow_checker().release_borrow()
306+
}
307+
}
253308

254-
let py = Python::assume_gil_acquired();
255-
let slf = Borrowed::from_ptr_unchecked(py, slf).downcast_unchecked::<T>();
256-
let borrow = PyRef::try_borrow_threadsafe(&slf);
257-
let visit = PyVisit::from_raw(visit, arg, py);
309+
// `.try_borrow()` above created a borrow, we need to release it when we're done
310+
// traversing the object. This allows us to read `instance` safely.
311+
let _guard = TraverseGuard(class_object);
312+
let instance = &*class_object.contents.value.get();
258313

259-
let retval = if let Ok(borrow) = borrow {
260-
let _lock = LockGIL::during_traverse();
314+
let visit = PyVisit { visit, arg, _guard: PhantomData };
261315

262-
match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
263-
Ok(res) => match res {
264-
Ok(()) => 0,
265-
Err(PyTraverseError(value)) => value,
266-
},
316+
match catch_unwind(AssertUnwindSafe(move || impl_(instance, visit))) {
317+
Ok(Ok(())) => 0,
318+
Ok(Err(traverse_error)) => traverse_error.into_inner(),
267319
Err(_err) => -1,
268320
}
269321
} else {
270322
0
271323
};
324+
325+
// Drop lock before trap just in case dropping lock panics
326+
drop(lock);
272327
trap.disarm();
273328
retval
274329
}

src/pycell.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,6 @@ impl<'py, T: PyClass> PyRef<'py, T> {
312312
.try_borrow()
313313
.map(|_| Self { inner: obj.clone() })
314314
}
315-
316-
pub(crate) fn try_borrow_threadsafe(obj: &Bound<'py, T>) -> Result<Self, PyBorrowError> {
317-
let cell = obj.get_class_object();
318-
cell.check_threadsafe()?;
319-
cell.borrow_checker()
320-
.try_borrow()
321-
.map(|_| Self { inner: obj.clone() })
322-
}
323315
}
324316

325317
impl<'p, T, U> PyRef<'p, T>

src/pyclass.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ mod create_type_object;
66
mod gc;
77

88
pub(crate) use self::create_type_object::{create_type_object, PyClassTypeObject};
9+
910
pub use self::gc::{PyTraverseError, PyVisit};
1011

1112
/// Types that can be used as Python classes.

src/pyclass/gc.rs

Lines changed: 49 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,75 @@
1-
use std::os::raw::{c_int, c_void};
1+
use std::{
2+
marker::PhantomData,
3+
os::raw::{c_int, c_void},
4+
};
25

3-
use crate::{ffi, AsPyPointer, Python};
6+
use crate::{ffi, AsPyPointer};
47

58
/// Error returned by a `__traverse__` visitor implementation.
69
#[repr(transparent)]
7-
pub struct PyTraverseError(pub(crate) c_int);
10+
pub struct PyTraverseError(NonZeroCInt);
11+
12+
impl PyTraverseError {
13+
/// Returns the error code.
14+
pub(crate) fn into_inner(self) -> c_int {
15+
self.0.into()
16+
}
17+
}
818

919
/// Object visitor for GC.
1020
#[derive(Clone)]
11-
pub struct PyVisit<'p> {
21+
pub struct PyVisit<'a> {
1222
pub(crate) visit: ffi::visitproc,
1323
pub(crate) arg: *mut c_void,
14-
/// VisitProc contains a Python instance to ensure that
15-
/// 1) it is cannot be moved out of the traverse() call
16-
/// 2) it cannot be sent to other threads
17-
pub(crate) _py: Python<'p>,
24+
/// Prevents the `PyVisit` from outliving the `__traverse__` call.
25+
pub(crate) _guard: PhantomData<&'a ()>,
1826
}
1927

20-
impl<'p> PyVisit<'p> {
28+
impl<'a> PyVisit<'a> {
2129
/// Visit `obj`.
2230
pub fn call<T>(&self, obj: &T) -> Result<(), PyTraverseError>
2331
where
2432
T: AsPyPointer,
2533
{
2634
let ptr = obj.as_ptr();
2735
if !ptr.is_null() {
28-
let r = unsafe { (self.visit)(ptr, self.arg) };
29-
if r == 0 {
30-
Ok(())
31-
} else {
32-
Err(PyTraverseError(r))
36+
match NonZeroCInt::new(unsafe { (self.visit)(ptr, self.arg) }) {
37+
None => Ok(()),
38+
Some(r) => Err(PyTraverseError(r)),
3339
}
3440
} else {
3541
Ok(())
3642
}
3743
}
44+
}
3845

39-
/// Creates the PyVisit from the arguments to tp_traverse
40-
#[doc(hidden)]
41-
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, py: Python<'p>) -> Self {
42-
Self {
43-
visit,
44-
arg,
45-
_py: py,
46-
}
46+
/// Workaround for `NonZero<c_int>` not being available until MSRV 1.79
47+
mod get_nonzero_c_int {
48+
pub struct GetNonZeroCInt<const WIDTH: usize>();
49+
50+
pub trait NonZeroCIntType {
51+
type Type;
52+
}
53+
impl NonZeroCIntType for GetNonZeroCInt<16> {
54+
type Type = std::num::NonZeroI16;
55+
}
56+
impl NonZeroCIntType for GetNonZeroCInt<32> {
57+
type Type = std::num::NonZeroI32;
58+
}
59+
60+
pub type Type =
61+
<GetNonZeroCInt<{ std::mem::size_of::<std::os::raw::c_int>() * 8 }> as NonZeroCIntType>::Type;
62+
}
63+
64+
use get_nonzero_c_int::Type as NonZeroCInt;
65+
66+
#[cfg(test)]
67+
mod tests {
68+
use super::PyVisit;
69+
use static_assertions::assert_not_impl_any;
70+
71+
#[test]
72+
fn py_visit_not_send_sync() {
73+
assert_not_impl_any!(PyVisit<'_>: Send, Sync);
4774
}
4875
}

0 commit comments

Comments
 (0)