Skip to content

Commit

Permalink
refactor: use new object to mutex holder
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 22, 2024
1 parent 8a5663f commit d31a92e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 22 deletions.
14 changes: 9 additions & 5 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,18 @@ class PyTreeIter {
const bool &none_is_leaf,
const std::string &registry_namespace)
: m_root{tree},
m_agenda{reserved_vector<std::pair<py::object, ssize_t>>(4)},
m_agenda{{{tree, 0}}},
m_leaf_predicate{leaf_predicate},
m_none_is_leaf{none_is_leaf},
m_namespace{registry_namespace},
m_is_dict_insertion_ordered{PyTreeSpec::IsDictInsertionOrdered(registry_namespace)},
m_mutex{} {
m_agenda.emplace_back(tree, 0);
};
#ifdef Py_GIL_DISABLED
// NOLINTNEXTLINE[whitespace/braces]
m_mutex_holder_object{py::handle{reinterpret_cast<PyObject *>(&PyBaseObject_Type)}()}
#else
m_mutex_holder_object{} // NOLINT[whitespace/braces]
#endif
{};

PyTreeIter() = delete;
~PyTreeIter() = default;
Expand All @@ -419,7 +423,7 @@ class PyTreeIter {
const bool m_none_is_leaf;
const std::string m_namespace;
const bool m_is_dict_insertion_ordered;
mutex m_mutex;
const py::object m_mutex_holder_object;

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
1 change: 1 addition & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ Py_Declare_ID(__qualname__); // type.__qualname__
Py_Declare_ID(__name__); // type.__name__
Py_Declare_ID(sort); // list.sort
Py_Declare_ID(copy); // dict.copy
Py_Declare_ID(fromkeys); // dict.fromkeys
Py_Declare_ID(default_factory); // defaultdict.default_factory
Py_Declare_ID(maxlen); // deque.maxlen
Py_Declare_ID(_fields); // namedtuple._fields
Expand Down
5 changes: 1 addition & 4 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ limitations under the License.

#include "include/critical_section.h"
#include "include/exceptions.h"
#include "include/mutex.h"
#include "include/registry.h"
#include "include/treespec.h"
#include "include/utils.h"
Expand Down Expand Up @@ -167,9 +166,7 @@ py::object PyTreeIter::NextImpl() {
}

py::object PyTreeIter::Next() {
#ifdef Py_GIL_DISABLED
const scoped_lock_guard lock{m_mutex};
#endif
const scoped_critical_section cs{m_mutex_holder_object};

if (m_none_is_leaf) [[unlikely]] {
return NextImpl<NONE_IS_LEAF>();
Expand Down
23 changes: 10 additions & 13 deletions src/treespec/treespec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@ namespace optree {
const scoped_critical_section2 cs{node.node_data, node.original_keys};
const auto keys = py::reinterpret_borrow<py::list>(node.node_data);
if (node.original_keys) [[unlikely]] {
for (ssize_t i = 0; i < node.arity; ++i) {
dict[GET_ITEM_HANDLE<py::list>(node.original_keys, i)] = py::none();
}
dict = py::getattr(py::handle{reinterpret_cast<PyObject*>(&PyDict_Type)},
Py_Get_ID(fromkeys))(node.original_keys);
}
for (ssize_t i = 0; i < node.arity; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
Expand All @@ -145,21 +144,20 @@ namespace optree {
}

case PyTreeKind::DefaultDict: {
const py::dict dict{};
py::dict dict{};
const scoped_critical_section2 cs{node.node_data, node.original_keys};
const py::object default_factory =
GET_ITEM_BORROW<py::tuple>(node.node_data, ssize_t(0));
const py::list keys = GET_ITEM_BORROW<py::tuple>(node.node_data, ssize_t(1));
if (node.original_keys) [[unlikely]] {
for (ssize_t i = 0; i < node.arity; ++i) {
dict[GET_ITEM_HANDLE<py::list>(node.original_keys, i)] = py::none();
}
dict = py::getattr(py::handle{reinterpret_cast<PyObject*>(&PyDict_Type)},
Py_Get_ID(fromkeys))(node.original_keys);
}
for (ssize_t i = 0; i < node.arity; ++i) {
// NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic]
dict[GET_ITEM_HANDLE<py::list>(keys, i)] = children[i];
}
return PyDefaultDictTypeObject(default_factory, dict);
return PyDefaultDictTypeObject(default_factory, std::move(dict));
}

case PyTreeKind::Custom: {
Expand Down Expand Up @@ -1568,7 +1566,7 @@ int PyTreeSpecTpTraverse(PyObject* self_base, visitproc visit, void* arg) {
// The holder is not constructed yet. Skip the traversal to avoid segfault.
return 0;
}
auto& self = py::cast<PyTreeSpec&>(py::handle(self_base));
auto& self = py::cast<PyTreeSpec&>(py::handle{self_base});
for (const auto& node : self.m_traversal) {
Py_VISIT(node.node_data.ptr());
Py_VISIT(node.node_entries.ptr());
Expand All @@ -1577,6 +1575,7 @@ int PyTreeSpecTpTraverse(PyObject* self_base, visitproc visit, void* arg) {
return 0;
}

// NOLINTNEXTLINE[readability-function-cognitive-complexity]
int PyTreeIterTpTraverse(PyObject* self_base, visitproc visit, void* arg) {
#if PY_VERSION_HEX >= 0x03090000 // Python 3.9
Py_VISIT(Py_TYPE(self_base));
Expand All @@ -1586,16 +1585,14 @@ int PyTreeIterTpTraverse(PyObject* self_base, visitproc visit, void* arg) {
// The holder is not constructed yet. Skip the traversal to avoid segfault.
return 0;
}
auto& self = py::cast<PyTreeIter&>(py::handle(self_base));
auto& self = py::cast<PyTreeIter&>(py::handle{self_base});
{
#ifdef Py_GIL_DISABLED
const scoped_lock_guard lock{self.m_mutex};
#endif
for (const auto& pair : self.m_agenda) {
Py_VISIT(pair.first.ptr());
}
}
Py_VISIT(self.m_root.ptr());
Py_VISIT(self.m_mutex_holder_object.ptr());
return 0;
}

Expand Down

0 comments on commit d31a92e

Please sign in to comment.