@@ -2,15 +2,18 @@ use crate::callback::IntoPyCallbackOutput;
2
2
use crate :: exceptions:: PyStopAsyncIteration ;
3
3
use crate :: gil:: LockGIL ;
4
4
use crate :: impl_:: panic:: PanicTrap ;
5
+ use crate :: impl_:: pycell:: { PyClassObject , PyClassObjectLayout } ;
6
+ use crate :: pycell:: impl_:: PyClassBorrowChecker as _;
5
7
use crate :: pycell:: { PyBorrowError , PyBorrowMutError } ;
6
8
use crate :: pyclass:: boolean_struct:: False ;
7
9
use crate :: types:: any:: PyAnyMethods ;
8
10
use crate :: {
9
- ffi, Borrowed , Bound , DowncastError , Py , PyAny , PyClass , PyClassInitializer , PyErr , PyObject ,
10
- PyRef , PyRefMut , PyResult , PyTraverseError , PyTypeCheck , PyVisit , Python ,
11
+ ffi, Bound , DowncastError , Py , PyAny , PyClass , PyClassInitializer , PyErr , PyObject , PyRef ,
12
+ PyRefMut , PyResult , PyTraverseError , PyTypeCheck , PyVisit , Python ,
11
13
} ;
12
14
use std:: ffi:: CStr ;
13
15
use std:: fmt;
16
+ use std:: marker:: PhantomData ;
14
17
use std:: os:: raw:: { c_int, c_void} ;
15
18
use std:: panic:: { catch_unwind, AssertUnwindSafe } ;
16
19
use std:: ptr:: null_mut;
@@ -232,6 +235,40 @@ impl PySetterDef {
232
235
}
233
236
234
237
/// Calls an implementation of __traverse__ for tp_traverse
238
+ ///
239
+ /// NB cannot accept `'static` visitor, this is a sanity check below:
240
+ ///
241
+ /// ```rust,compile_fail
242
+ /// use pyo3::prelude::*;
243
+ /// use pyo3::pyclass::{PyTraverseError, PyVisit};
244
+ ///
245
+ /// #[pyclass]
246
+ /// struct Foo;
247
+ ///
248
+ /// #[pymethods]
249
+ /// impl Foo {
250
+ /// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
251
+ /// Ok(())
252
+ /// }
253
+ /// }
254
+ /// ```
255
+ ///
256
+ /// Elided lifetime should compile ok:
257
+ ///
258
+ /// ```rust
259
+ /// use pyo3::prelude::*;
260
+ /// use pyo3::pyclass::{PyTraverseError, PyVisit};
261
+ ///
262
+ /// #[pyclass]
263
+ /// struct Foo;
264
+ ///
265
+ /// #[pymethods]
266
+ /// impl Foo {
267
+ /// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
268
+ /// Ok(())
269
+ /// }
270
+ /// }
271
+ /// ```
235
272
#[ doc( hidden) ]
236
273
pub unsafe fn _call_traverse < T > (
237
274
slf : * mut ffi:: PyObject ,
@@ -250,25 +287,43 @@ where
250
287
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
251
288
// token does not produce any owned objects thereby calling into `register_owned`.
252
289
let trap = PanicTrap :: new ( "uncaught panic inside __traverse__ handler" ) ;
290
+ let lock = LockGIL :: during_traverse ( ) ;
291
+
292
+ // SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
293
+ // traversal is running so no mutations can occur.
294
+ let class_object: & PyClassObject < T > = & * slf. cast ( ) ;
295
+
296
+ let retval =
297
+ // `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
298
+ // do not traverse them if not on their owning thread :(
299
+ if class_object. check_threadsafe ( ) . is_ok ( )
300
+ // ... and we cannot traverse a type which might be being mutated by a Rust thread
301
+ && class_object. borrow_checker ( ) . try_borrow ( ) . is_ok ( ) {
302
+ struct TraverseGuard < ' a , T : PyClass > ( & ' a PyClassObject < T > ) ;
303
+ impl < ' a , T : PyClass > Drop for TraverseGuard < ' a , T > {
304
+ fn drop ( & mut self ) {
305
+ self . 0 . borrow_checker ( ) . release_borrow ( )
306
+ }
307
+ }
253
308
254
- let py = Python :: assume_gil_acquired ( ) ;
255
- let slf = Borrowed :: from_ptr_unchecked ( py , slf ) . downcast_unchecked :: < T > ( ) ;
256
- let borrow = PyRef :: try_borrow_threadsafe ( & slf ) ;
257
- let visit = PyVisit :: from_raw ( visit , arg , py ) ;
309
+ // `.try_borrow()` above created a borrow, we need to release it when we're done
310
+ // traversing the object. This allows us to read `instance` safely.
311
+ let _guard = TraverseGuard ( class_object ) ;
312
+ let instance = & * class_object . contents . value . get ( ) ;
258
313
259
- let retval = if let Ok ( borrow) = borrow {
260
- let _lock = LockGIL :: during_traverse ( ) ;
314
+ let visit = PyVisit { visit, arg, _guard : PhantomData } ;
261
315
262
- match catch_unwind ( AssertUnwindSafe ( move || impl_ ( & * borrow, visit) ) ) {
263
- Ok ( res) => match res {
264
- Ok ( ( ) ) => 0 ,
265
- Err ( PyTraverseError ( value) ) => value,
266
- } ,
316
+ match catch_unwind ( AssertUnwindSafe ( move || impl_ ( instance, visit) ) ) {
317
+ Ok ( Ok ( ( ) ) ) => 0 ,
318
+ Ok ( Err ( traverse_error) ) => traverse_error. into_inner ( ) ,
267
319
Err ( _err) => -1 ,
268
320
}
269
321
} else {
270
322
0
271
323
} ;
324
+
325
+ // Drop lock before trap just in case dropping lock panics
326
+ drop ( lock) ;
272
327
trap. disarm ( ) ;
273
328
retval
274
329
}
0 commit comments