Skip to content

Commit

Permalink
fix: fix free-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 24, 2024
1 parent 804a784 commit b0791d1
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 168 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ readability-*,
-readability-identifier-length,
'
CheckOptions:
misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*'
misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*'
HeaderFilterRegex: '^include/.*$'
...
7 changes: 5 additions & 2 deletions include/mutex.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ using recursive_mutex = std::recursive_mutex;
using scoped_lock_guard = std::lock_guard<mutex>;
using scoped_recursive_lock_guard = std::lock_guard<recursive_mutex>;

#if defined(Py_GIL_DISABLED) /* use mutex implementation from Python rather than STL */ || \
(defined(__APPLE__) /* header <shared_mutex> is not available on macOS build target */ && \
#if (defined(__APPLE__) /* header <shared_mutex> is not available on macOS build target */ && \
PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0)

#undef HAVE_READ_WRITE_LOCK

using read_write_mutex = mutex;
using scoped_read_lock_guard = scoped_lock_guard;
using scoped_write_lock_guard = scoped_lock_guard;

#else

#define HAVE_READ_WRITE_LOCK

#include <shared_mutex> // std::shared_mutex, std::shared_lock

using read_write_mutex = std::shared_mutex;
Expand Down
17 changes: 12 additions & 5 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ inline std::vector<T> reserved_vector(const py::size_t& size) {
return v;
}

template <typename T>
inline T thread_safe_cast(const py::handle& handle) {
return EVALUATE_WITH_LOCK_HELD(py::cast<T>(handle), handle);
}

template <typename Sized = py::object>
inline py::ssize_t GetSize(const py::handle& sized) {
return py::ssize_t_cast(py::len(sized));
Expand Down Expand Up @@ -264,8 +269,12 @@ inline void SET_ITEM<py::list>(const py::handle& container,
PyList_SET_ITEM(container.ptr(), item, value.inc_ref().ptr());
}

inline std::string PyStr(const py::handle& object) {
return EVALUATE_WITH_LOCK_HELD(static_cast<std::string>(py::str(object)), object);
}
inline std::string PyStr(const std::string& string) { return string; }
inline std::string PyRepr(const py::handle& object) {
return static_cast<std::string>(py::repr(object));
return EVALUATE_WITH_LOCK_HELD(static_cast<std::string>(py::repr(object)), object);
}
inline std::string PyRepr(const std::string& string) {
return static_cast<std::string>(py::repr(py::str(string)));
Expand Down Expand Up @@ -601,10 +610,8 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
// Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)`
const auto sort_key_fn = py::cpp_function([](const py::object& o) -> py::tuple {
const py::handle t = py::type::handle_of(o);
const py::str qualname{
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__module__)))) +
"." +
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__qualname__))))};
const py::str qualname{PyStr(py::getattr(t, Py_Get_ID(__module__))) + "." +
PyStr(py::getattr(t, Py_Get_ID(__qualname__)))};
return py::make_tuple(qualname, o);
});
{
Expand Down
2 changes: 2 additions & 0 deletions optree/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""

# pylint: disable=invalid-name

__version__ = '0.12.1'
__license__ = 'Apache License, Version 2.0'
__author__ = 'OpTree Contributors'
Expand Down
28 changes: 15 additions & 13 deletions src/treespec/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ template <bool NoneIsLeaf>
<< PyRepr(handle) << ".";
throw py::value_error(oss.str());
}
treespecs.emplace_back(py::cast<PyTreeSpec&>(child));
treespecs.emplace_back(thread_safe_cast<PyTreeSpec&>(child));
}

std::string common_registry_namespace{};
Expand Down Expand Up @@ -185,6 +185,7 @@ template <bool NoneIsLeaf>
}
verify_children(children, treespecs, registry_namespace);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
const scoped_critical_section cs{handle};
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
std::move(keys));
} else [[likely]] {
Expand All @@ -206,9 +207,10 @@ template <bool NoneIsLeaf>
}

case PyTreeKind::Deque: {
const auto list = EVALUATE_WITH_LOCK_HELD(py::cast<py::list>(handle), handle);
const auto list = thread_safe_cast<py::list>(handle);
node.arity = GET_SIZE<py::list>(list);
node.node_data = py::getattr(handle, Py_Get_ID(maxlen));
node.node_data =
EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle);
for (ssize_t i = 0; i < node.arity; ++i) {
children.emplace_back(GET_ITEM_BORROW<py::list>(list, i));
}
Expand All @@ -217,10 +219,10 @@ template <bool NoneIsLeaf>
}

case PyTreeKind::Custom: {
const py::tuple out =
EVALUATE_WITH_LOCK_HELD2(py::cast<py::tuple>(node.custom->flatten_func(handle)),
handle,
node.custom->flatten_func);
const py::tuple out = EVALUATE_WITH_LOCK_HELD2(
thread_safe_cast<py::tuple>(node.custom->flatten_func(handle)),
handle,
node.custom->flatten_func);
const ssize_t num_out = GET_SIZE<py::tuple>(out);
if (num_out != 2 && num_out != 3) [[unlikely]] {
std::ostringstream oss{};
Expand All @@ -231,19 +233,19 @@ template <bool NoneIsLeaf>
node.arity = 0;
node.node_data = GET_ITEM_BORROW<py::tuple>(out, ssize_t(1));
{
auto children_iterator =
py::cast<py::iterable>(GET_ITEM_BORROW<py::tuple>(out, ssize_t(0)));
const scoped_critical_section cs{children_iterator};
for (const py::handle& child : children_iterator) {
auto children_iterable =
thread_safe_cast<py::iterable>(GET_ITEM_BORROW<py::tuple>(out, ssize_t(0)));
const scoped_critical_section cs{children_iterable};
for (const py::handle& child : children_iterable) {
++node.arity;
children.emplace_back(py::reinterpret_borrow<py::object>(child));
}
}
verify_children(children, treespecs, registry_namespace);
if (num_out == 3) [[likely]] {
py::object node_entries = GET_ITEM_BORROW<py::tuple>(out, ssize_t(2));
const py::object node_entries = GET_ITEM_BORROW<py::tuple>(out, ssize_t(2));
if (!node_entries.is_none()) [[likely]] {
node.node_entries = py::cast<py::tuple>(std::move(node_entries));
node.node_entries = thread_safe_cast<py::tuple>(node_entries);
const ssize_t num_entries = GET_SIZE<py::tuple>(node.node_entries);
if (num_entries != node.arity) [[unlikely]] {
std::ostringstream oss{};
Expand Down
Loading

0 comments on commit b0791d1

Please sign in to comment.