@@ -2,17 +2,20 @@ use crate::callback::IntoPyCallbackOutput;
2
2
use crate :: exceptions:: PyStopAsyncIteration ;
3
3
use crate :: gil:: LockGIL ;
4
4
use crate :: impl_:: panic:: PanicTrap ;
5
+ use crate :: impl_:: pycell:: { PyClassObject , PyClassObjectLayout } ;
6
+ use crate :: pycell:: impl_:: PyClassBorrowChecker as _;
5
7
use crate :: pycell:: { PyBorrowError , PyBorrowMutError } ;
6
8
use crate :: pyclass:: boolean_struct:: False ;
7
9
use crate :: types:: any:: PyAnyMethods ;
8
10
#[ cfg( feature = "gil-refs" ) ]
9
11
use crate :: types:: { PyModule , PyType } ;
10
12
use crate :: {
11
- ffi, Borrowed , Bound , DowncastError , Py , PyAny , PyClass , PyClassInitializer , PyErr , PyObject ,
12
- PyRef , PyRefMut , PyResult , PyTraverseError , PyTypeCheck , PyVisit , Python ,
13
+ ffi, Bound , DowncastError , Py , PyAny , PyClass , PyClassInitializer , PyErr , PyObject , PyRef ,
14
+ PyRefMut , PyResult , PyTraverseError , PyTypeCheck , PyVisit , Python ,
13
15
} ;
14
16
use std:: ffi:: CStr ;
15
17
use std:: fmt;
18
+ use std:: marker:: PhantomData ;
16
19
use std:: os:: raw:: { c_int, c_void} ;
17
20
use std:: panic:: { catch_unwind, AssertUnwindSafe } ;
18
21
use std:: ptr:: null_mut;
@@ -234,6 +237,40 @@ impl PySetterDef {
234
237
}
235
238
236
239
/// Calls an implementation of __traverse__ for tp_traverse
240
+ ///
241
+ /// NB cannot accept `'static` visitor, this is a sanity check below:
242
+ ///
243
+ /// ```rust,compile_fail
244
+ /// use pyo3::prelude::*;
245
+ /// use pyo3::pyclass::{PyTraverseError, PyVisit};
246
+ ///
247
+ /// #[pyclass]
248
+ /// struct Foo;
249
+ ///
250
+ /// #[pymethods]
251
+ /// impl Foo {
252
+ /// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
253
+ /// Ok(())
254
+ /// }
255
+ /// }
256
+ /// ```
257
+ ///
258
+ /// Elided lifetime should compile ok:
259
+ ///
260
+ /// ```rust
261
+ /// use pyo3::prelude::*;
262
+ /// use pyo3::pyclass::{PyTraverseError, PyVisit};
263
+ ///
264
+ /// #[pyclass]
265
+ /// struct Foo;
266
+ ///
267
+ /// #[pymethods]
268
+ /// impl Foo {
269
+ /// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
270
+ /// Ok(())
271
+ /// }
272
+ /// }
273
+ /// ```
237
274
#[ doc( hidden) ]
238
275
pub unsafe fn _call_traverse < T > (
239
276
slf : * mut ffi:: PyObject ,
@@ -252,25 +289,43 @@ where
252
289
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
253
290
// token does not produce any owned objects thereby calling into `register_owned`.
254
291
let trap = PanicTrap :: new ( "uncaught panic inside __traverse__ handler" ) ;
292
+ let lock = LockGIL :: during_traverse ( ) ;
293
+
294
+ // SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
295
+ // traversal is running so no mutations can occur.
296
+ let class_object: & PyClassObject < T > = & * slf. cast ( ) ;
297
+
298
+ let retval =
299
+ // `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
300
+ // do not traverse them if not on their owning thread :(
301
+ if class_object. check_threadsafe ( ) . is_ok ( )
302
+ // ... and we cannot traverse a type which might be being mutated by a Rust thread
303
+ && class_object. borrow_checker ( ) . try_borrow ( ) . is_ok ( ) {
304
+ struct TraverseGuard < ' a , T : PyClass > ( & ' a PyClassObject < T > ) ;
305
+ impl < ' a , T : PyClass > Drop for TraverseGuard < ' a , T > {
306
+ fn drop ( & mut self ) {
307
+ self . 0 . borrow_checker ( ) . release_borrow ( )
308
+ }
309
+ }
255
310
256
- let py = Python :: assume_gil_acquired ( ) ;
257
- let slf = Borrowed :: from_ptr_unchecked ( py , slf ) . downcast_unchecked :: < T > ( ) ;
258
- let borrow = PyRef :: try_borrow_threadsafe ( & slf ) ;
259
- let visit = PyVisit :: from_raw ( visit , arg , py ) ;
311
+ // `.try_borrow()` above created a borrow, we need to release it when we're done
312
+ // traversing the object. This allows us to read `instance` safely.
313
+ let _guard = TraverseGuard ( class_object ) ;
314
+ let instance = & * class_object . contents . value . get ( ) ;
260
315
261
- let retval = if let Ok ( borrow) = borrow {
262
- let _lock = LockGIL :: during_traverse ( ) ;
316
+ let visit = PyVisit { visit, arg, _guard : PhantomData } ;
263
317
264
- match catch_unwind ( AssertUnwindSafe ( move || impl_ ( & * borrow, visit) ) ) {
265
- Ok ( res) => match res {
266
- Ok ( ( ) ) => 0 ,
267
- Err ( PyTraverseError ( value) ) => value,
268
- } ,
318
+ match catch_unwind ( AssertUnwindSafe ( move || impl_ ( instance, visit) ) ) {
319
+ Ok ( Ok ( ( ) ) ) => 0 ,
320
+ Ok ( Err ( traverse_error) ) => traverse_error. into_inner ( ) ,
269
321
Err ( _err) => -1 ,
270
322
}
271
323
} else {
272
324
0
273
325
} ;
326
+
327
+ // Drop lock before trap just in case dropping lock panics
328
+ drop ( lock) ;
274
329
trap. disarm ( ) ;
275
330
retval
276
331
}
0 commit comments