Skip to content

Commit

Permalink
fix: fix free-threading
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 23, 2024
1 parent 2e37af4 commit 821c02b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ readability-*,
-readability-identifier-length,
'
CheckOptions:
misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*'
misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*'
HeaderFilterRegex: '^include/.*$'
...
7 changes: 5 additions & 2 deletions include/mutex.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,19 @@ using recursive_mutex = std::recursive_mutex;
using scoped_lock_guard = std::lock_guard<mutex>;
using scoped_recursive_lock_guard = std::lock_guard<recursive_mutex>;

#if defined(Py_GIL_DISABLED) /* use mutex implementation from Python rather than STL */ || \
(defined(__APPLE__) /* header <shared_mutex> is not available on macOS build target */ && \
#if (defined(__APPLE__) /* header <shared_mutex> is not available on macOS build target */ && \
PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0)

#undef HAVE_READ_WRITE_LOCK

using read_write_mutex = mutex;
using scoped_read_lock_guard = scoped_lock_guard;
using scoped_write_lock_guard = scoped_lock_guard;

#else

#define HAVE_READ_WRITE_LOCK

#include <shared_mutex> // std::shared_mutex, std::shared_lock

using read_write_mutex = std::shared_mutex;
Expand Down
12 changes: 7 additions & 5 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,12 @@ inline void SET_ITEM<py::list>(const py::handle& container,
PyList_SET_ITEM(container.ptr(), item, value.inc_ref().ptr());
}

inline std::string PyStr(const py::handle& object) {
return EVALUATE_WITH_LOCK_HELD(static_cast<std::string>(py::str(object)), object);
}
inline std::string PyStr(const std::string& string) { return string; }
inline std::string PyRepr(const py::handle& object) {
return static_cast<std::string>(py::repr(object));
return EVALUATE_WITH_LOCK_HELD(static_cast<std::string>(py::repr(object)), object);
}
inline std::string PyRepr(const std::string& string) {
return static_cast<std::string>(py::repr(py::str(string)));
Expand Down Expand Up @@ -601,10 +605,8 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
// Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)`
const auto sort_key_fn = py::cpp_function([](const py::object& o) -> py::tuple {
const py::handle t = py::type::handle_of(o);
const py::str qualname{
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__module__)))) +
"." +
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__qualname__))))};
const py::str qualname{PyStr(py::getattr(t, Py_Get_ID(__module__))) + "." +
PyStr(py::getattr(t, Py_Get_ID(__qualname__)))};
return py::make_tuple(qualname, o);
});
{
Expand Down
1 change: 1 addition & 0 deletions src/treespec/constructor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ template <bool NoneIsLeaf>
}
verify_children(children, treespecs, registry_namespace);
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
const scoped_critical_section cs{handle};
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
std::move(keys));
} else [[likely]] {
Expand Down
38 changes: 25 additions & 13 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "include/critical_section.h"
#include "include/exceptions.h"
#include "include/mutex.h"
#include "include/registry.h"
#include "include/treespec.h"
#include "include/utils.h"
Expand Down Expand Up @@ -119,6 +120,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle,
}
}
if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] {
const scoped_critical_section cs{handle};
node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)),
std::move(keys));
} else [[likely]] {
Expand Down Expand Up @@ -244,10 +246,15 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle,
auto leaves = reserved_vector<py::object>(4);
auto treespec = std::make_unique<PyTreeSpec>();
treespec->m_none_is_leaf = none_is_leaf;
if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) ||
IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false))
[[unlikely]] {
treespec->m_namespace = registry_namespace;
{
#ifdef HAVE_READ_WRITE_LOCK
const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex};
#endif
if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) ||
IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false))
[[unlikely]] {
treespec->m_namespace = registry_namespace;
}
}
treespec->m_traversal.shrink_to_fit();
return std::make_pair(std::move(leaves), std::move(treespec));
Expand Down Expand Up @@ -518,15 +525,20 @@ PyTreeSpec::FlattenWithPath(const py::object& tree,
auto paths = reserved_vector<py::tuple>(4);
auto treespec = std::make_unique<PyTreeSpec>();
treespec->m_none_is_leaf = none_is_leaf;
if (treespec->FlattenIntoWithPath(tree,
leaves,
paths,
leaf_predicate,
none_is_leaf,
registry_namespace) ||
IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false))
[[unlikely]] {
treespec->m_namespace = registry_namespace;
{
#ifdef HAVE_READ_WRITE_LOCK
const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex};
#endif
if (treespec->FlattenIntoWithPath(tree,
leaves,
paths,
leaf_predicate,
none_is_leaf,
registry_namespace) ||
IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false))
[[unlikely]] {
treespec->m_namespace = registry_namespace;
}
}
treespec->m_traversal.shrink_to_fit();
return std::make_tuple(std::move(paths), std::move(leaves), std::move(treespec));
Expand Down
5 changes: 1 addition & 4 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,11 @@ limitations under the License.

#include "include/critical_section.h"
#include "include/exceptions.h"
#include "include/mutex.h"
#include "include/registry.h"
#include "include/treespec.h"
#include "include/utils.h"

#ifdef Py_GIL_DISABLED
#include "include/mutex.h"
#endif

namespace optree {

template <bool NoneIsLeaf>
Expand Down
38 changes: 20 additions & 18 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,15 @@ namespace optree {
dict[GET_ITEM_HANDLE<py::list>(node.original_keys, i)] = py::none();
}
}
for (ssize_t i = 0; i < node.arity; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
dict[GET_ITEM_HANDLE<py::list>(keys, i)] = children[i];
{
const scoped_critical_section cs{keys};
for (ssize_t i = 0; i < node.arity; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
dict[GET_ITEM_HANDLE<py::list>(keys, i)] = children[i];
}
}
return PyDefaultDictTypeObject(default_factory, dict);
return EVALUATE_WITH_LOCK_HELD(PyDefaultDictTypeObject(default_factory, dict),
default_factory);
}

case PyTreeKind::Custom: {
Expand Down Expand Up @@ -1092,9 +1096,9 @@ bool PyTreeSpec::operator==(const PyTreeSpec& other) const {
}

// NOLINTNEXTLINE[readability-qualified-auto]
auto b = other.m_traversal.begin();
auto b = other.m_traversal.cbegin();
// NOLINTNEXTLINE[readability-qualified-auto]
for (auto a = m_traversal.begin(); a != m_traversal.end(); ++a, ++b) {
for (auto a = m_traversal.cbegin(); a != m_traversal.cend(); ++a, ++b) {
if (a->kind != b->kind || a->arity != b->arity ||
static_cast<bool>(a->node_data) != static_cast<bool>(b->node_data) ||
a->custom != b->custom) [[likely]] {
Expand Down Expand Up @@ -1194,16 +1198,15 @@ std::string PyTreeSpec::ToStringImpl() const {
EXPECT_EQ(GET_SIZE<py::tuple>(fields),
node.arity,
"Number of fields and entries does not match.");
const std::string kind =
static_cast<std::string>(py::str(py::getattr(type, Py_Get_ID(__name__))));
const std::string kind = PyStr(py::getattr(type, Py_Get_ID(__name__)));
sstream << kind << "(";
bool first = true;
auto child_iter = agenda.end() - node.arity;
for (const py::handle& field : fields) {
if (!first) [[likely]] {
sstream << ", ";
}
sstream << static_cast<std::string>(py::str(field)) << "=" << *child_iter;
sstream << PyStr(field) << "=" << *child_iter;
++child_iter;
first = false;
}
Expand Down Expand Up @@ -1241,7 +1244,7 @@ std::string PyTreeSpec::ToStringImpl() const {
case PyTreeKind::Deque: {
sstream << "deque([" << children << "]";
if (!node.node_data.is_none()) [[unlikely]] {
sstream << ", maxlen=" << static_cast<std::string>(py::str(node.node_data));
sstream << ", maxlen=" << PyRepr(node.node_data);
}
sstream << ")";
break;
Expand All @@ -1256,21 +1259,21 @@ std::string PyTreeSpec::ToStringImpl() const {
const py::object module_name =
py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__));
if (!module_name.is_none()) [[likely]] {
const std::string name = static_cast<std::string>(py::str(module_name));
const std::string name = PyStr(module_name);
if (!(name.empty() || name == "__main__" || name == "builtins" ||
name == "__builtins__")) [[likely]] {
sstream << name << ".";
}
}
const py::object qualname = py::getattr(type, Py_Get_ID(__qualname__));
sstream << static_cast<std::string>(py::str(qualname)) << "(";
sstream << PyStr(qualname) << "(";
bool first = true;
auto child_iter = agenda.end() - node.arity;
for (const py::handle& field : fields) {
if (!first) [[likely]] {
sstream << ", ";
}
sstream << static_cast<std::string>(py::str(field)) << "=" << *child_iter;
sstream << PyStr(field) << "=" << *child_iter;
++child_iter;
first = false;
}
Expand All @@ -1279,9 +1282,8 @@ std::string PyTreeSpec::ToStringImpl() const {
}

case PyTreeKind::Custom: {
const scoped_critical_section cs{node.node_data};
const std::string kind = static_cast<std::string>(
py::str(py::getattr(node.custom->type, Py_Get_ID(__name__))));
const scoped_critical_section2 cs{node.custom->type, node.node_data};
const std::string kind = PyStr(py::getattr(node.custom->type, Py_Get_ID(__name__)));
sstream << "CustomTreeNode(" << kind << "[";
if (node.node_data) [[likely]] {
sstream << PyRepr(node.node_data);
Expand Down Expand Up @@ -1369,7 +1371,7 @@ std::string PyTreeSpec::ToString() const {
"Number of auxiliary data mismatch.");
const py::object default_factory =
GET_ITEM_BORROW<py::tuple>(node.node_data, ssize_t(0));
data_hash = py::hash(default_factory);
data_hash = EVALUATE_WITH_LOCK_HELD(py::hash(default_factory), default_factory);
}
const auto keys = py::reinterpret_borrow<py::list>(
node.kind != PyTreeKind::DefaultDict
Expand All @@ -1379,7 +1381,7 @@ std::string PyTreeSpec::ToString() const {
node.arity,
"Number of keys and entries does not match.");
for (const py::handle& key : keys) {
HashCombine(data_hash, py::hash(key));
HashCombine(data_hash, EVALUATE_WITH_LOCK_HELD(py::hash(key), key));
}
break;
}
Expand Down

0 comments on commit 821c02b

Please sign in to comment.