Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid creating PyRef inside __traverse__ handler #4479

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ serde_json = "1.0.61"
rayon = "1.6.1"
futures = "0.3.28"
tempfile = "3.12.0"
static_assertions = "1.1.0"

[build-dependencies]
pyo3-build-config = { path = "pyo3-build-config", version = "=0.23.0-dev", features = ["resolve-config"] }
Expand Down
1 change: 1 addition & 0 deletions newsfragments/4479.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Remove illegal reference counting op inside implementation of `__traverse__` handlers.
81 changes: 68 additions & 13 deletions src/impl_/pymethods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ use crate::callback::IntoPyCallbackOutput;
use crate::exceptions::PyStopAsyncIteration;
use crate::gil::LockGIL;
use crate::impl_::panic::PanicTrap;
use crate::impl_::pycell::{PyClassObject, PyClassObjectLayout};
use crate::pycell::impl_::PyClassBorrowChecker as _;
use crate::pycell::{PyBorrowError, PyBorrowMutError};
use crate::pyclass::boolean_struct::False;
use crate::types::any::PyAnyMethods;
use crate::{
ffi, Borrowed, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject,
PyRef, PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
ffi, Bound, DowncastError, Py, PyAny, PyClass, PyClassInitializer, PyErr, PyObject, PyRef,
PyRefMut, PyResult, PyTraverseError, PyTypeCheck, PyVisit, Python,
};
use std::ffi::CStr;
use std::fmt;
use std::marker::PhantomData;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::ptr::null_mut;
Expand Down Expand Up @@ -232,6 +235,40 @@ impl PySetterDef {
}

/// Calls an implementation of __traverse__ for tp_traverse
///
/// NB cannot accept `'static` visitor, this is a sanity check below:
///
/// ```rust,compile_fail
/// use pyo3::prelude::*;
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
///
/// #[pyclass]
/// struct Foo;
///
/// #[pymethods]
/// impl Foo {
/// fn __traverse__(&self, _visit: PyVisit<'static>) -> Result<(), PyTraverseError> {
/// Ok(())
/// }
/// }
/// ```
///
/// Elided lifetime should compile ok:
///
/// ```rust
/// use pyo3::prelude::*;
/// use pyo3::pyclass::{PyTraverseError, PyVisit};
///
/// #[pyclass]
/// struct Foo;
///
/// #[pymethods]
/// impl Foo {
/// fn __traverse__(&self, _visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
/// Ok(())
/// }
/// }
/// ```
#[doc(hidden)]
pub unsafe fn _call_traverse<T>(
slf: *mut ffi::PyObject,
Expand All @@ -250,25 +287,43 @@ where
// Since we do not create a `GILPool` at all, it is important that our usage of the GIL
// token does not produce any owned objects thereby calling into `register_owned`.
let trap = PanicTrap::new("uncaught panic inside __traverse__ handler");
let lock = LockGIL::during_traverse();

// SAFETY: `slf` is a valid Python object pointer to a class object of type T, and
// traversal is running so no mutations can occur.
let class_object: &PyClassObject<T> = &*slf.cast();

let retval =
// `#[pyclass(unsendable)]` types can only be deallocated by their own thread, so
// do not traverse them if not on their owning thread :(
if class_object.check_threadsafe().is_ok()
// ... and we cannot traverse a type which might be being mutated by a Rust thread
&& class_object.borrow_checker().try_borrow().is_ok() {
struct TraverseGuard<'a, T: PyClass>(&'a PyClassObject<T>);
impl<'a, T: PyClass> Drop for TraverseGuard<'a, T> {
fn drop(&mut self) {
self.0.borrow_checker().release_borrow()
}
}

let py = Python::assume_gil_acquired();
let slf = Borrowed::from_ptr_unchecked(py, slf).downcast_unchecked::<T>();
let borrow = PyRef::try_borrow_threadsafe(&slf);
let visit = PyVisit::from_raw(visit, arg, py);
// `.try_borrow()` above created a borrow, we need to release it when we're done
// traversing the object. This allows us to read `instance` safely.
let _guard = TraverseGuard(class_object);
let instance = &*class_object.contents.value.get();

let retval = if let Ok(borrow) = borrow {
let _lock = LockGIL::during_traverse();
let visit = PyVisit { visit, arg, _guard: PhantomData };

match catch_unwind(AssertUnwindSafe(move || impl_(&*borrow, visit))) {
Ok(res) => match res {
Ok(()) => 0,
Err(PyTraverseError(value)) => value,
},
match catch_unwind(AssertUnwindSafe(move || impl_(instance, visit))) {
Ok(Ok(())) => 0,
Ok(Err(traverse_error)) => traverse_error.into_inner(),
Err(_err) => -1,
}
} else {
0
};

// Drop lock before trap just in case dropping lock panics
drop(lock);
trap.disarm();
retval
}
Expand Down
8 changes: 0 additions & 8 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,14 +312,6 @@ impl<'py, T: PyClass> PyRef<'py, T> {
.try_borrow()
.map(|_| Self { inner: obj.clone() })
}

pub(crate) fn try_borrow_threadsafe(obj: &Bound<'py, T>) -> Result<Self, PyBorrowError> {
let cell = obj.get_class_object();
cell.check_threadsafe()?;
cell.borrow_checker()
.try_borrow()
.map(|_| Self { inner: obj.clone() })
}
}

impl<'p, T, U> PyRef<'p, T>
Expand Down
1 change: 1 addition & 0 deletions src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ mod create_type_object;
mod gc;

pub(crate) use self::create_type_object::{create_type_object, PyClassTypeObject};

pub use self::gc::{PyTraverseError, PyVisit};

/// Types that can be used as Python classes.
Expand Down
71 changes: 49 additions & 22 deletions src/pyclass/gc.rs
Original file line number Diff line number Diff line change
@@ -1,48 +1,75 @@
use std::os::raw::{c_int, c_void};
use std::{
marker::PhantomData,
os::raw::{c_int, c_void},
};

use crate::{ffi, AsPyPointer, Python};
use crate::{ffi, AsPyPointer};

/// Error returned by a `__traverse__` visitor implementation.
#[repr(transparent)]
pub struct PyTraverseError(pub(crate) c_int);
pub struct PyTraverseError(NonZeroCInt);

impl PyTraverseError {
/// Returns the error code.
pub(crate) fn into_inner(self) -> c_int {
self.0.into()
}
}

/// Object visitor for GC.
#[derive(Clone)]
pub struct PyVisit<'p> {
pub struct PyVisit<'a> {
pub(crate) visit: ffi::visitproc,
pub(crate) arg: *mut c_void,
/// VisitProc contains a Python instance to ensure that
/// 1) it is cannot be moved out of the traverse() call
/// 2) it cannot be sent to other threads
pub(crate) _py: Python<'p>,
/// Prevents the `PyVisit` from outliving the `__traverse__` call.
pub(crate) _guard: PhantomData<&'a ()>,
}
Comment on lines +24 to +25
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely removed the Python::assume_gil_acquired call from _call_traverse, so I needed to place a different reference back in for the lifetime. This is sufficient because the *mut c_void already prevents Send/Sync.


impl<'p> PyVisit<'p> {
impl<'a> PyVisit<'a> {
/// Visit `obj`.
pub fn call<T>(&self, obj: &T) -> Result<(), PyTraverseError>
where
T: AsPyPointer,
{
let ptr = obj.as_ptr();
if !ptr.is_null() {
let r = unsafe { (self.visit)(ptr, self.arg) };
if r == 0 {
Ok(())
} else {
Err(PyTraverseError(r))
match NonZeroCInt::new(unsafe { (self.visit)(ptr, self.arg) }) {
None => Ok(()),
Some(r) => Err(PyTraverseError(r)),
}
} else {
Ok(())
}
}
}

/// Creates the PyVisit from the arguments to tp_traverse
#[doc(hidden)]
pub unsafe fn from_raw(visit: ffi::visitproc, arg: *mut c_void, py: Python<'p>) -> Self {
Self {
visit,
arg,
_py: py,
}
/// Workaround for `NonZero<c_int>` not being available until MSRV 1.79
mod get_nonzero_c_int {
pub struct GetNonZeroCInt<const WIDTH: usize>();

pub trait NonZeroCIntType {
type Type;
}
impl NonZeroCIntType for GetNonZeroCInt<16> {
type Type = std::num::NonZeroI16;
}
impl NonZeroCIntType for GetNonZeroCInt<32> {
type Type = std::num::NonZeroI32;
}

pub type Type =
<GetNonZeroCInt<{ std::mem::size_of::<std::os::raw::c_int>() * 8 }> as NonZeroCIntType>::Type;
}

use get_nonzero_c_int::Type as NonZeroCInt;

#[cfg(test)]
mod tests {
use super::PyVisit;
use static_assertions::assert_not_impl_any;

#[test]
fn py_visit_not_send_sync() {
assert_not_impl_any!(PyVisit<'_>: Send, Sync);
}
}
Loading