diff --git a/.clang-tidy b/.clang-tidy index 0a12b0af..37ef02cd 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -22,6 +22,6 @@ readability-*, -readability-identifier-length, ' CheckOptions: - misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*' + misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*' HeaderFilterRegex: '^include/.*$' ... diff --git a/include/mutex.h b/include/mutex.h index e3066790..d01db861 100644 --- a/include/mutex.h +++ b/include/mutex.h @@ -53,16 +53,19 @@ using recursive_mutex = std::recursive_mutex; using scoped_lock_guard = std::lock_guard; using scoped_recursive_lock_guard = std::lock_guard; -#if defined(Py_GIL_DISABLED) /* use mutex implementation from Python rather than STL */ || \ - (defined(__APPLE__) /* header is not available on macOS build target */ && \ +#if (defined(__APPLE__) /* header is not available on macOS build target */ && \ PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0) +#undef HAVE_READ_WRITE_LOCK + using read_write_mutex = mutex; using scoped_read_lock_guard = scoped_lock_guard; using scoped_write_lock_guard = scoped_lock_guard; #else +#define HAVE_READ_WRITE_LOCK + #include // std::shared_mutex, std::shared_lock using read_write_mutex = std::shared_mutex; diff --git a/include/utils.h b/include/utils.h index aea68ce5..ef2ac43d 100644 --- a/include/utils.h +++ b/include/utils.h @@ -264,8 +264,12 @@ inline void SET_ITEM(const py::handle& container, PyList_SET_ITEM(container.ptr(), item, value.inc_ref().ptr()); } +inline std::string PyStr(const py::handle& object) { + return EVALUATE_WITH_LOCK_HELD(static_cast(py::str(object)), object); +} +inline std::string PyStr(const std::string& string) { return string; } inline std::string PyRepr(const py::handle& object) { - return static_cast(py::repr(object)); + return EVALUATE_WITH_LOCK_HELD(static_cast(py::repr(object)), object); } inline std::string PyRepr(const std::string& string) { return static_cast(py::repr(py::str(string))); @@ -601,10 +605,8 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references] // Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)` const auto sort_key_fn = py::cpp_function([](const py::object& o) -> py::tuple { const py::handle t = py::type::handle_of(o); - const py::str qualname{ - static_cast(py::str(py::getattr(t, Py_Get_ID(__module__)))) + - "." + - static_cast(py::str(py::getattr(t, Py_Get_ID(__qualname__))))}; + const py::str qualname{PyStr(py::getattr(t, Py_Get_ID(__module__))) + "." + + PyStr(py::getattr(t, Py_Get_ID(__qualname__)))}; return py::make_tuple(qualname, o); }); { diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index 2c242e45..4e5687c0 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -185,6 +185,7 @@ template } verify_children(children, treespecs, registry_namespace); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), std::move(keys)); } else [[likely]] { diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index f7e58ab6..1db9d805 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -26,6 +26,7 @@ limitations under the License. #include "include/critical_section.h" #include "include/exceptions.h" +#include "include/mutex.h" #include "include/registry.h" #include "include/treespec.h" #include "include/utils.h" @@ -119,6 +120,7 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, } } if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), std::move(keys)); } else [[likely]] { @@ -244,10 +246,15 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle, auto leaves = reserved_vector(4); auto treespec = std::make_unique(); treespec->m_none_is_leaf = none_is_leaf; - if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) || - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) - [[unlikely]] { - treespec->m_namespace = registry_namespace; + { +#ifdef HAVE_READ_WRITE_LOCK + const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; +#endif + if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) || + IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) + [[unlikely]] { + treespec->m_namespace = registry_namespace; + } } treespec->m_traversal.shrink_to_fit(); return std::make_pair(std::move(leaves), std::move(treespec)); @@ -518,15 +525,20 @@ PyTreeSpec::FlattenWithPath(const py::object& tree, auto paths = reserved_vector(4); auto treespec = std::make_unique(); treespec->m_none_is_leaf = none_is_leaf; - if (treespec->FlattenIntoWithPath(tree, - leaves, - paths, - leaf_predicate, - none_is_leaf, - registry_namespace) || - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) - [[unlikely]] { - treespec->m_namespace = registry_namespace; + { +#ifdef HAVE_READ_WRITE_LOCK + const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; +#endif + if (treespec->FlattenIntoWithPath(tree, + leaves, + paths, + leaf_predicate, + none_is_leaf, + registry_namespace) || + IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) + [[unlikely]] { + treespec->m_namespace = registry_namespace; + } } treespec->m_traversal.shrink_to_fit(); return std::make_tuple(std::move(paths), std::move(leaves), std::move(treespec)); diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 3bb8615f..d74e7cfe 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -23,14 +23,11 @@ limitations under the License. #include "include/critical_section.h" #include "include/exceptions.h" +#include "include/mutex.h" #include "include/registry.h" #include "include/treespec.h" #include "include/utils.h" -#ifdef Py_GIL_DISABLED -#include "include/mutex.h" -#endif - namespace optree { template diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index c93746a1..ec66e0b9 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -155,11 +155,15 @@ namespace optree { dict[GET_ITEM_HANDLE(node.original_keys, i)] = py::none(); } } - for (ssize_t i = 0; i < node.arity; ++i) { - // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - dict[GET_ITEM_HANDLE(keys, i)] = children[i]; + { + const scoped_critical_section cs{keys}; + for (ssize_t i = 0; i < node.arity; ++i) { + // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] + dict[GET_ITEM_HANDLE(keys, i)] = children[i]; + } } - return PyDefaultDictTypeObject(default_factory, dict); + return EVALUATE_WITH_LOCK_HELD(PyDefaultDictTypeObject(default_factory, dict), + default_factory); } case PyTreeKind::Custom: { @@ -1092,9 +1096,9 @@ bool PyTreeSpec::operator==(const PyTreeSpec& other) const { } // NOLINTNEXTLINE[readability-qualified-auto] - auto b = other.m_traversal.begin(); + auto b = other.m_traversal.cbegin(); // NOLINTNEXTLINE[readability-qualified-auto] - for (auto a = m_traversal.begin(); a != m_traversal.end(); ++a, ++b) { + for (auto a = m_traversal.cbegin(); a != m_traversal.cend(); ++a, ++b) { if (a->kind != b->kind || a->arity != b->arity || static_cast(a->node_data) != static_cast(b->node_data) || a->custom != b->custom) [[likely]] { @@ -1194,8 +1198,7 @@ std::string PyTreeSpec::ToStringImpl() const { EXPECT_EQ(GET_SIZE(fields), node.arity, "Number of fields and entries does not match."); - const std::string kind = - static_cast(py::str(py::getattr(type, Py_Get_ID(__name__)))); + const std::string kind = PyStr(py::getattr(type, Py_Get_ID(__name__))); sstream << kind << "("; bool first = true; auto child_iter = agenda.end() - node.arity; @@ -1203,7 +1206,7 @@ std::string PyTreeSpec::ToStringImpl() const { if (!first) [[likely]] { sstream << ", "; } - sstream << static_cast(py::str(field)) << "=" << *child_iter; + sstream << PyStr(field) << "=" << *child_iter; ++child_iter; first = false; } @@ -1241,7 +1244,7 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::Deque: { sstream << "deque([" << children << "]"; if (!node.node_data.is_none()) [[unlikely]] { - sstream << ", maxlen=" << static_cast(py::str(node.node_data)); + sstream << ", maxlen=" << PyRepr(node.node_data); } sstream << ")"; break; @@ -1256,21 +1259,21 @@ std::string PyTreeSpec::ToStringImpl() const { const py::object module_name = py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)); if (!module_name.is_none()) [[likely]] { - const std::string name = static_cast(py::str(module_name)); + const std::string name = PyStr(module_name); if (!(name.empty() || name == "__main__" || name == "builtins" || name == "__builtins__")) [[likely]] { sstream << name << "."; } } const py::object qualname = py::getattr(type, Py_Get_ID(__qualname__)); - sstream << static_cast(py::str(qualname)) << "("; + sstream << PyStr(qualname) << "("; bool first = true; auto child_iter = agenda.end() - node.arity; for (const py::handle& field : fields) { if (!first) [[likely]] { sstream << ", "; } - sstream << static_cast(py::str(field)) << "=" << *child_iter; + sstream << PyStr(field) << "=" << *child_iter; ++child_iter; first = false; } @@ -1279,9 +1282,8 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::Custom: { - const scoped_critical_section cs{node.node_data}; - const std::string kind = static_cast( - py::str(py::getattr(node.custom->type, Py_Get_ID(__name__)))); + const scoped_critical_section2 cs{node.custom->type, node.node_data}; + const std::string kind = PyStr(py::getattr(node.custom->type, Py_Get_ID(__name__))); sstream << "CustomTreeNode(" << kind << "["; if (node.node_data) [[likely]] { sstream << PyRepr(node.node_data); @@ -1369,7 +1371,7 @@ std::string PyTreeSpec::ToString() const { "Number of auxiliary data mismatch."); const py::object default_factory = GET_ITEM_BORROW(node.node_data, ssize_t(0)); - data_hash = py::hash(default_factory); + data_hash = EVALUATE_WITH_LOCK_HELD(py::hash(default_factory), default_factory); } const auto keys = py::reinterpret_borrow( node.kind != PyTreeKind::DefaultDict @@ -1379,7 +1381,7 @@ std::string PyTreeSpec::ToString() const { node.arity, "Number of keys and entries does not match."); for (const py::handle& key : keys) { - HashCombine(data_hash, py::hash(key)); + HashCombine(data_hash, EVALUATE_WITH_LOCK_HELD(py::hash(key), key)); } break; }