diff --git a/.clang-format b/.clang-format index 1e2d55ee..2ea83640 100644 --- a/.clang-format +++ b/.clang-format @@ -15,13 +15,13 @@ BreakBeforeTernaryOperators: true FixNamespaceComments: true IncludeBlocks: Regroup IncludeCategories: - - Regex: '^(<|")Python\.h("|>)$' + - Regex: '^[<"]Python\.h[">]$' Priority: 2 CaseSensitive: true - - Regex: '^(<|")(pybind11/.*)("|>)$' + - Regex: '^[<"]pybind11/.*[">]$' Priority: 3 CaseSensitive: true - - Regex: '^<[[:alnum:]_]+(\.h)?>$' + - Regex: '^<[[:alnum:]_/]+(\.h)?>$' Priority: 1 - Regex: '^"include/' Priority: 4 diff --git a/.clang-tidy b/.clang-tidy index 0a12b0af..37ef02cd 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -22,6 +22,6 @@ readability-*, -readability-identifier-length, ' CheckOptions: - misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*' + misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*' HeaderFilterRegex: '^include/.*$' ... diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 0f4c51a6..2c976ccb 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -147,7 +147,7 @@ jobs: - os: macos-latest python-version: "3.7" fail-fast: false - timeout-minutes: 120 + timeout-minutes: 180 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.github/workflows/set_cibw_build.py b/.github/workflows/set_cibw_build.py index ae0c35a6..5b08e4e7 100755 --- a/.github/workflows/set_cibw_build.py +++ b/.github/workflows/set_cibw_build.py @@ -7,7 +7,8 @@ MAJOR, MINOR, *_ = platform.python_version_tuple() -CIBW_BUILD = f'CIBW_BUILD=*{platform.python_implementation().lower()[0]}p{MAJOR}{MINOR}-*' +IMPLEMENTATION = platform.python_implementation() +CIBW_BUILD = f'CIBW_BUILD=*{IMPLEMENTATION.lower()[0]}p{MAJOR}{MINOR}{{,?}}-*' print(CIBW_BUILD) with open(os.getenv('GITHUB_ENV'), mode='a', encoding='utf-8') as file: diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 39e81072..54007f41 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -37,9 +37,25 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] - python-abiflags: ["d"] + python-abiflags: ["d", "td"] + exclude: + - python-version: "3.7" + python-abiflags: "td" + - python-version: "3.8" + python-abiflags: "td" + - python-version: "3.9" + python-abiflags: "td" + - python-version: "3.10" + python-abiflags: "td" + - python-version: "3.11" + python-abiflags: "td" + - python-version: "3.12" + python-abiflags: "td" + - os: windows-latest # pyenv-win does not support Python 3.13t yet + python-version: "3.13" + python-abiflags: "td" fail-fast: false - timeout-minutes: 60 + timeout-minutes: 90 steps: - name: Checkout uses: actions/checkout@v4 @@ -65,7 +81,12 @@ jobs: shell: bash run: | pyenv install --list - if ! PYTHON_VERSION="$(pyenv latest --known "${{ matrix.python-version }}")"; then + if [[ "${{ matrix.python-abiflags }}" == *t* ]]; then + PYTHON_VERSION="$( + pyenv install --list | tr -d ' ' | grep -E "^${{ matrix.python-version }}" | + grep -vF '-' | grep -E '[0-9]t$' | sort -rV | head -n 1 + )" + elif ! PYTHON_VERSION="$(pyenv latest --known "${{ matrix.python-version }}")"; then PYTHON_VERSION="$( pyenv install --list | tr -d ' ' | grep -E "^${{ matrix.python-version }}" | grep -vF '-' | grep -E '[0-9]$' | sort -rV | head -n 1 @@ -109,4 +130,4 @@ jobs: - name: Test with pytest run: | - make test PYTESTOPTS="--exitfirst --verbosity=0 --durations=10" + make test PYTESTOPTS="--exitfirst" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3eb725d6..5acb82bb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -46,7 +46,7 @@ jobs: - os: macos-latest python-version: "3.7" # Python 3.7 does not support macOS ARM64 fail-fast: false - timeout-minutes: 60 + timeout-minutes: 90 steps: - name: Checkout uses: actions/checkout@v4 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16b22bf2..59ccec50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.7 + rev: v0.6.8 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/CHANGELOG.md b/CHANGELOG.md index 92e00d91..cb2f5909 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add Python 3.13t support by [@XuehaiPan](https://github.com/XuehaiPan) in [#137](https://github.com/metaopt/optree/pull/137). - Expose Python implementation for C utilities for `namedtuple` and `PyStructSequence` by [@XuehaiPan](https://github.com/XuehaiPan) in [#157](https://github.com/metaopt/optree/pull/157). - Add `dataclasses` integration by [@XuehaiPan](https://github.com/XuehaiPan) in [#142](https://github.com/metaopt/optree/pull/142). - Add Python 3.13 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#156](https://github.com/metaopt/optree/pull/156). diff --git a/Makefile b/Makefile index b86ab080..0ee2977f 100644 --- a/Makefile +++ b/Makefile @@ -140,7 +140,7 @@ pytest test: pytest-install $(PYTHON) -m pytest --version cd tests && $(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)' && \ $(PYTHON) -X dev -W 'always' -W 'error' -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \ - $(PYTHON) -X dev -m pytest --verbose --color=yes --durations=0 --showlocals \ + $(PYTHON) -X dev -m pytest --verbose --color=yes --durations=10 --showlocals \ --cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \ $(PYTESTOPTS) . diff --git a/include/critical_section.h b/include/critical_section.h new file mode 100644 index 00000000..76a2edf4 --- /dev/null +++ b/include/critical_section.h @@ -0,0 +1,136 @@ +/* +Copyright 2022-2024 MetaOPT Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +================================================================================ +*/ + +#pragma once + +#include + +#include + +namespace py = pybind11; + +#ifndef Py_Is +#define Py_Is(x, y) ((x) == (y)) +#endif +#ifndef Py_IsNone +#define Py_IsNone(x) Py_Is((x), Py_None) +#endif +#ifndef Py_IsTrue +#define Py_IsTrue(x) Py_Is((x), Py_True) +#endif +#ifndef Py_IsFalse +#define Py_IsFalse(x) Py_Is((x), Py_False) +#endif + +inline bool Py_IsConstant(PyObject* x) { return Py_IsNone(x) || Py_IsTrue(x) || Py_IsFalse(x); } +#define Py_IsConstant(x) Py_IsConstant(x) + +class scoped_critical_section { + public: + scoped_critical_section() = delete; + +#ifdef Py_GIL_DISABLED + explicit scoped_critical_section(const py::handle& handle) : m_ptr{handle.ptr()} { + if (m_ptr != nullptr && !Py_IsConstant(m_ptr)) [[likely]] { + PyCriticalSection_Begin(&m_critical_section, m_ptr); + } + } + + ~scoped_critical_section() { + if (m_ptr != nullptr && !Py_IsConstant(m_ptr)) [[likely]] { + PyCriticalSection_End(&m_critical_section); + } + } +#else + explicit scoped_critical_section(const py::handle& /*unused*/) {} + ~scoped_critical_section() = default; +#endif + + scoped_critical_section(const scoped_critical_section&) = delete; + scoped_critical_section& operator=(const scoped_critical_section&) = delete; + scoped_critical_section(scoped_critical_section&&) = delete; + scoped_critical_section& operator=(scoped_critical_section&&) = delete; + + private: +#ifdef Py_GIL_DISABLED + PyObject* m_ptr{nullptr}; + PyCriticalSection m_critical_section{}; +#endif +}; + +class scoped_critical_section2 { + public: + scoped_critical_section2() = delete; + +#ifdef Py_GIL_DISABLED + explicit scoped_critical_section2(const py::handle& handle1, const py::handle& handle2) + : m_ptr1{handle1.ptr()}, m_ptr2{handle2.ptr()} { + if (m_ptr1 != nullptr && !Py_IsConstant(m_ptr1)) [[likely]] { + if (m_ptr2 != nullptr && !Py_IsConstant(m_ptr2)) [[likely]] { + PyCriticalSection2_Begin(&m_critical_section2, m_ptr1, m_ptr2); + } else [[unlikely]] { + PyCriticalSection_Begin(&m_critical_section, m_ptr1); + } + } else if (m_ptr2 != nullptr && !Py_IsConstant(m_ptr2)) [[likely]] { + PyCriticalSection_Begin(&m_critical_section, m_ptr2); + } + } + + ~scoped_critical_section2() { + if (m_ptr1 != nullptr && !Py_IsConstant(m_ptr1)) [[likely]] { + if (m_ptr2 != nullptr && !Py_IsConstant(m_ptr2)) [[likely]] { + PyCriticalSection2_End(&m_critical_section2); + } else [[unlikely]] { + PyCriticalSection_End(&m_critical_section); + } + } else if (m_ptr2 != nullptr && !Py_IsConstant(m_ptr2)) [[likely]] { + PyCriticalSection_End(&m_critical_section); + } + } +#else + explicit scoped_critical_section2(const py::handle& /*unused*/, const py::handle& /*unused*/) {} + ~scoped_critical_section2() = default; +#endif + + scoped_critical_section2(const scoped_critical_section2&) = delete; + scoped_critical_section2& operator=(const scoped_critical_section2&) = delete; + scoped_critical_section2(scoped_critical_section2&&) = delete; + scoped_critical_section2& operator=(scoped_critical_section2&&) = delete; + + private: +#ifdef Py_GIL_DISABLED + PyObject* m_ptr1{nullptr}; + PyObject* m_ptr2{nullptr}; + PyCriticalSection m_critical_section{}; + PyCriticalSection2 m_critical_section2{}; +#endif +}; + +#ifdef Py_GIL_DISABLED + +#define EVALUATE_WITH_LOCK_HELD(expression, handle) \ + (((void)scoped_critical_section{(handle)}), (expression)) + +#define EVALUATE_WITH_LOCK_HELD2(expression, handle1, handle2) \ + (((void)scoped_critical_section2{(handle1), (handle2)}), (expression)) + +#else + +#define EVALUATE_WITH_LOCK_HELD(expression, handle) (expression) +#define EVALUATE_WITH_LOCK_HELD2(expression, handle1, handle2) (expression) + +#endif diff --git a/include/mutex.h b/include/mutex.h new file mode 100644 index 00000000..d01db861 --- /dev/null +++ b/include/mutex.h @@ -0,0 +1,75 @@ +/* +Copyright 2022-2024 MetaOPT Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +================================================================================ +*/ + +#pragma once + +#include // std::mutex, std::recursive_mutex, std::lock_guard, std::unique_lock + +#include + +#ifdef Py_GIL_DISABLED + +class pymutex { + public: + pymutex() = default; + ~pymutex() = default; + + pymutex(const pymutex &) = delete; + pymutex &operator=(const pymutex &) = delete; + pymutex(pymutex &&) = delete; + pymutex &operator=(pymutex &&) = delete; + + void lock() { PyMutex_Lock(&mutex); } + void unlock() { PyMutex_Unlock(&mutex); } + + private: + PyMutex mutex{0}; +}; + +using mutex = pymutex; +using recursive_mutex = std::recursive_mutex; + +#else + +using mutex = std::mutex; +using recursive_mutex = std::recursive_mutex; + +#endif + +using scoped_lock_guard = std::lock_guard; +using scoped_recursive_lock_guard = std::lock_guard; + +#if (defined(__APPLE__) /* header is not available on macOS build target */ && \ + PY_VERSION_HEX < /* Python 3.12.0 */ 0x030C00F0) + +#undef HAVE_READ_WRITE_LOCK + +using read_write_mutex = mutex; +using scoped_read_lock_guard = scoped_lock_guard; +using scoped_write_lock_guard = scoped_lock_guard; + +#else + +#define HAVE_READ_WRITE_LOCK + +#include // std::shared_mutex, std::shared_lock + +using read_write_mutex = std::shared_mutex; +using scoped_read_lock_guard = std::shared_lock; +using scoped_write_lock_guard = std::unique_lock; + +#endif diff --git a/include/registry.h b/include/registry.h index e7150115..7ec36625 100644 --- a/include/registry.h +++ b/include/registry.h @@ -26,6 +26,7 @@ limitations under the License. #include +#include "include/mutex.h" #include "include/utils.h" namespace optree { @@ -137,6 +138,8 @@ class PyTreeTypeRegistry { NamedTypeHash, NamedTypeEq> m_named_registrations{}; + + static inline read_write_mutex sm_mutex{}; }; } // namespace optree diff --git a/include/treespec.h b/include/treespec.h index bd5c7c4c..ee0fbfff 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -29,6 +29,7 @@ limitations under the License. #include +#include "include/mutex.h" #include "include/registry.h" #include "include/utils.h" @@ -208,6 +209,8 @@ class PyTreeSpec { // Check if should preserve the insertion order of the dictionary keys during flattening. static inline bool IsDictInsertionOrdered(const std::string ®istry_namespace, const bool &inherit_global_namespace = true) { + const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; + return (sm_is_dict_insertion_ordered.find(registry_namespace) != sm_is_dict_insertion_ordered.end()) || (inherit_global_namespace && @@ -217,6 +220,8 @@ class PyTreeSpec { // Set the namespace to preserve the insertion order of the dictionary keys during flattening. static inline void SetDictInsertionOrdered(const bool &mode, const std::string ®istry_namespace) { + const scoped_write_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; + if (mode) [[likely]] { sm_is_dict_insertion_ordered.insert(registry_namespace); } else [[unlikely]] { @@ -236,13 +241,13 @@ class PyTreeSpec { // Arity for non-Leaf types. ssize_t arity = 0; - // Kind-specific auxiliary data. + // Kind-specific metadata. // For a NamedTuple/PyStructSequence, contains the tuple type object. // For a Dict, contains a sorted list of keys. // For a OrderedDict, contains a list of keys. // For a DefaultDict, contains a tuple of (default_factory, sorted list of keys). // For a Deque, contains the `maxlen` attribute. - // For a Custom type, contains the auxiliary data returned by the `flatten_func` function. + // For a Custom type, contains the metadata returned by the `flatten_func` function. py::object node_data{}; // The tuple of path entries. @@ -361,16 +366,7 @@ class PyTreeSpec { // A set of namespaces that preserve the insertion order of the dictionary keys during // flattening. static inline std::unordered_set sm_is_dict_insertion_ordered{}; - - // A set of (treespec, thread_id) pairs that are currently being represented as strings. - static inline std::unordered_set, - ThreadIndentTypeHash> - sm_repr_running{}; - - // A set of (treespec, thread_id) pairs that are currently being hashed. - static inline std::unordered_set, - ThreadIndentTypeHash> - sm_hash_running{}; + static inline read_write_mutex sm_is_dict_insertion_ordered_mutex{}; }; class PyTreeIter { @@ -379,7 +375,8 @@ class PyTreeIter { const std::optional &leaf_predicate, const bool &none_is_leaf, const std::string ®istry_namespace) - : m_agenda{{{tree, 0}}}, + : m_root{tree}, + m_agenda{{{tree, 0}}}, m_leaf_predicate{leaf_predicate}, m_none_is_leaf{none_is_leaf}, m_namespace{registry_namespace}, @@ -401,11 +398,15 @@ class PyTreeIter { static int PyTpTraverse(PyObject *self_base, visitproc visit, void *arg); private: + const py::object m_root; std::vector> m_agenda; const std::optional m_leaf_predicate; const bool m_none_is_leaf; const std::string m_namespace; const bool m_is_dict_insertion_ordered; +#ifdef Py_GIL_DISABLED + mutable mutex m_mutex{}; +#endif template [[nodiscard]] py::object NextImpl(); diff --git a/include/utils.h b/include/utils.h index dd7443d7..4a0ed03d 100644 --- a/include/utils.h +++ b/include/utils.h @@ -21,6 +21,7 @@ limitations under the License. #include // std::hash #include // std::ostringstream #include // std::string +#include // std::enable_if_t, std::is_base_of_v #include // std::unordered_map #include // std::move, std::pair, std::make_pair #include // std::vector @@ -34,6 +35,9 @@ limitations under the License. #include // pybind11::exec #include +#include "include/critical_section.h" +#include "include/mutex.h" + namespace py = pybind11; // The maximum size of the type cache. @@ -181,116 +185,99 @@ inline std::vector reserved_vector(const py::size_t& size) { return v; } -template -inline py::ssize_t GetSize(const py::handle& sized) { - return py::ssize_t_cast(py::len(sized)); -} -template <> -inline py::ssize_t GetSize(const py::handle& sized) { - return PyTuple_Size(sized.ptr()); -} -template <> -inline py::ssize_t GetSize(const py::handle& sized) { - return PyList_Size(sized.ptr()); -} -template <> -inline py::ssize_t GetSize(const py::handle& sized) { - return PyDict_Size(sized.ptr()); +template +inline T thread_safe_cast(const py::handle& handle) { + return EVALUATE_WITH_LOCK_HELD(py::cast(handle), handle); } -template -inline py::ssize_t GET_SIZE(const py::handle& sized) { - return py::ssize_t_cast(py::len(sized)); -} -template <> -inline py::ssize_t GET_SIZE(const py::handle& sized) { - return PyTuple_GET_SIZE(sized.ptr()); -} -template <> -inline py::ssize_t GET_SIZE(const py::handle& sized) { - return PyList_GET_SIZE(sized.ptr()); -} -#ifndef PyDict_GET_SIZE -#define PyDict_GET_SIZE PyDict_Size +inline py::ssize_t TupleGetSize(const py::handle& tuple) { return PyTuple_GET_SIZE(tuple.ptr()); } +inline py::ssize_t ListGetSize(const py::handle& list) { return PyList_GET_SIZE(list.ptr()); } +inline py::ssize_t DictGetSize(const py::handle& dict) { +#ifdef PyDict_GET_SIZE + return PyDict_GET_SIZE(dict.ptr()); +#else + return PyDict_Size(dict.ptr()); #endif -template <> -inline py::ssize_t GET_SIZE(const py::handle& sized) { - return PyDict_GET_SIZE(sized.ptr()); } -template -inline py::handle GET_ITEM_HANDLE(const py::handle& container, const Item& item) { - return container[item]; +template >> +inline T TupleGetItemAs(const py::handle& tuple, const py::ssize_t& index) { + return py::reinterpret_borrow(PyTuple_GET_ITEM(tuple.ptr(), index)); } -template <> -inline py::handle GET_ITEM_HANDLE(const py::handle& container, const py::ssize_t& item) { - return PyTuple_GET_ITEM(container.ptr(), item); +inline py::object TupleGetItem(const py::handle& tuple, const py::ssize_t& index) { + return TupleGetItemAs(tuple, index); } -template <> -inline py::handle GET_ITEM_HANDLE(const py::handle& container, const py::ssize_t& item) { - return PyList_GET_ITEM(container.ptr(), item); +template >> +inline T ListGetItemAs(const py::handle& list, const py::ssize_t& index) { +#if PY_VERSION_HEX >= 0x030D00A4 // Python 3.13.0a4 + PyObject* item = PyList_GetItemRef(list.ptr(), index); + if (item == nullptr) [[unlikely]] { + throw py::error_already_set(); + } + return py::reinterpret_steal(item); +#else + return py::reinterpret_borrow(PyList_GET_ITEM(list.ptr(), index)); +#endif } - -template -inline py::object GET_ITEM_BORROW(const py::handle& container, const Item& item) { - return py::reinterpret_borrow(container[item]); +inline py::object ListGetItem(const py::handle& list, const py::ssize_t& index) { + return ListGetItemAs(list, index); } -template <> -inline py::object GET_ITEM_BORROW(const py::handle& container, const py::ssize_t& item) { - return py::reinterpret_borrow(PyTuple_GET_ITEM(container.ptr(), item)); +template >> +inline T DictGetItemAs(const py::handle& dict, const py::handle& key) { +#if PY_VERSION_HEX >= 0x030D00A1 // Python 3.13.0a1 + PyObject* value = nullptr; + if (PyDict_GetItemRef(dict.ptr(), key.ptr(), &value) < 0) [[unlikely]] { + throw py::error_already_set(); + } + if (value == nullptr) [[unlikely]] { + py::set_error(PyExc_KeyError, py::make_tuple(key)); + throw py::error_already_set(); + } + return py::reinterpret_steal(value); +#else + return py::reinterpret_borrow(PyDict_GetItem(dict.ptr(), key.ptr())); +#endif } -template <> -inline py::object GET_ITEM_BORROW(const py::handle& container, const py::ssize_t& item) { - return py::reinterpret_borrow(PyList_GET_ITEM(container.ptr(), item)); +inline py::object DictGetItem(const py::handle& dict, const py::handle& key) { + return DictGetItemAs(dict, key); } -template -inline void SET_ITEM(const py::handle& container, const Item& item, const py::handle& value) { - container[item] = value; +inline void TupleSetItem(const py::handle& tuple, + const py::ssize_t& index, + const py::handle& value) { + PyTuple_SET_ITEM(tuple.ptr(), index, value.inc_ref().ptr()); } -template <> -inline void SET_ITEM(const py::handle& container, - const py::ssize_t& item, - const py::handle& value) { - PyTuple_SET_ITEM(container.ptr(), item, value.inc_ref().ptr()); +inline void ListSetItem(const py::handle& list, const py::ssize_t& index, const py::handle& value) { + PyList_SET_ITEM(list.ptr(), index, value.inc_ref().ptr()); } -template <> -inline void SET_ITEM(const py::handle& container, - const py::ssize_t& item, - const py::handle& value) { - PyList_SET_ITEM(container.ptr(), item, value.inc_ref().ptr()); +inline void DictSetItem(const py::handle& dict, const py::handle& key, const py::handle& value) { + if (PyDict_SetItem(dict.ptr(), key.ptr(), value.ptr()) < 0) [[unlikely]] { + throw py::error_already_set(); + } } +inline std::string PyStr(const py::handle& object) { + return EVALUATE_WITH_LOCK_HELD(static_cast(py::str(object)), object); +} +inline std::string PyStr(const std::string& string) { return string; } inline std::string PyRepr(const py::handle& object) { - return static_cast(py::repr(object)); + return EVALUATE_WITH_LOCK_HELD(static_cast(py::repr(object)), object); } inline std::string PyRepr(const std::string& string) { return static_cast(py::repr(py::str(string))); } -template -inline void AssertExact(const py::handle& object) { - if (!py::isinstance(object)) [[unlikely]] { - std::ostringstream oss{}; - oss << "Expected an instance of " << typeid(PyType).name() << ", got " << PyRepr(object) - << "."; - throw py::value_error(oss.str()); - } -} -template <> -inline void AssertExact(const py::handle& object) { +inline void AssertExactList(const py::handle& object) { if (!PyList_CheckExact(object.ptr())) [[unlikely]] { throw py::value_error("Expected an instance of list, got " + PyRepr(object) + "."); } } -template <> -inline void AssertExact(const py::handle& object) { +inline void AssertExactTuple(const py::handle& object) { if (!PyTuple_CheckExact(object.ptr())) [[unlikely]] { throw py::value_error("Expected an instance of tuple, got " + PyRepr(object) + "."); } } -template <> -inline void AssertExact(const py::handle& object) { +inline void AssertExactDict(const py::handle& object) { if (!PyDict_CheckExact(object.ptr())) [[unlikely]] { throw py::value_error("Expected an instance of dict, got " + PyRepr(object) + "."); } @@ -372,19 +359,28 @@ inline bool IsNamedTupleClass(const py::handle& type) { } static auto cache = std::unordered_map{}; - const auto it = cache.find(type); - if (it != cache.end()) [[likely]] { - return it->second; + static read_write_mutex mutex{}; + + { + const scoped_read_lock_guard lock{mutex}; + const auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + return it->second; + } } - const bool result = IsNamedTupleClassImpl(type); - if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { - cache.emplace(type, result); - (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { - cache.erase(type); - weakref.dec_ref(); - })) - .release(); + const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type); + { + const scoped_write_lock_guard lock{mutex}; + if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + cache.emplace(type, result); + (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { + const scoped_write_lock_guard lock{mutex}; + cache.erase(type); + weakref.dec_ref(); + })) + .release(); + } } return result; } @@ -416,7 +412,7 @@ inline py::tuple NamedTupleGetFields(const py::handle& object) { PyRepr(object) + "."); } } - return py::getattr(type, Py_Get_ID(_fields)); + return EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(_fields)), type); } inline bool IsStructSequenceClassImpl(const py::handle& type) { @@ -462,19 +458,28 @@ inline bool IsStructSequenceClass(const py::handle& type) { } static auto cache = std::unordered_map{}; - const auto it = cache.find(type); - if (it != cache.end()) [[likely]] { - return it->second; + static read_write_mutex mutex{}; + + { + const scoped_read_lock_guard lock{mutex}; + const auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + return it->second; + } } - const bool result = IsStructSequenceClassImpl(type); - if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { - cache.emplace(type, result); - (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { - cache.erase(type); - weakref.dec_ref(); - })) - .release(); + const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type); + { + const scoped_write_lock_guard lock{mutex}; + if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + cache.emplace(type, result); + (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { + const scoped_write_lock_guard lock{mutex}; + cache.erase(type); + weakref.dec_ref(); + })) + .release(); + } } return result; } @@ -509,12 +514,12 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) { return py::tuple{fields}; #else const auto n_sequence_fields = - py::cast(getattr(type, Py_Get_ID(n_sequence_fields))); + py::cast(py::getattr(type, Py_Get_ID(n_sequence_fields))); const auto* const members = reinterpret_cast(type.ptr())->tp_members; py::tuple fields{n_sequence_fields}; for (py::ssize_t i = 0; i < n_sequence_fields; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - SET_ITEM(fields, i, py::str(members[i].name)); + TupleSetItem(fields, i, py::str(members[i].name)); } return fields; #endif @@ -535,24 +540,33 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) { } static auto cache = std::unordered_map{}; - const auto it = cache.find(type); - if (it != cache.end()) [[likely]] { - return py::reinterpret_borrow(it->second); - } - - const py::tuple fields = StructSequenceGetFieldsImpl(type); - if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { - cache.emplace(type, fields); - fields.inc_ref(); - (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { - const auto it = cache.find(type); - if (it != cache.end()) [[likely]] { - it->second.dec_ref(); - cache.erase(it); - } - weakref.dec_ref(); - })) - .release(); + static read_write_mutex mutex{}; + + { + const scoped_read_lock_guard lock{mutex}; + const auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + return py::reinterpret_borrow(it->second); + } + } + + const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type); + { + const scoped_write_lock_guard lock{mutex}; + if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { + cache.emplace(type, fields); + fields.inc_ref(); + (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { + const scoped_write_lock_guard lock{mutex}; + const auto it = cache.find(type); + if (it != cache.end()) [[likely]] { + it->second.dec_ref(); + cache.erase(it); + } + weakref.dec_ref(); + })) + .release(); + } } return fields; } @@ -560,7 +574,8 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) { inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references] try { // Sort directly if possible. - if (static_cast(PyList_Sort(list.ptr()))) [[unlikely]] { + if (static_cast(EVALUATE_WITH_LOCK_HELD(PyList_Sort(list.ptr()), list))) + [[unlikely]] { throw py::error_already_set(); } } catch (py::error_already_set& ex1) { @@ -571,12 +586,15 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references] const auto sort_key_fn = py::cpp_function([](const py::object& o) -> py::tuple { const py::handle t = py::type::handle_of(o); const py::str qualname{ - static_cast(py::str(py::getattr(t, Py_Get_ID(__module__)))) + - "." + - static_cast(py::str(py::getattr(t, Py_Get_ID(__qualname__))))}; + EVALUATE_WITH_LOCK_HELD(PyStr(py::getattr(t, Py_Get_ID(__module__))) + "." + + PyStr(py::getattr(t, Py_Get_ID(__qualname__))), + t)}; return py::make_tuple(qualname, o); }); - py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn); + { + const scoped_critical_section cs{list}; + py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn); + } } catch (py::error_already_set& ex2) { if (ex2.matches(PyExc_TypeError)) [[likely]] { // Found incomparable user-defined key types. @@ -593,6 +611,7 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references] } inline py::list DictKeys(const py::dict& dict) { + const scoped_critical_section cs{dict}; return py::reinterpret_steal(PyDict_Keys(dict.ptr())); } @@ -603,13 +622,14 @@ inline py::list SortedDictKeys(const py::dict& dict) { } inline bool DictKeysEqual(const py::list& /*unique*/ keys, const py::dict& dict) { - const py::ssize_t list_len = GET_SIZE(keys); - const py::ssize_t dict_len = GET_SIZE(dict); + const scoped_critical_section2 cs{keys, dict}; + const py::ssize_t list_len = ListGetSize(keys); + const py::ssize_t dict_len = DictGetSize(dict); if (list_len != dict_len) [[likely]] { // assumes keys are unique return false; } for (py::ssize_t i = 0; i < list_len; ++i) { - const py::object key = GET_ITEM_BORROW(keys, i); + const py::object key = ListGetItem(keys, i); const int result = PyDict_Contains(dict.ptr(), key.ptr()); if (result == -1) [[unlikely]] { throw py::error_already_set(); @@ -623,8 +643,8 @@ inline bool DictKeysEqual(const py::list& /*unique*/ keys, const py::dict& dict) inline std::pair DictKeysDifference(const py::list& /*unique*/ keys, const py::dict& dict) { - const py::set expected_keys{keys}; - const py::set got_keys{DictKeys(dict)}; + const py::set expected_keys = EVALUATE_WITH_LOCK_HELD(py::set{keys}, keys); + const py::set got_keys = EVALUATE_WITH_LOCK_HELD(py::set{dict}, dict); py::list missing_keys{expected_keys - got_keys}; py::list extra_keys{got_keys - expected_keys}; TotalOrderSort(missing_keys); diff --git a/optree/functools.py b/optree/functools.py index 4d34396a..1262e1b4 100644 --- a/optree/functools.py +++ b/optree/functools.py @@ -148,7 +148,7 @@ def tree_flatten(self) -> tuple[ # type: ignore[override] Callable[..., Any], tuple[str, str], ]: - """Flatten the :class:`partial` instance to children and auxiliary data.""" + """Flatten the :class:`partial` instance to children and metadata.""" return (self.args, self.keywords), self.func, ('args', 'keywords') @classmethod @@ -157,7 +157,7 @@ def tree_unflatten( # type: ignore[override] metadata: Callable[..., Any], children: tuple[tuple[T, ...], dict[str, T]], ) -> Self: - """Unflatten the children and auxiliary data into a :class:`partial` instance.""" + """Unflatten the children and metadata into a :class:`partial` instance.""" args, keywords = children return cls(metadata, *args, **keywords) diff --git a/optree/ops.py b/optree/ops.py index 5ae08f97..853cd3f2 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -2438,7 +2438,7 @@ def tree_flatten_one_level( tuple[Any, ...], Callable[[MetaData, list[PyTree[T]]], PyTree[T]], ]: - """Flatten the pytree one level, returning a 4-tuple of children, auxiliary data, path entries, and an unflatten function. + """Flatten the pytree one level, returning a 4-tuple of children, metadata, path entries, and an unflatten function. See also :func:`tree_flatten`, :func:`tree_flatten_with_path`. @@ -2468,9 +2468,9 @@ def tree_flatten_one_level( Returns: A 4-tuple ``(children, metadata, entries, unflatten_func)``. The first element is a list of - one-level children of the pytree node. The second element is the auxiliary data used to + one-level children of the pytree node. The second element is the metadata used to reconstruct the pytree node. The third element is a tuple of path entries to the children. - The fourth element is a function that can be used to unflatten the auxiliary data and + The fourth element is a function that can be used to unflatten the metadata and children back to the pytree node. """ # pylint: disable=line-too-long node_type = type(tree) @@ -3209,11 +3209,11 @@ def _prefix_error( ): yield lambda name: ValueError( f'pytree structure error: different types at key path\n' - f' {{name}}{accessor.codify("") if accessor else " tree root"}\n' - f'At that key path, the prefix pytree {{name}} has a subtree of type\n' + f' {accessor.codify(name) if accessor else name + " tree root"}\n' + f'At that key path, the prefix pytree {name} has a subtree of type\n' f' {type(prefix_tree)}\n' f'but at the same key path the full pytree has a subtree of different type\n' - f' {type(full_tree)}.'.format(name=name), + f' {type(full_tree)}.', ) return # don't look for more errors in this subtree @@ -3254,15 +3254,15 @@ def _prefix_error( key_difference += f'\nextra key(s):\n {extra_keys}' yield lambda name: ValueError( f'pytree structure error: different pytree keys at key path\n' - f' {{name}}{accessor.codify("") if accessor else " tree root"}\n' - f'At that key path, the prefix pytree {{name}} has a subtree of type\n' + f' {accessor.codify(name) if accessor else name + " tree root"}\n' + f'At that key path, the prefix pytree {name} has a subtree of type\n' f' {prefix_tree_type}\n' f'with {len(prefix_tree_keys)} key(s)\n' f' {prefix_tree_keys}\n' f'but at the same key path the full pytree has a subtree of type\n' f' {full_tree_type}\n' f'but with {len(full_tree_keys)} key(s)\n' - f' {full_tree_keys}{key_difference}'.format(name=name), + f' {full_tree_keys}{key_difference}', ) return # don't look for more errors in this subtree @@ -3272,12 +3272,12 @@ def _prefix_error( if len(prefix_tree_children) != len(full_tree_children): yield lambda name: ValueError( f'pytree structure error: different numbers of pytree children at key path\n' - f' {{name}}{accessor.codify("") if accessor else " tree root"}\n' - f'At that key path, the prefix pytree {{name}} has a subtree of type\n' + f' {accessor.codify(name) if accessor else name + " tree root"}\n' + f'At that key path, the prefix pytree {name} has a subtree of type\n' f' {prefix_tree_type}\n' f'with {len(prefix_tree_children)} children, ' f'but at the same key path the full pytree has a subtree of the same ' - f'type but with {len(full_tree_children)} children.'.format(name=name), + f'type but with {len(full_tree_children)} children.', ) return # don't look for more errors in this subtree @@ -3303,8 +3303,8 @@ def _prefix_error( ) yield lambda name: ValueError( f'pytree structure error: different pytree metadata at key path\n' - f' {{name}}{accessor.codify("") if accessor else " tree root"}\n' - f'At that key path, the prefix pytree {{name}} has a subtree of type\n' + f' {accessor.codify(name) if accessor else name + " tree root"}\n' + f'At that key path, the prefix pytree {name} has a subtree of type\n' f' {prefix_tree_type}\n' f'with metadata\n' f' {prefix_tree_metadata_repr}\n' @@ -3312,7 +3312,7 @@ def _prefix_error( f'type but with metadata\n' f' {full_tree_metadata_repr}\n' f'so the diff in the metadata at these pytree nodes is\n' - f'{metadata_diff}'.format(name=name), + f'{metadata_diff}', ) return # don't look for more errors in this subtree diff --git a/optree/registry.py b/optree/registry.py index 94daf4b2..904dfdd7 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -139,13 +139,13 @@ def register_pytree_node( cls (type): A Python type to treat as an internal pytree node. flatten_func (callable): A function to be used during flattening, taking an instance of ``cls`` and returning a triple or optionally a pair, with (1) an iterable for the children to be - flattened recursively, and (2) some hashable auxiliary data to be stored in the treespec - and to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree - path entries to the corresponding children. If the entries are not provided or given by + flattened recursively, and (2) some hashable metadata to be stored in the treespec and + to be passed to the ``unflatten_func``, and (3) (optional) an iterable for the tree path + entries to the corresponding children. If the entries are not provided or given by :data:`None`, then `range(len(children))` will be used. - unflatten_func (callable): A function taking two arguments: the auxiliary data that was - returned by ``flatten_func`` and stored in the treespec, and the unflattened children. - The function should return an instance of ``cls``. + unflatten_func (callable): A function taking two arguments: the metadata that was returned + by ``flatten_func`` and stored in the treespec, and the unflattened children. The + function should return an instance of ``cls``. path_entry_type (type, optional): The type of the path entry to be used in the treespec. (default: :class:`AutoEntry`) namespace (str): A non-empty string that uniquely identifies the namespace of the type registry. @@ -647,6 +647,10 @@ def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None]: PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace') ) + .. warning:: + The dictionary sorting mode is a global setting and is **not thread-safe**. It is + recommended to use this context manager in a single-threaded environment. + Args: mode (bool): The dictionary sorting mode to set. namespace (str): The namespace to set the dictionary sorting mode for. diff --git a/optree/typing.py b/optree/typing.py index 3670f577..4440b62b 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -141,11 +141,11 @@ def tree_flatten( # With optionally implemented path entries tuple[Children[T], MetaData, Iterable[Any] | None] ): - """Flatten the custom pytree node into children and auxiliary data.""" + """Flatten the custom pytree node into children and metadata.""" @classmethod def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> CustomTreeNode[T]: - """Unflatten the children and auxiliary data into the custom pytree node.""" + """Unflatten the children and metadata into the custom pytree node.""" _UnionType = type(Union[int, str]) @@ -452,10 +452,7 @@ def is_structseq_class(cls: type) -> bool: # Check the type does not allow subclassing if platform.python_implementation() == 'PyPy': try: - # pylint: disable-next=too-few-public-methods - class _(cls): # noqa: N801 - pass - + types.new_class('subclass', bases=(cls,)) except (AssertionError, TypeError): return True return False diff --git a/pyproject.toml b/pyproject.toml index 2dea85db..faab5485 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ optree = ['*.so', '*.pyd'] # Reference: https://cibuildwheel.readthedocs.io [tool.cibuildwheel] archs = ["native"] +free-threaded-support = true skip = "*musllinux*" build-frontend = "build" test-extras = ["test"] @@ -276,6 +277,7 @@ inline-quotes = "single" ban-relative-imports = "all" [tool.pytest.ini_options] +verbosity_assertions = 3 filterwarnings = [ "error", "always", diff --git a/src/registry.cpp b/src/registry.cpp index 02902362..5f50e88f 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -26,6 +26,7 @@ limitations under the License. #include #include "include/exceptions.h" +#include "include/mutex.h" #include "include/utils.h" namespace optree { @@ -153,6 +154,8 @@ template const py::function& unflatten_func, const py::object& path_entry_type, const std::string& registry_namespace) { + const scoped_write_lock_guard lock{sm_mutex}; + RegisterImpl(cls, flatten_func, unflatten_func, @@ -224,6 +227,8 @@ template /*static*/ void PyTreeTypeRegistry::Unregister(const py::object& cls, const std::string& registry_namespace) { + const scoped_write_lock_guard lock{sm_mutex}; + const auto registration1 = UnregisterImpl(cls, registry_namespace); const auto registration2 = UnregisterImpl(cls, registry_namespace); EXPECT_TRUE(registration1->type.is(registration2->type)); @@ -240,6 +245,8 @@ template /*static*/ PyTreeTypeRegistry::RegistrationPtr PyTreeTypeRegistry::Lookup( const py::object& cls, const std::string& registry_namespace) { + const scoped_read_lock_guard lock{sm_mutex}; + PyTreeTypeRegistry* registry = Singleton(); if (!registry_namespace.empty()) [[unlikely]] { const auto named_it = @@ -295,6 +302,8 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( // NOLINTNEXTLINE[readability-function-cognitive-complexity] /*static*/ void PyTreeTypeRegistry::Clear() { + const scoped_write_lock_guard lock{sm_mutex}; + PyTreeTypeRegistry* registry1 = PyTreeTypeRegistry::Singleton(); PyTreeTypeRegistry* registry2 = PyTreeTypeRegistry::Singleton(); diff --git a/src/treespec/constructor.cpp b/src/treespec/constructor.cpp index 4ff9a6c5..72e0459b 100644 --- a/src/treespec/constructor.cpp +++ b/src/treespec/constructor.cpp @@ -24,6 +24,7 @@ limitations under the License. #include // std::move #include // std::vector +#include "include/critical_section.h" #include "include/exceptions.h" #include "include/registry.h" #include "include/treespec.h" @@ -87,7 +88,7 @@ template << PyRepr(handle) << "."; throw py::value_error(oss.str()); } - treespecs.emplace_back(py::cast(child)); + treespecs.emplace_back(thread_safe_cast(child)); } std::string common_registry_namespace{}; @@ -143,18 +144,21 @@ template } case PyTreeKind::Tuple: { - node.arity = GET_SIZE(handle); + node.arity = TupleGetSize(handle); for (ssize_t i = 0; i < node.arity; ++i) { - children.emplace_back(GET_ITEM_BORROW(handle, i)); + children.emplace_back(TupleGetItem(handle, i)); } verify_children(children, treespecs, registry_namespace); break; } case PyTreeKind::List: { - node.arity = GET_SIZE(handle); - for (ssize_t i = 0; i < node.arity; ++i) { - children.emplace_back(GET_ITEM_BORROW(handle, i)); + { + const scoped_critical_section cs{handle}; + node.arity = ListGetSize(handle); + for (ssize_t i = 0; i < node.arity; ++i) { + children.emplace_back(ListGetItem(handle, i)); + } } verify_children(children, treespecs, registry_namespace); break; @@ -163,20 +167,25 @@ template case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { - const auto dict = py::reinterpret_borrow(handle); - node.arity = GET_SIZE(dict); - py::list keys = DictKeys(dict); - if (node.kind != PyTreeKind::OrderedDict) [[likely]] { - node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { - TotalOrderSort(keys); + py::list keys; + { + const scoped_critical_section cs{handle}; + const auto dict = py::reinterpret_borrow(handle); + node.arity = DictGetSize(dict); + keys = DictKeys(dict); + if (node.kind != PyTreeKind::OrderedDict) [[likely]] { + node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); + if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { + TotalOrderSort(keys); + } + } + for (const py::handle& key : keys) { + children.emplace_back(DictGetItem(dict, key)); } - } - for (const py::handle& key : keys) { - children.emplace_back(dict[key]); } verify_children(children, treespecs, registry_namespace); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), std::move(keys)); } else [[likely]] { @@ -188,29 +197,33 @@ template case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: { const auto tuple = py::reinterpret_borrow(handle); - node.arity = GET_SIZE(tuple); + node.arity = TupleGetSize(tuple); node.node_data = py::type::of(tuple); for (ssize_t i = 0; i < node.arity; ++i) { - children.emplace_back(GET_ITEM_BORROW(tuple, i)); + children.emplace_back(TupleGetItem(tuple, i)); } verify_children(children, treespecs, registry_namespace); break; } case PyTreeKind::Deque: { - const auto list = py::cast(handle); - node.arity = GET_SIZE(list); - node.node_data = py::getattr(handle, Py_Get_ID(maxlen)); + const auto list = thread_safe_cast(handle); + node.arity = ListGetSize(list); + node.node_data = + EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); for (ssize_t i = 0; i < node.arity; ++i) { - children.emplace_back(GET_ITEM_BORROW(list, i)); + children.emplace_back(ListGetItem(list, i)); } verify_children(children, treespecs, registry_namespace); break; } case PyTreeKind::Custom: { - const py::tuple out = py::cast(node.custom->flatten_func(handle)); - const ssize_t num_out = GET_SIZE(out); + const py::tuple out = EVALUATE_WITH_LOCK_HELD2( + thread_safe_cast(node.custom->flatten_func(handle)), + handle, + node.custom->flatten_func); + const ssize_t num_out = TupleGetSize(out); if (num_out != 2 && num_out != 3) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " << PyRepr(node.custom->type) @@ -218,19 +231,21 @@ template throw std::runtime_error(oss.str()); } node.arity = 0; - node.node_data = GET_ITEM_BORROW(out, ssize_t(1)); - auto children_iterator = - py::cast(GET_ITEM_BORROW(out, ssize_t(0))); - for (const py::handle& child : children_iterator) { - ++node.arity; - children.emplace_back(py::reinterpret_borrow(child)); + node.node_data = TupleGetItem(out, 1); + { + auto children_iterable = thread_safe_cast(TupleGetItem(out, 0)); + const scoped_critical_section cs{children_iterable}; + for (const py::handle& child : children_iterable) { + ++node.arity; + children.emplace_back(py::reinterpret_borrow(child)); + } } verify_children(children, treespecs, registry_namespace); if (num_out == 3) [[likely]] { - py::object node_entries = GET_ITEM_BORROW(out, ssize_t(2)); + const py::object node_entries = TupleGetItem(out, 2); if (!node_entries.is_none()) [[likely]] { - node.node_entries = py::cast(std::move(node_entries)); - const ssize_t num_entries = GET_SIZE(node.node_entries); + node.node_entries = thread_safe_cast(node_entries); + const ssize_t num_entries = TupleGetSize(node.node_entries); if (num_entries != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index 39ec5f3e..ca455674 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -24,7 +24,9 @@ limitations under the License. #include // std::move, std::pair, std::make_pair #include // std::vector +#include "include/critical_section.h" #include "include/exceptions.h" +#include "include/mutex.h" #include "include/registry.h" #include "include/treespec.h" #include "include/utils.h" @@ -49,7 +51,10 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, const ssize_t start_num_nodes = py::ssize_t_cast(m_traversal.size()); const ssize_t start_num_leaves = py::ssize_t_cast(leaves.size()); - if (leaf_predicate && py::cast((*leaf_predicate)(handle))) [[unlikely]] { + if (leaf_predicate && + EVALUATE_WITH_LOCK_HELD2(thread_safe_cast((*leaf_predicate)(handle)), + handle, + *leaf_predicate)) [[unlikely]] { leaves.emplace_back(py::reinterpret_borrow(handle)); } else [[likely]] { node.kind = @@ -80,17 +85,18 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, } case PyTreeKind::Tuple: { - node.arity = GET_SIZE(handle); + node.arity = TupleGetSize(handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(handle, i)); + recurse(TupleGetItem(handle, i)); } break; } case PyTreeKind::List: { - node.arity = GET_SIZE(handle); + const scoped_critical_section cs{handle}; + node.arity = ListGetSize(handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(handle, i)); + recurse(ListGetItem(handle, i)); } break; } @@ -98,19 +104,24 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { - const auto dict = py::reinterpret_borrow(handle); - node.arity = GET_SIZE(dict); - py::list keys = DictKeys(dict); - if (node.kind != PyTreeKind::OrderedDict) [[likely]] { - node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); - if constexpr (DictShouldBeSorted) { - TotalOrderSort(keys); + py::list keys; + { + const scoped_critical_section cs{handle}; + const auto dict = py::reinterpret_borrow(handle); + node.arity = DictGetSize(dict); + keys = DictKeys(dict); + if (node.kind != PyTreeKind::OrderedDict) [[likely]] { + node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); + if constexpr (DictShouldBeSorted) { + TotalOrderSort(keys); + } + } + for (const py::handle& key : keys) { + recurse(DictGetItem(dict, key)); } - } - for (const py::handle& key : keys) { - recurse(dict[key]); } if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + const scoped_critical_section cs{handle}; node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), std::move(keys)); } else [[likely]] { @@ -122,28 +133,32 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: { const auto tuple = py::reinterpret_borrow(handle); - node.arity = GET_SIZE(tuple); + node.arity = TupleGetSize(tuple); node.node_data = py::type::of(tuple); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(tuple, i)); + recurse(TupleGetItem(tuple, i)); } break; } case PyTreeKind::Deque: { - const auto list = py::cast(handle); - node.arity = GET_SIZE(list); - node.node_data = py::getattr(handle, Py_Get_ID(maxlen)); + const auto list = thread_safe_cast(handle); + node.arity = ListGetSize(list); + node.node_data = + EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(list, i)); + recurse(ListGetItem(list, i)); } break; } case PyTreeKind::Custom: { found_custom = true; - const py::tuple out = py::cast(node.custom->flatten_func(handle)); - const ssize_t num_out = GET_SIZE(out); + const py::tuple out = EVALUATE_WITH_LOCK_HELD2( + thread_safe_cast(node.custom->flatten_func(handle)), + handle, + node.custom->flatten_func); + const ssize_t num_out = TupleGetSize(out); if (num_out != 2 && num_out != 3) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " << PyRepr(node.custom->type) @@ -151,17 +166,20 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle& handle, throw std::runtime_error(oss.str()); } node.arity = 0; - node.node_data = GET_ITEM_BORROW(out, ssize_t(1)); - auto children = py::cast(GET_ITEM_BORROW(out, ssize_t(0))); - for (const py::handle& child : children) { - ++node.arity; - recurse(child); + node.node_data = TupleGetItem(out, 1); + { + auto children = thread_safe_cast(TupleGetItem(out, 0)); + const scoped_critical_section cs{children}; + for (const py::handle& child : children) { + ++node.arity; + recurse(child); + } } if (num_out == 3) [[likely]] { - py::object node_entries = GET_ITEM_BORROW(out, ssize_t(2)); + const py::object node_entries = TupleGetItem(out, 2); if (!node_entries.is_none()) [[likely]] { - node.node_entries = py::cast(std::move(node_entries)); - const ssize_t num_entries = GET_SIZE(node.node_entries); + node.node_entries = thread_safe_cast(node_entries); + const ssize_t num_entries = TupleGetSize(node.node_entries); if (num_entries != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " @@ -190,35 +208,52 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle, const std::optional& leaf_predicate, const bool& none_is_leaf, const std::string& registry_namespace) { + bool found_custom = false; + bool is_dict_insertion_ordered = false; + bool is_dict_insertion_ordered_in_current_namespace = false; + { +#ifdef HAVE_READ_WRITE_LOCK + const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; +#endif + is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered_in_current_namespace = + IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + } + if (none_is_leaf) [[unlikely]] { - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { - return FlattenIntoImpl(handle, - leaves, - 0, - leaf_predicate, - registry_namespace); + if (!is_dict_insertion_ordered) [[likely]] { + found_custom = + FlattenIntoImpl(handle, + leaves, + 0, + leaf_predicate, + registry_namespace); } else [[unlikely]] { - return FlattenIntoImpl(handle, - leaves, - 0, - leaf_predicate, - registry_namespace); + found_custom = + FlattenIntoImpl(handle, + leaves, + 0, + leaf_predicate, + registry_namespace); } } else [[likely]] { - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { - return FlattenIntoImpl(handle, - leaves, - 0, - leaf_predicate, - registry_namespace); + if (!is_dict_insertion_ordered) [[likely]] { + found_custom = + FlattenIntoImpl(handle, + leaves, + 0, + leaf_predicate, + registry_namespace); } else [[unlikely]] { - return FlattenIntoImpl(handle, - leaves, - 0, - leaf_predicate, - registry_namespace); + found_custom = + FlattenIntoImpl(handle, + leaves, + 0, + leaf_predicate, + registry_namespace); } } + return found_custom || is_dict_insertion_ordered_in_current_namespace; } /*static*/ std::pair, std::unique_ptr> PyTreeSpec::Flatten( @@ -229,8 +264,7 @@ bool PyTreeSpec::FlattenInto(const py::handle& handle, auto leaves = reserved_vector(4); auto treespec = std::make_unique(); treespec->m_none_is_leaf = none_is_leaf; - if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace) || - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) + if (treespec->FlattenInto(tree, leaves, leaf_predicate, none_is_leaf, registry_namespace)) [[unlikely]] { treespec->m_namespace = registry_namespace; } @@ -262,10 +296,13 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, const ssize_t start_num_nodes = py::ssize_t_cast(m_traversal.size()); const ssize_t start_num_leaves = py::ssize_t_cast(leaves.size()); - if (leaf_predicate && py::cast((*leaf_predicate)(handle))) [[unlikely]] { + if (leaf_predicate && + EVALUATE_WITH_LOCK_HELD2(thread_safe_cast((*leaf_predicate)(handle)), + handle, + *leaf_predicate)) [[unlikely]] { py::tuple path{depth}; for (ssize_t d = 0; d < depth; ++d) { - SET_ITEM(path, d, stack[d]); + TupleSetItem(path, d, stack[d]); } leaves.emplace_back(py::reinterpret_borrow(handle)); paths.emplace_back(std::move(path)); @@ -296,7 +333,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, case PyTreeKind::Leaf: { py::tuple path{depth}; for (ssize_t d = 0; d < depth; ++d) { - SET_ITEM(path, d, stack[d]); + TupleSetItem(path, d, stack[d]); } leaves.emplace_back(py::reinterpret_borrow(handle)); paths.emplace_back(std::move(path)); @@ -313,17 +350,18 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, } case PyTreeKind::Tuple: { - node.arity = GET_SIZE(handle); + node.arity = TupleGetSize(handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); + recurse(TupleGetItem(handle, i), py::int_(i)); } break; } case PyTreeKind::List: { - node.arity = GET_SIZE(handle); + const scoped_critical_section cs{handle}; + node.arity = ListGetSize(handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(handle, i), py::int_(i)); + recurse(ListGetItem(handle, i), py::int_(i)); } break; } @@ -331,7 +369,9 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { + const scoped_critical_section cs{handle}; const auto dict = py::reinterpret_borrow(handle); + node.arity = DictGetSize(dict); py::list keys = DictKeys(dict); if (node.kind != PyTreeKind::OrderedDict) [[likely]] { node.original_keys = py::getattr(keys, Py_Get_ID(copy))(); @@ -340,9 +380,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, } } for (const py::handle& key : keys) { - recurse(dict[key], key); + recurse(DictGetItem(dict, key), key); } - node.arity = GET_SIZE(dict); if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { node.node_data = py::make_tuple(py::getattr(handle, Py_Get_ID(default_factory)), std::move(keys)); @@ -355,28 +394,32 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: { const auto tuple = py::reinterpret_borrow(handle); - node.arity = GET_SIZE(tuple); + node.arity = TupleGetSize(tuple); node.node_data = py::type::of(tuple); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(tuple, i), py::int_(i)); + recurse(TupleGetItem(tuple, i), py::int_(i)); } break; } case PyTreeKind::Deque: { - const auto list = py::cast(handle); - node.arity = GET_SIZE(list); - node.node_data = py::getattr(handle, Py_Get_ID(maxlen)); + const auto list = thread_safe_cast(handle); + node.arity = ListGetSize(list); + node.node_data = + EVALUATE_WITH_LOCK_HELD(py::getattr(handle, Py_Get_ID(maxlen)), handle); for (ssize_t i = 0; i < node.arity; ++i) { - recurse(GET_ITEM_HANDLE(list, i), py::int_(i)); + recurse(ListGetItem(list, i), py::int_(i)); } break; } case PyTreeKind::Custom: { found_custom = true; - const py::tuple out = py::cast(node.custom->flatten_func(handle)); - const ssize_t num_out = GET_SIZE(out); + const py::tuple out = EVALUATE_WITH_LOCK_HELD2( + thread_safe_cast(node.custom->flatten_func(handle)), + handle, + node.custom->flatten_func); + const ssize_t num_out = TupleGetSize(out); if (num_out != 2 && num_out != 3) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " << PyRepr(node.custom->type) @@ -384,25 +427,25 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, throw std::runtime_error(oss.str()); } node.arity = 0; - node.node_data = GET_ITEM_BORROW(out, ssize_t(1)); + node.node_data = TupleGetItem(out, 1); py::object node_entries; if (num_out == 3) [[likely]] { - node_entries = GET_ITEM_BORROW(out, ssize_t(2)); + node_entries = TupleGetItem(out, 2); } else [[unlikely]] { node_entries = py::none(); } if (node_entries.is_none()) [[unlikely]] { - auto children = - py::cast(GET_ITEM_BORROW(out, ssize_t(0))); + auto children = thread_safe_cast(TupleGetItem(out, 0)); + const scoped_critical_section cs{children}; for (const py::handle& child : children) { recurse(child, py::int_(node.arity++)); } } else [[likely]] { - node.node_entries = py::cast(std::move(node_entries)); - node.arity = GET_SIZE(node.node_entries); + node.node_entries = thread_safe_cast(node_entries); + node.arity = TupleGetSize(node.node_entries); ssize_t num_children = 0; - auto children = - py::cast(GET_ITEM_BORROW(out, ssize_t(0))); + auto children = thread_safe_cast(TupleGetItem(out, 0)); + const scoped_critical_section cs{children}; for (const py::handle& child : children) { if (num_children >= node.arity) [[unlikely]] { throw std::runtime_error( @@ -410,8 +453,7 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle& handle, PyRepr(node.custom->type) + " returned inconsistent number of children and number of entries."); } - recurse(child, - GET_ITEM_BORROW(node.node_entries, num_children++)); + recurse(child, TupleGetItem(node.node_entries, num_children++)); } if (num_children != node.arity) [[unlikely]] { std::ostringstream oss{}; @@ -441,10 +483,22 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, const std::optional& leaf_predicate, const bool& none_is_leaf, const std::string& registry_namespace) { + bool found_custom = false; + bool is_dict_insertion_ordered = false; + bool is_dict_insertion_ordered_in_current_namespace = false; + { +#ifdef HAVE_READ_WRITE_LOCK + const scoped_read_lock_guard lock{sm_is_dict_insertion_ordered_mutex}; +#endif + is_dict_insertion_ordered = IsDictInsertionOrdered(registry_namespace); + is_dict_insertion_ordered_in_current_namespace = + IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false); + } + auto stack = reserved_vector(4); if (none_is_leaf) [[unlikely]] { - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { - return FlattenIntoWithPathImpl( + if (!is_dict_insertion_ordered) [[likely]] { + found_custom = FlattenIntoWithPathImpl( handle, leaves, paths, @@ -453,7 +507,7 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, leaf_predicate, registry_namespace); } else [[unlikely]] { - return FlattenIntoWithPathImpl( + found_custom = FlattenIntoWithPathImpl( handle, leaves, paths, @@ -463,8 +517,8 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, registry_namespace); } } else [[likely]] { - if (!IsDictInsertionOrdered(registry_namespace)) [[likely]] { - return FlattenIntoWithPathImpl( + if (!is_dict_insertion_ordered) [[likely]] { + found_custom = FlattenIntoWithPathImpl( handle, leaves, paths, @@ -473,7 +527,7 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, leaf_predicate, registry_namespace); } else [[unlikely]] { - return FlattenIntoWithPathImpl( + found_custom = FlattenIntoWithPathImpl( handle, leaves, paths, @@ -483,6 +537,7 @@ bool PyTreeSpec::FlattenIntoWithPath(const py::handle& handle, registry_namespace); } } + return found_custom || is_dict_insertion_ordered_in_current_namespace; } /*static*/ std::tuple, std::vector, std::unique_ptr> @@ -499,9 +554,7 @@ PyTreeSpec::FlattenWithPath(const py::object& tree, paths, leaf_predicate, none_is_leaf, - registry_namespace) || - IsDictInsertionOrdered(registry_namespace, /*inherit_global_namespace=*/false)) - [[unlikely]] { + registry_namespace)) [[unlikely]] { treespec->m_namespace = registry_namespace; } treespec->m_traversal.shrink_to_fit(); @@ -533,7 +586,7 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { switch (node.kind) { case PyTreeKind::Leaf: { EXPECT_GE(leaf, 0, "Leaf count mismatch."); - SET_ITEM(leaves, leaf, object); + ListSetItem(leaves, leaf, object); --leaf; break; } @@ -553,33 +606,32 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { } case PyTreeKind::Tuple: { - AssertExact(object); + AssertExactTuple(object); const auto tuple = py::reinterpret_borrow(object); - if (GET_SIZE(tuple) != node.arity) [[unlikely]] { + if (TupleGetSize(tuple) != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "tuple arity mismatch; expected: " << node.arity - << ", got: " << GET_SIZE(tuple) << "; tuple: " << PyRepr(object) - << "."; + << ", got: " << TupleGetSize(tuple) << "; tuple: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } for (ssize_t i = 0; i < node.arity; ++i) { - agenda.emplace_back(GET_ITEM_BORROW(tuple, i)); + agenda.emplace_back(TupleGetItem(tuple, i)); } break; } case PyTreeKind::List: { - AssertExact(object); + AssertExactList(object); + const scoped_critical_section cs{object}; const auto list = py::reinterpret_borrow(object); - if (GET_SIZE(list) != node.arity) [[unlikely]] { + if (ListGetSize(list) != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "list arity mismatch; expected: " << node.arity - << ", got: " << GET_SIZE(list) << "; list: " << PyRepr(object) - << "."; + << ", got: " << ListGetSize(list) << "; list: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } for (ssize_t i = 0; i < node.arity; ++i) { - agenda.emplace_back(GET_ITEM_BORROW(list, i)); + agenda.emplace_back(ListGetItem(list, i)); } break; } @@ -588,19 +640,20 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { AssertExactStandardDict(object); + const scoped_critical_section2 cs{object, node.node_data}; const auto dict = py::reinterpret_borrow(object); const py::list expected_keys = (node.kind != PyTreeKind::DefaultDict - ? node.node_data - : GET_ITEM_BORROW(node.node_data, ssize_t(1))); + ? py::reinterpret_borrow(node.node_data) + : TupleGetItemAs(node.node_data, 1)); if (!DictKeysEqual(expected_keys, dict)) [[unlikely]] { const py::list keys = SortedDictKeys(dict); const auto [missing_keys, extra_keys] = DictKeysDifference(expected_keys, dict); std::ostringstream key_difference_sstream{}; - if (GET_SIZE(missing_keys) != 0) [[likely]] { + if (ListGetSize(missing_keys) != 0) [[likely]] { key_difference_sstream << ", missing key(s): " << PyRepr(missing_keys); } - if (GET_SIZE(extra_keys) != 0) [[likely]] { + if (ListGetSize(extra_keys) != 0) [[likely]] { key_difference_sstream << ", extra key(s): " << PyRepr(extra_keys); } std::ostringstream oss{}; @@ -617,7 +670,7 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { throw py::value_error(oss.str()); } for (const py::handle& key : expected_keys) { - agenda.emplace_back(dict[key]); + agenda.emplace_back(DictGetItem(dict, key)); } break; } @@ -625,11 +678,10 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { case PyTreeKind::NamedTuple: { AssertExactNamedTuple(object); const auto tuple = py::reinterpret_borrow(object); - if (GET_SIZE(tuple) != node.arity) [[unlikely]] { + if (TupleGetSize(tuple) != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "namedtuple arity mismatch; expected: " << node.arity - << ", got: " << GET_SIZE(tuple) << "; tuple: " << PyRepr(object) - << "."; + << ", got: " << TupleGetSize(tuple) << "; tuple: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] { @@ -640,23 +692,22 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { throw py::value_error(oss.str()); } for (ssize_t i = 0; i < node.arity; ++i) { - agenda.emplace_back(GET_ITEM_BORROW(tuple, i)); + agenda.emplace_back(TupleGetItem(tuple, i)); } break; } case PyTreeKind::Deque: { AssertExactDeque(object); - const auto list = py::cast(object); - if (GET_SIZE(list) != node.arity) [[unlikely]] { + const auto list = thread_safe_cast(object); + if (ListGetSize(list) != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "deque arity mismatch; expected: " << node.arity - << ", got: " << GET_SIZE(list) << "; deque: " << PyRepr(object) - << "."; + << ", got: " << ListGetSize(list) << "; deque: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } for (ssize_t i = 0; i < node.arity; ++i) { - agenda.emplace_back(GET_ITEM_BORROW(list, i)); + agenda.emplace_back(ListGetItem(list, i)); } break; } @@ -664,11 +715,10 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { case PyTreeKind::StructSequence: { AssertExactStructSequence(object); const auto tuple = py::reinterpret_borrow(object); - if (GET_SIZE(tuple) != node.arity) [[unlikely]] { + if (TupleGetSize(tuple) != node.arity) [[unlikely]] { std::ostringstream oss{}; oss << "PyStructSequence arity mismatch; expected: " << node.arity - << ", got: " << GET_SIZE(tuple) << "; tuple: " << PyRepr(object) - << "."; + << ", got: " << TupleGetSize(tuple) << "; tuple: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } if (py::type::handle_of(object).not_equal(node.node_data)) [[unlikely]] { @@ -680,7 +730,7 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { throw py::value_error(oss.str()); } for (ssize_t i = 0; i < node.arity; ++i) { - agenda.emplace_back(GET_ITEM_BORROW(tuple, i)); + agenda.emplace_back(TupleGetItem(tuple, i)); } break; } @@ -701,26 +751,36 @@ py::list PyTreeSpec::FlattenUpTo(const py::object& full_tree) const { << "; value: " << PyRepr(object) << "."; throw py::value_error(oss.str()); } - const py::tuple out = py::cast(node.custom->flatten_func(object)); - const ssize_t num_out = GET_SIZE(out); + const py::tuple out = EVALUATE_WITH_LOCK_HELD2( + thread_safe_cast(node.custom->flatten_func(object)), + object, + node.custom->flatten_func); + const ssize_t num_out = TupleGetSize(out); if (num_out != 2 && num_out != 3) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " << PyRepr(node.custom->type) << " should return a 2- or 3-tuple, got " << num_out << "."; throw std::runtime_error(oss.str()); } - const py::object node_data = GET_ITEM_BORROW(out, ssize_t(1)); - if (node.node_data.not_equal(node_data)) [[unlikely]] { - std::ostringstream oss{}; - oss << "Mismatch custom node data; expected: " << PyRepr(node.node_data) - << ", got: " << PyRepr(node_data) << "; value: " << PyRepr(object) << "."; - throw py::value_error(oss.str()); + { + const py::object node_data = TupleGetItem(out, 1); + const scoped_critical_section2 cs{node.node_data, node_data}; + if (node.node_data.not_equal(node_data)) [[unlikely]] { + std::ostringstream oss{}; + oss << "Mismatch custom node data; expected: " << PyRepr(node.node_data) + << ", got: " << PyRepr(node_data) << "; value: " << PyRepr(object) + << "."; + throw py::value_error(oss.str()); + } } ssize_t arity = 0; - auto children = py::cast(GET_ITEM_BORROW(out, ssize_t(0))); - for (const py::handle& child : children) { - ++arity; - agenda.emplace_back(py::reinterpret_borrow(child)); + { + auto children = thread_safe_cast(TupleGetItem(out, 0)); + const scoped_critical_section cs{children}; + for (const py::handle& child : children) { + ++arity; + agenda.emplace_back(py::reinterpret_borrow(child)); + } } if (arity != node.arity) [[unlikely]] { std::ostringstream oss{}; @@ -748,7 +808,10 @@ template bool IsLeafImpl(const py::handle& handle, const std::optional& leaf_predicate, const std::string& registry_namespace) { - if (leaf_predicate && py::cast((*leaf_predicate)(handle))) [[unlikely]] { + if (leaf_predicate && + EVALUATE_WITH_LOCK_HELD2(thread_safe_cast((*leaf_predicate)(handle)), + handle, + *leaf_predicate)) [[unlikely]] { return true; } PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; @@ -773,7 +836,10 @@ bool AllLeavesImpl(const py::iterable& iterable, const std::string& registry_namespace) { PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; for (const py::handle& handle : iterable) { - if (leaf_predicate && py::cast((*leaf_predicate)(handle))) [[unlikely]] { + if (leaf_predicate && + EVALUATE_WITH_LOCK_HELD2(thread_safe_cast((*leaf_predicate)(handle)), + handle, + *leaf_predicate)) [[unlikely]] { continue; } if (PyTreeTypeRegistry::GetKind(handle, custom, registry_namespace) != @@ -788,6 +854,7 @@ bool AllLeaves(const py::iterable& iterable, const std::optional& leaf_predicate, const bool& none_is_leaf, const std::string& registry_namespace) { + const scoped_critical_section cs{iterable}; if (none_is_leaf) [[unlikely]] { return AllLeavesImpl(iterable, leaf_predicate, registry_namespace); } else [[likely]] { diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 6ffa1df8..f02e43e8 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -21,7 +21,9 @@ limitations under the License. #include // std::string #include // std::move +#include "include/critical_section.h" #include "include/exceptions.h" +#include "include/mutex.h" #include "include/registry.h" #include "include/treespec.h" #include "include/utils.h" @@ -41,7 +43,10 @@ py::object PyTreeIter::NextImpl() { throw py::error_already_set(); } - if (m_leaf_predicate && py::cast((*m_leaf_predicate)(object))) [[unlikely]] { + if (m_leaf_predicate && + EVALUATE_WITH_LOCK_HELD2(thread_safe_cast((*m_leaf_predicate)(object)), + object, + *m_leaf_predicate)) [[unlikely]] { return object; } @@ -65,17 +70,18 @@ py::object PyTreeIter::NextImpl() { } case PyTreeKind::Tuple: { - const ssize_t arity = GET_SIZE(object); + const ssize_t arity = TupleGetSize(object); for (ssize_t i = arity - 1; i >= 0; --i) { - m_agenda.emplace_back(GET_ITEM_BORROW(object, i), depth); + m_agenda.emplace_back(TupleGetItem(object, i), depth); } break; } case PyTreeKind::List: { - const ssize_t arity = GET_SIZE(object); + const scoped_critical_section cs{object}; + const ssize_t arity = ListGetSize(object); for (ssize_t i = arity - 1; i >= 0; --i) { - m_agenda.emplace_back(GET_ITEM_BORROW(object, i), depth); + m_agenda.emplace_back(ListGetItem(object, i), depth); } break; } @@ -83,6 +89,7 @@ py::object PyTreeIter::NextImpl() { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { + const scoped_critical_section cs{object}; const auto dict = py::reinterpret_borrow(object); py::list keys = DictKeys(dict); if (kind != PyTreeKind::OrderedDict && !m_is_dict_insertion_ordered) [[likely]] { @@ -92,7 +99,7 @@ py::object PyTreeIter::NextImpl() { throw py::error_already_set(); } for (const py::handle& key : keys) { - m_agenda.emplace_back(dict[key], depth); + m_agenda.emplace_back(DictGetItem(dict, key), depth); } break; } @@ -100,38 +107,41 @@ py::object PyTreeIter::NextImpl() { case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: { const auto tuple = py::reinterpret_borrow(object); - const ssize_t arity = GET_SIZE(tuple); + const ssize_t arity = TupleGetSize(tuple); for (ssize_t i = arity - 1; i >= 0; --i) { - m_agenda.emplace_back(GET_ITEM_BORROW(tuple, i), depth); + m_agenda.emplace_back(TupleGetItem(tuple, i), depth); } break; } case PyTreeKind::Deque: { - const auto list = py::cast(object); - const ssize_t arity = GET_SIZE(list); + const auto list = thread_safe_cast(object); + const ssize_t arity = ListGetSize(list); for (ssize_t i = arity - 1; i >= 0; --i) { - m_agenda.emplace_back(GET_ITEM_BORROW(list, i), depth); + m_agenda.emplace_back(ListGetItem(list, i), depth); } break; } case PyTreeKind::Custom: { - const py::tuple out = py::cast(custom->flatten_func(object)); - const ssize_t num_out = GET_SIZE(out); + const py::tuple out = EVALUATE_WITH_LOCK_HELD2( + thread_safe_cast(custom->flatten_func(object)), + object, + custom->flatten_func); + const ssize_t num_out = TupleGetSize(out); if (num_out != 2 && num_out != 3) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " << PyRepr(custom->type) << " should return a 2- or 3-tuple, got " << num_out << "."; throw std::runtime_error(oss.str()); } - auto children = py::cast(GET_ITEM_BORROW(out, ssize_t(0))); - const ssize_t arity = GET_SIZE(children); + auto children = thread_safe_cast(TupleGetItem(out, 0)); + const ssize_t arity = TupleGetSize(children); if (num_out == 3) [[likely]] { - py::object node_entries = GET_ITEM_BORROW(out, ssize_t(2)); + const py::object node_entries = TupleGetItem(out, 2); if (!node_entries.is_none()) [[likely]] { const ssize_t num_entries = - GET_SIZE(py::cast(std::move(node_entries))); + TupleGetSize(thread_safe_cast(node_entries)); if (num_entries != arity) [[unlikely]] { std::ostringstream oss{}; oss << "PyTree custom flatten function for type " @@ -143,7 +153,7 @@ py::object PyTreeIter::NextImpl() { } } for (ssize_t i = arity - 1; i >= 0; --i) { - m_agenda.emplace_back(GET_ITEM_BORROW(children, i), depth); + m_agenda.emplace_back(TupleGetItem(children, i), depth); } break; } @@ -157,6 +167,10 @@ py::object PyTreeIter::NextImpl() { } py::object PyTreeIter::Next() { +#ifdef Py_GIL_DISABLED + const scoped_lock_guard lock{m_mutex}; +#endif + if (m_none_is_leaf) [[unlikely]] { return NextImpl(); } else [[likely]] { @@ -167,6 +181,7 @@ py::object PyTreeIter::Next() { py::object PyTreeSpec::Walk(const py::function& f_node, const std::optional& f_leaf, const py::iterable& leaves) const { + const scoped_critical_section cs{leaves}; auto agenda = reserved_vector(4); auto it = leaves.begin(); for (const Node& node : m_traversal) { @@ -177,11 +192,8 @@ py::object PyTreeSpec::Walk(const py::function& f_node, } const auto leaf = py::reinterpret_borrow(*it); - if (f_leaf) [[likely]] { - agenda.emplace_back((*f_leaf)(leaf)); - } else [[unlikely]] { - agenda.emplace_back(leaf); - } + agenda.emplace_back( + f_leaf ? EVALUATE_WITH_LOCK_HELD2((*f_leaf)(leaf), leaf, *f_leaf) : leaf); ++it; break; } @@ -201,10 +213,13 @@ py::object PyTreeSpec::Walk(const py::function& f_node, "Too few elements for custom type."); const py::tuple tuple{node.arity}; for (ssize_t i = node.arity - 1; i >= 0; --i) { - SET_ITEM(tuple, i, agenda.back()); + TupleSetItem(tuple, i, agenda.back()); agenda.pop_back(); } - agenda.emplace_back(f_node(tuple, (node.node_data ? node.node_data : py::none()))); + agenda.emplace_back(EVALUATE_WITH_LOCK_HELD2( + f_node(tuple, (node.node_data ? node.node_data : py::none())), + node.node_data, + f_node)); break; } diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 93a30327..d32e6743 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -28,10 +28,13 @@ limitations under the License. #include // std::thread::id, std::this_thread #include // std::tuple #include // std::unordered_map +#include // std::unordered_set #include // std::move, std::pair #include // std::vector +#include "include/critical_section.h" #include "include/exceptions.h" +#include "include/mutex.h" #include "include/registry.h" #include "include/utils.h" @@ -86,12 +89,14 @@ namespace optree { py::tuple tuple{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - SET_ITEM(tuple, i, children[i]); + TupleSetItem(tuple, i, children[i]); } if (node.kind == PyTreeKind::NamedTuple) [[unlikely]] { + const scoped_critical_section cs{node.node_data}; return node.node_data(*tuple); } if (node.kind == PyTreeKind::StructSequence) [[unlikely]] { + const scoped_critical_section cs{node.node_data}; return node.node_data(std::move(tuple)); } return std::move(tuple); @@ -102,7 +107,7 @@ namespace optree { py::list list{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - SET_ITEM(list, i, children[i]); + ListSetItem(list, i, children[i]); } if (node.kind == PyTreeKind::Deque) [[unlikely]] { return PyDequeTypeObject(std::move(list), py::arg("maxlen") = node.node_data); @@ -110,58 +115,47 @@ namespace optree { return std::move(list); } - case PyTreeKind::Dict: { + case PyTreeKind::Dict: + case PyTreeKind::OrderedDict: + case PyTreeKind::DefaultDict: { py::dict dict{}; - const auto keys = py::reinterpret_borrow(node.node_data); + const scoped_critical_section2 cs{node.node_data, node.original_keys}; + if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch."); + } + const auto keys = (node.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(node.node_data) + : TupleGetItemAs(node.node_data, 1)); if (node.original_keys) [[unlikely]] { for (ssize_t i = 0; i < node.arity; ++i) { - dict[GET_ITEM_HANDLE(node.original_keys, i)] = py::none(); + DictSetItem(dict, ListGetItem(node.original_keys, i), py::none()); } } for (ssize_t i = 0; i < node.arity; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - dict[GET_ITEM_HANDLE(keys, i)] = children[i]; + DictSetItem(dict, ListGetItem(keys, i), children[i]); } - return std::move(dict); - } - - case PyTreeKind::OrderedDict: { - const py::list items{node.arity}; - const auto keys = py::reinterpret_borrow(node.node_data); - for (ssize_t i = 0; i < node.arity; ++i) { - SET_ITEM( - items, - i, - // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - py::make_tuple(GET_ITEM_HANDLE(keys, i), children[i])); + if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + return PyOrderedDictTypeObject(std::move(dict)); } - return PyOrderedDictTypeObject(items); - } - - case PyTreeKind::DefaultDict: { - const py::dict dict{}; - const py::object default_factory = - GET_ITEM_BORROW(node.node_data, ssize_t(0)); - const py::list keys = GET_ITEM_BORROW(node.node_data, ssize_t(1)); - if (node.original_keys) [[unlikely]] { - for (ssize_t i = 0; i < node.arity; ++i) { - dict[GET_ITEM_HANDLE(node.original_keys, i)] = py::none(); - } - } - for (ssize_t i = 0; i < node.arity; ++i) { - // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - dict[GET_ITEM_HANDLE(keys, i)] = children[i]; + if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { + const py::object default_factory = TupleGetItem(node.node_data, 0); + return EVALUATE_WITH_LOCK_HELD( + PyDefaultDictTypeObject(default_factory, std::move(dict)), + default_factory); } - return PyDefaultDictTypeObject(default_factory, dict); + return std::move(dict); } case PyTreeKind::Custom: { const py::tuple tuple{node.arity}; for (ssize_t i = 0; i < node.arity; ++i) { // NOLINTNEXTLINE[cppcoreguidelines-pro-bounds-pointer-arithmetic] - SET_ITEM(tuple, i, children[i]); + TupleSetItem(tuple, i, children[i]); } - return node.custom->unflatten_func(node.node_data, tuple); + return EVALUATE_WITH_LOCK_HELD2(node.custom->unflatten_func(node.node_data, tuple), + node.node_data, + node.custom->unflatten_func); } default: @@ -309,26 +303,25 @@ namespace optree { throw py::value_error(oss.str()); } - const auto expected_keys = py::reinterpret_borrow( - root.kind != PyTreeKind::DefaultDict - ? root.node_data - : GET_ITEM_BORROW(root.node_data, ssize_t(1))); - auto other_keys = py::reinterpret_borrow( - other_root.kind != PyTreeKind::DefaultDict - ? other_root.node_data - : GET_ITEM_BORROW(other_root.node_data, ssize_t(1))); + const scoped_critical_section2 cs{root.node_data, other_root.node_data}; + const auto expected_keys = (root.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(root.node_data) + : TupleGetItemAs(root.node_data, 1)); + auto other_keys = (other_root.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(other_root.node_data) + : TupleGetItemAs(other_root.node_data, 1)); const py::dict dict{}; for (ssize_t i = 0; i < other_root.arity; ++i) { - dict[GET_ITEM_HANDLE(other_keys, i)] = py::int_(i); + DictSetItem(dict, ListGetItem(other_keys, i), py::int_(i)); } if (!DictKeysEqual(expected_keys, dict)) [[unlikely]] { TotalOrderSort(other_keys); const auto [missing_keys, extra_keys] = DictKeysDifference(expected_keys, dict); std::ostringstream key_difference_sstream{}; - if (GET_SIZE(missing_keys) != 0) [[likely]] { + if (ListGetSize(missing_keys) != 0) [[likely]] { key_difference_sstream << ", missing key(s): " << PyRepr(missing_keys); } - if (GET_SIZE(extra_keys) != 0) [[likely]] { + if (ListGetSize(extra_keys) != 0) [[likely]] { key_difference_sstream << ", extra key(s): " << PyRepr(extra_keys); } std::ostringstream oss{}; @@ -347,8 +340,8 @@ namespace optree { std::reverse(other_curs.begin(), other_curs.end()); const ssize_t last_other_cur = other_cur; for (ssize_t i = root.arity - 1; i >= 0; --i) { - const py::object key = GET_ITEM_BORROW(expected_keys, i); - other_cur = other_curs[py::cast(dict[key])]; + const py::object key = ListGetItem(expected_keys, i); + other_cur = other_curs[py::cast(DictGetItem(dict, key))]; const auto [num_nodes, other_num_nodes, new_num_nodes, new_num_leaves] = // NOLINTNEXTLINE[misc-no-recursion] BroadcastToCommonSuffixImpl(nodes, traversal, cur, other_traversal, other_cur); @@ -406,11 +399,14 @@ namespace optree { << ", got: " << other_root.arity << "."; throw py::value_error(oss.str()); } - if (root.node_data.not_equal(other_root.node_data)) [[unlikely]] { - std::ostringstream oss{}; - oss << "Mismatch custom node data; expected: " << PyRepr(root.node_data) - << ", got: " << PyRepr(other_root.node_data) << "."; - throw py::value_error(oss.str()); + { + const scoped_critical_section2 cs{root.node_data, other_root.node_data}; + if (root.node_data.not_equal(other_root.node_data)) [[unlikely]] { + std::ostringstream oss{}; + oss << "Mismatch custom node data; expected: " << PyRepr(root.node_data) + << ", got: " << PyRepr(other_root.node_data) << "."; + throw py::value_error(oss.str()); + } } break; } @@ -555,14 +551,14 @@ ssize_t PyTreeSpec::PathsImpl(Span& paths, // NOLINT[misc-no-recursion] if (root.node_entries) [[unlikely]] { for (ssize_t i = root.arity - 1; i >= 0; --i) { - cur -= recurse(cur, GET_ITEM_HANDLE(root.node_entries, i)); + cur -= recurse(cur, TupleGetItem(root.node_entries, i)); } } else [[likely]] { switch (root.kind) { case PyTreeKind::Leaf: { py::tuple path{depth}; for (ssize_t d = 0; d < depth; ++d) { - SET_ITEM(path, d, stack[d]); + TupleSetItem(path, d, stack[d]); } paths.emplace_back(std::move(path)); break; @@ -586,12 +582,12 @@ ssize_t PyTreeSpec::PathsImpl(Span& paths, // NOLINT[misc-no-recursion] case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { - const auto keys = py::reinterpret_borrow( - root.kind != PyTreeKind::DefaultDict - ? root.node_data - : GET_ITEM_BORROW(root.node_data, ssize_t(1))); + const scoped_critical_section cs{root.node_data}; + const auto keys = (root.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(root.node_data) + : TupleGetItemAs(root.node_data, 1)); for (ssize_t i = root.arity - 1; i >= 0; --i) { - cur -= recurse(cur, GET_ITEM_HANDLE(keys, i)); + cur -= recurse(cur, ListGetItem(keys, i)); } break; } @@ -647,7 +643,9 @@ ssize_t PyTreeSpec::AccessorsImpl(Span& accessors, // NOLINT[misc-no-recursion] const ssize_t& cur, const py::handle& entry, const py::handle& path_entry_type) -> ssize_t { - stack.emplace_back(path_entry_type(entry, node_type, node_kind)); + stack.emplace_back(EVALUATE_WITH_LOCK_HELD2(path_entry_type(entry, node_type, node_kind), + path_entry_type, + node_type)); const ssize_t num_nodes = AccessorsImpl(accessors, stack, cur, depth + 1); stack.pop_back(); return num_nodes; @@ -660,16 +658,17 @@ ssize_t PyTreeSpec::AccessorsImpl(Span& accessors, // NOLINT[misc-no-recursion] "Node entries are only supported for custom nodes."); EXPECT_NE(root.custom, nullptr, "The custom registration is null."); for (ssize_t i = root.arity - 1; i >= 0; --i) { - cur -= recurse(cur, GET_ITEM_HANDLE(root.node_entries, i), path_entry_type); + cur -= recurse(cur, TupleGetItem(root.node_entries, i), path_entry_type); } } else [[likely]] { switch (root.kind) { case PyTreeKind::Leaf: { const py::tuple typed_path{depth}; for (ssize_t d = 0; d < depth; ++d) { - SET_ITEM(typed_path, d, stack[d]); + TupleSetItem(typed_path, d, stack[d]); } - accessors.emplace_back(PyTreeAccessor(typed_path)); + accessors.emplace_back( + EVALUATE_WITH_LOCK_HELD(PyTreeAccessor(typed_path), PyTreeAccessor)); break; } @@ -691,12 +690,12 @@ ssize_t PyTreeSpec::AccessorsImpl(Span& accessors, // NOLINT[misc-no-recursion] case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { - const auto keys = py::reinterpret_borrow( - root.kind != PyTreeKind::DefaultDict - ? root.node_data - : GET_ITEM_BORROW(root.node_data, ssize_t(1))); + const scoped_critical_section cs{root.node_data}; + const auto keys = (root.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(root.node_data) + : TupleGetItemAs(root.node_data, 1)); for (ssize_t i = root.arity - 1; i >= 0; --i) { - cur -= recurse(cur, GET_ITEM_HANDLE(keys, i), path_entry_type); + cur -= recurse(cur, ListGetItem(keys, i), path_entry_type); } break; } @@ -747,18 +746,19 @@ py::list PyTreeSpec::Entries() const { case PyTreeKind::Custom: { py::list entries{root.arity}; for (ssize_t i = 0; i < root.arity; ++i) { - SET_ITEM(entries, i, py::int_(i)); + ListSetItem(entries, i, py::int_(i)); } return entries; } case PyTreeKind::Dict: case PyTreeKind::OrderedDict: { + const scoped_critical_section cs{root.node_data}; return py::getattr(root.node_data, Py_Get_ID(copy))(); } case PyTreeKind::DefaultDict: { - return py::getattr(GET_ITEM_BORROW(root.node_data, ssize_t(1)), - Py_Get_ID(copy))(); + const scoped_critical_section cs{root.node_data}; + return py::getattr(TupleGetItem(root.node_data, 1), Py_Get_ID(copy))(); } default: @@ -777,7 +777,7 @@ py::object PyTreeSpec::Entry(ssize_t index) const { } if (root.node_entries) [[unlikely]] { - return GET_ITEM_BORROW(root.node_entries, index); + return TupleGetItem(root.node_entries, index); } switch (root.kind) { case PyTreeKind::Tuple: @@ -791,11 +791,12 @@ py::object PyTreeSpec::Entry(ssize_t index) const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: { - return GET_ITEM_BORROW(root.node_data, index); + const scoped_critical_section cs{root.node_data}; + return ListGetItem(root.node_data, index); } case PyTreeKind::DefaultDict: { - return GET_ITEM_BORROW(GET_ITEM_BORROW(root.node_data, ssize_t(1)), - index); + const scoped_critical_section cs{root.node_data}; + return ListGetItem(TupleGetItemAs(root.node_data, 1), index); } case PyTreeKind::None: @@ -931,7 +932,7 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec& other, const bool& strict) const { } bool all_leaves_match = true; - std::vector other_traversal{other.m_traversal.begin(), other.m_traversal.end()}; + std::vector other_traversal{other.m_traversal}; // NOLINTNEXTLINE[readability-qualified-auto] auto b = other_traversal.rbegin(); // NOLINTNEXTLINE[readability-qualified-auto] @@ -969,17 +970,16 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec& other, const bool& strict) const { b->kind != PyTreeKind::DefaultDict) [[likely]] { return false; } - const auto expected_keys = py::reinterpret_borrow( - a->kind != PyTreeKind::DefaultDict - ? a->node_data - : GET_ITEM_BORROW(a->node_data, ssize_t(1))); - const auto other_keys = py::reinterpret_borrow( - b->kind != PyTreeKind::DefaultDict - ? b->node_data - : GET_ITEM_BORROW(b->node_data, ssize_t(1))); + const scoped_critical_section2 cs(a->node_data, b->node_data); + const auto expected_keys = (a->kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(a->node_data) + : TupleGetItemAs(a->node_data, 1)); + const auto other_keys = (b->kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(b->node_data) + : TupleGetItemAs(b->node_data, 1)); const py::dict dict{}; for (ssize_t i = 0; i < b->arity; ++i) { - dict[GET_ITEM_HANDLE(other_keys, i)] = py::int_(i); + DictSetItem(dict, ListGetItem(other_keys, i), py::int_(i)); } if (!DictKeysEqual(expected_keys, dict)) [[likely]] { return false; @@ -1002,8 +1002,9 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec& other, const bool& strict) const { "PyTreeSpec traversal out of range."); auto reordered_index_to_index = std::unordered_map{}; for (ssize_t i = a->arity - 1; i >= 0; --i) { - const py::object key = GET_ITEM_BORROW(expected_keys, i); - reordered_index_to_index.emplace(i, py::cast(dict[key])); + const py::object key = ListGetItem(expected_keys, i); + reordered_index_to_index.emplace(i, + py::cast(DictGetItem(dict, key))); } auto reordered_other_num_nodes = reserved_vector(b->arity); reordered_other_num_nodes.resize(b->arity); @@ -1033,6 +1034,7 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec& other, const bool& strict) const { case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: case PyTreeKind::Custom: { + const scoped_critical_section2 cs(a->node_data, b->node_data); if (a->kind != b->kind || (a->node_data && a->node_data.not_equal(b->node_data))) [[likely]] { return false; @@ -1064,14 +1066,15 @@ bool PyTreeSpec::operator==(const PyTreeSpec& other) const { } // NOLINTNEXTLINE[readability-qualified-auto] - auto b = other.m_traversal.begin(); + auto b = other.m_traversal.cbegin(); // NOLINTNEXTLINE[readability-qualified-auto] - for (auto a = m_traversal.begin(); a != m_traversal.end(); ++a, ++b) { + for (auto a = m_traversal.cbegin(); a != m_traversal.cend(); ++a, ++b) { if (a->kind != b->kind || a->arity != b->arity || static_cast(a->node_data) != static_cast(b->node_data) || a->custom != b->custom) [[likely]] { return false; } + const scoped_critical_section2 cs(a->node_data, b->node_data); if (a->node_data && a->node_data.not_equal(b->node_data)) [[likely]] { return false; } @@ -1090,7 +1093,7 @@ std::string PyTreeSpec::ToStringImpl() const { std::ostringstream children_sstream{}; { bool first = true; - for (auto it = agenda.end() - node.arity; it != agenda.end(); ++it) { + for (auto it = agenda.cend() - node.arity; it != agenda.cend(); ++it) { if (!first) [[likely]] { children_sstream << ", "; } @@ -1103,8 +1106,8 @@ std::string PyTreeSpec::ToStringImpl() const { std::ostringstream sstream{}; switch (node.kind) { case PyTreeKind::Leaf: { - agenda.emplace_back("*"); - continue; + sstream << "*"; + break; } case PyTreeKind::None: { @@ -1129,7 +1132,8 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: { - EXPECT_EQ(GET_SIZE(node.node_data), + const scoped_critical_section cs{node.node_data}; + EXPECT_EQ(ListGetSize(node.node_data), node.arity, "Number of keys and entries does not match."); if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { @@ -1139,7 +1143,7 @@ std::string PyTreeSpec::ToStringImpl() const { sstream << "{"; } bool first = true; - auto child_iter = agenda.end() - node.arity; + auto child_iter = agenda.cend() - node.arity; for (const py::handle& key : node.node_data) { if (!first) [[likely]] { sstream << ", "; @@ -1159,21 +1163,20 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::NamedTuple: { const py::object type = node.node_data; - const auto fields = - py::reinterpret_borrow(py::getattr(type, Py_Get_ID(_fields))); - EXPECT_EQ(GET_SIZE(fields), + const auto fields = NamedTupleGetFields(type); + EXPECT_EQ(TupleGetSize(fields), node.arity, "Number of fields and entries does not match."); const std::string kind = - static_cast(py::str(py::getattr(type, Py_Get_ID(__name__)))); + PyStr(EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__name__)), type)); sstream << kind << "("; bool first = true; - auto child_iter = agenda.end() - node.arity; + auto child_iter = agenda.cend() - node.arity; for (const py::handle& field : fields) { if (!first) [[likely]] { sstream << ", "; } - sstream << static_cast(py::str(field)) << "=" << *child_iter; + sstream << PyStr(field) << "=" << *child_iter; ++child_iter; first = false; } @@ -1182,25 +1185,22 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::DefaultDict: { - EXPECT_EQ(GET_SIZE(node.node_data), - 2, - "Number of auxiliary data mismatch."); - const py::object default_factory = - GET_ITEM_BORROW(node.node_data, ssize_t(0)); - const auto keys = py::reinterpret_borrow( - GET_ITEM_BORROW(node.node_data, ssize_t(1))); - EXPECT_EQ(GET_SIZE(keys), + const scoped_critical_section cs(node.node_data); + EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch."); + const py::object default_factory = TupleGetItem(node.node_data, 0); + const auto keys = TupleGetItemAs(node.node_data, 1); + EXPECT_EQ(ListGetSize(keys), node.arity, "Number of keys and entries does not match."); sstream << "defaultdict(" << PyRepr(default_factory) << ", {"; bool first = true; - auto child_iter = agenda.end() - node.arity; + auto child_it = agenda.cend() - node.arity; for (const py::handle& key : keys) { if (!first) [[likely]] { sstream << ", "; } - sstream << PyRepr(key) << ": " << *child_iter; - ++child_iter; + sstream << PyRepr(key) << ": " << *child_it; + ++child_it; first = false; } sstream << "})"; @@ -1210,7 +1210,7 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::Deque: { sstream << "deque([" << children << "]"; if (!node.node_data.is_none()) [[unlikely]] { - sstream << ", maxlen=" << static_cast(py::str(node.node_data)); + sstream << ", maxlen=" << PyRepr(node.node_data); } sstream << ")"; break; @@ -1219,27 +1219,29 @@ std::string PyTreeSpec::ToStringImpl() const { case PyTreeKind::StructSequence: { const py::object type = node.node_data; const auto fields = StructSequenceGetFields(type); - EXPECT_EQ(GET_SIZE(fields), + EXPECT_EQ(TupleGetSize(fields), node.arity, "Number of fields and entries does not match."); - const py::object module_name = - py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)); + const py::object module_name = EVALUATE_WITH_LOCK_HELD( + py::getattr(type, Py_Get_ID(__module__), Py_Get_ID(__main__)), + type); if (!module_name.is_none()) [[likely]] { - const std::string name = static_cast(py::str(module_name)); + const std::string name = PyStr(module_name); if (!(name.empty() || name == "__main__" || name == "builtins" || name == "__builtins__")) [[likely]] { sstream << name << "."; } } - const py::object qualname = py::getattr(type, Py_Get_ID(__qualname__)); - sstream << static_cast(py::str(qualname)) << "("; + const py::object qualname = + EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(__qualname__)), type); + sstream << PyStr(qualname) << "("; bool first = true; - auto child_iter = agenda.end() - node.arity; + auto child_iter = agenda.cend() - node.arity; for (const py::handle& field : fields) { if (!first) [[likely]] { sstream << ", "; } - sstream << static_cast(py::str(field)) << "=" << *child_iter; + sstream << PyStr(field) << "=" << *child_iter; ++child_iter; first = false; } @@ -1248,8 +1250,9 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::Custom: { - const std::string kind = static_cast( - py::str(py::getattr(node.custom->type, Py_Get_ID(__name__)))); + const std::string kind = PyStr( + EVALUATE_WITH_LOCK_HELD(py::getattr(node.custom->type, Py_Get_ID(__name__)), + node.custom->type)); sstream << "CustomTreeNode(" << kind << "["; if (node.node_data) [[likely]] { sstream << PyRepr(node.node_data); @@ -1262,7 +1265,7 @@ std::string PyTreeSpec::ToStringImpl() const { INTERNAL_ERROR(); } - agenda.erase(agenda.end() - node.arity, agenda.end()); + agenda.resize(agenda.size() - node.arity); agenda.emplace_back(sstream.str()); } @@ -1280,18 +1283,34 @@ std::string PyTreeSpec::ToStringImpl() const { } std::string PyTreeSpec::ToString() const { - const std::pair indent{this, std::this_thread::get_id()}; - if (sm_repr_running.find(indent) != sm_repr_running.end()) [[unlikely]] { - return "..."; + using ThreadIndent = std::pair; + static std::unordered_set running{}; + static read_write_mutex mutex{}; + + const ThreadIndent indent{this, std::this_thread::get_id()}; + { + const scoped_read_lock_guard lock{mutex}; + if (running.find(indent) != running.end()) [[unlikely]] { + return "..."; + } } - sm_repr_running.insert(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.insert(indent); + } try { std::string representation = ToStringImpl(); - sm_repr_running.erase(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.erase(indent); + } return representation; } catch (...) { - sm_repr_running.erase(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.erase(indent); + } std::rethrow_exception(std::current_exception()); } } @@ -1310,6 +1329,7 @@ std::string PyTreeSpec::ToString() const { case PyTreeKind::NamedTuple: case PyTreeKind::Deque: case PyTreeKind::StructSequence: { + const scoped_critical_section cs{node.node_data}; data_hash = py::hash(node.node_data ? node.node_data : py::none()); break; } @@ -1317,23 +1337,18 @@ std::string PyTreeSpec::ToString() const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: case PyTreeKind::DefaultDict: { + const scoped_critical_section cs{node.node_data}; if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { - EXPECT_EQ(GET_SIZE(node.node_data), - 2, - "Number of auxiliary data mismatch."); - const py::object default_factory = - GET_ITEM_BORROW(node.node_data, ssize_t(0)); - data_hash = py::hash(default_factory); + EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch."); + const py::object default_factory = TupleGetItem(node.node_data, 0); + data_hash = EVALUATE_WITH_LOCK_HELD(py::hash(default_factory), default_factory); } - const auto keys = py::reinterpret_borrow( - node.kind != PyTreeKind::DefaultDict - ? node.node_data - : GET_ITEM_BORROW(node.node_data, ssize_t(1))); - EXPECT_EQ(GET_SIZE(keys), - node.arity, - "Number of keys and entries does not match."); + const auto keys = (node.kind != PyTreeKind::DefaultDict + ? py::reinterpret_borrow(node.node_data) + : TupleGetItemAs(node.node_data, 1)); + EXPECT_EQ(ListGetSize(keys), node.arity, "Number of keys and entries does not match."); for (const py::handle& key : keys) { - HashCombine(data_hash, py::hash(key)); + HashCombine(data_hash, EVALUATE_WITH_LOCK_HELD(py::hash(key), key)); } break; } @@ -1361,18 +1376,34 @@ ssize_t PyTreeSpec::HashValueImpl() const { } ssize_t PyTreeSpec::HashValue() const { - const std::pair indent{this, std::this_thread::get_id()}; - if (sm_hash_running.find(indent) != sm_hash_running.end()) [[unlikely]] { - return 0; + using ThreadIndent = std::pair; + static std::unordered_set running{}; + static read_write_mutex mutex{}; + + const ThreadIndent indent{this, std::this_thread::get_id()}; + { + const scoped_read_lock_guard lock{mutex}; + if (running.find(indent) != running.end()) [[unlikely]] { + return 0; + } } - sm_hash_running.insert(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.insert(indent); + } try { const ssize_t result = HashValueImpl(); - sm_hash_running.erase(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.erase(indent); + } return result; } catch (...) { - sm_hash_running.erase(indent); + { + const scoped_write_lock_guard lock{mutex}; + running.erase(indent); + } std::rethrow_exception(std::current_exception()); } } @@ -1381,16 +1412,19 @@ py::object PyTreeSpec::ToPickleable() const { const py::tuple node_states{GetNumNodes()}; ssize_t i = 0; for (const auto& node : m_traversal) { - SET_ITEM(node_states, - i++, - py::make_tuple(static_cast(node.kind), - node.arity, - node.node_data ? node.node_data : py::none(), - node.node_entries ? node.node_entries : py::none(), - node.custom != nullptr ? node.custom->type : py::none(), - node.num_leaves, - node.num_nodes, - node.original_keys ? node.original_keys : py::none())); + const scoped_critical_section2 cs{ + node.custom != nullptr ? py::handle{node.custom->type.ptr()} : py::handle{}, + node.node_data}; + TupleSetItem(node_states, + i++, + py::make_tuple(py::int_(static_cast(node.kind)), + py::int_(node.arity), + node.node_data ? node.node_data : py::none(), + node.node_entries ? node.node_entries : py::none(), + node.custom != nullptr ? node.custom->type : py::none(), + py::int_(node.num_leaves), + py::int_(node.num_nodes), + node.original_keys ? node.original_keys : py::none())); } return py::make_tuple(node_states, py::bool_(m_none_is_leaf), py::str(m_namespace)); } @@ -1398,20 +1432,20 @@ py::object PyTreeSpec::ToPickleable() const { // NOLINTBEGIN[cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers] // NOLINTNEXTLINE[readability-function-cognitive-complexity] /*static*/ std::unique_ptr PyTreeSpec::FromPickleable(const py::object& pickleable) { - const auto state = py::reinterpret_borrow(pickleable); + const auto state = thread_safe_cast(pickleable); if (state.size() != 3) [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } bool none_is_leaf = false; std::string registry_namespace{}; auto out = std::make_unique(); - out->m_none_is_leaf = none_is_leaf = py::cast(state[1]); - out->m_namespace = registry_namespace = py::cast(state[2]); - const auto node_states = py::reinterpret_borrow(state[0]); + out->m_none_is_leaf = none_is_leaf = thread_safe_cast(state[1]); + out->m_namespace = registry_namespace = thread_safe_cast(state[2]); + const auto node_states = thread_safe_cast(state[0]); for (const auto& item : node_states) { - const auto t = py::cast(item); + const auto t = thread_safe_cast(item); Node& node = out->m_traversal.emplace_back(); - node.kind = static_cast(py::cast(t[0])); + node.kind = static_cast(thread_safe_cast(t[0])); if (t.size() != 7) [[unlikely]] { if (t.size() == 8) [[likely]] { if (t[7].is_none()) [[likely]] { @@ -1422,7 +1456,7 @@ py::object PyTreeSpec::ToPickleable() const { } else [[unlikely]] { if (node.kind == PyTreeKind::Dict || node.kind == PyTreeKind::DefaultDict) [[likely]] { - node.original_keys = py::cast(t[7]); + node.original_keys = thread_safe_cast(t[7]); } else [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } @@ -1431,7 +1465,7 @@ py::object PyTreeSpec::ToPickleable() const { throw std::runtime_error("Malformed pickled PyTreeSpec."); } } - node.arity = py::cast(t[1]); + node.arity = thread_safe_cast(t[1]); switch (node.kind) { case PyTreeKind::Leaf: case PyTreeKind::None: @@ -1445,13 +1479,13 @@ py::object PyTreeSpec::ToPickleable() const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: { - node.node_data = py::cast(t[2]); + node.node_data = thread_safe_cast(t[2]); break; } case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: { - node.node_data = py::cast(t[2]); + node.node_data = thread_safe_cast(t[2]); break; } @@ -1467,7 +1501,7 @@ py::object PyTreeSpec::ToPickleable() const { } if (node.kind == PyTreeKind::Custom) [[unlikely]] { // NOLINT if (!t[3].is_none()) [[unlikely]] { - node.node_entries = py::cast(t[3]); + node.node_entries = thread_safe_cast(t[3]); } if (t[4].is_none()) [[unlikely]] { node.custom = nullptr; @@ -1494,8 +1528,8 @@ py::object PyTreeSpec::ToPickleable() const { } else if (!t[3].is_none() || !t[4].is_none()) [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } - node.num_leaves = py::cast(t[5]); - node.num_nodes = py::cast(t[6]); + node.num_leaves = thread_safe_cast(t[5]); + node.num_nodes = thread_safe_cast(t[6]); } out->m_traversal.shrink_to_fit(); return out; @@ -1520,7 +1554,7 @@ size_t PyTreeSpec::ThreadIndentTypeHash::operator()( // The holder is not constructed yet. Skip the traversal to avoid segfault. return 0; } - auto& self = py::cast(py::handle{self_base}); + auto& self = thread_safe_cast(py::handle{self_base}); for (const auto& node : self.m_traversal) { Py_VISIT(node.node_data.ptr()); Py_VISIT(node.node_entries.ptr()); @@ -1539,10 +1573,11 @@ size_t PyTreeSpec::ThreadIndentTypeHash::operator()( // The holder is not constructed yet. Skip the traversal to avoid segfault. return 0; } - auto& self = py::cast(py::handle{self_base}); + auto& self = thread_safe_cast(py::handle{self_base}); for (const auto& pair : self.m_agenda) { Py_VISIT(pair.first.ptr()); } + Py_VISIT(self.m_root.ptr()); return 0; } diff --git a/src/treespec/unflatten.cpp b/src/treespec/unflatten.cpp index f0300720..74e979f9 100644 --- a/src/treespec/unflatten.cpp +++ b/src/treespec/unflatten.cpp @@ -18,6 +18,7 @@ limitations under the License. #include // std::ostringstream #include // std::move +#include "include/critical_section.h" #include "include/exceptions.h" #include "include/registry.h" #include "include/treespec.h" @@ -79,6 +80,9 @@ py::object PyTreeSpec::UnflattenImpl(const Span& leaves) const { return std::move(agenda.back()); } -py::object PyTreeSpec::Unflatten(const py::iterable& leaves) const { return UnflattenImpl(leaves); } +py::object PyTreeSpec::Unflatten(const py::iterable& leaves) const { + const scoped_critical_section cs{leaves}; + return UnflattenImpl(leaves); +} } // namespace optree diff --git a/tests/conftest.py b/tests/conftest.py index 3d3e622f..90a9c541 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,13 @@ import os import random +import threading + + +thread = threading.Thread(target=object) # no-op +thread.start() +thread.join() +del threading, thread os.environ['PYTHONHASHSEED'] = '0' diff --git a/tests/helpers.py b/tests/helpers.py index c6ad6f07..3d9d9140 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -20,6 +20,7 @@ import itertools import platform import sys +import sysconfig import time from collections import OrderedDict, UserDict, defaultdict, deque, namedtuple from typing import Any, NamedTuple @@ -27,6 +28,7 @@ import pytest import optree +from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE PYPY = platform.python_implementation() == 'PyPy' @@ -35,12 +37,12 @@ reason='PyPy does not support weakref and refcount correctly', ) - -GLOBAL_NAMESPACE = optree.registry.__GLOBAL_NAMESPACE # pylint: disable=protected-access +Py_GIL_DISABLED = sysconfig.get_config_var('Py_GIL_DISABLED') is not None +NUM_GC_REPEAT = 10 if Py_GIL_DISABLED else 5 def gc_collect(): - for _ in range(3): + for _ in range(NUM_GC_REPEAT): gc.collect() diff --git a/tests/test_concurrent.py b/tests/test_concurrent.py new file mode 100644 index 00000000..ac5ecdba --- /dev/null +++ b/tests/test_concurrent.py @@ -0,0 +1,337 @@ +# Copyright 2022-2024 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order + +import atexit +import itertools +import pickle +import weakref +from collections import OrderedDict, defaultdict +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + +import optree +from helpers import GLOBAL_NAMESPACE, PYPY, TREES, Py_GIL_DISABLED, gc_collect, parametrize + + +if PYPY: + pytest.skip('Test for CPython only', allow_module_level=True) + + +if Py_GIL_DISABLED: + NUM_WORKERS = 32 + NUM_FUTURES = 128 +else: + NUM_WORKERS = 4 + NUM_FUTURES = 16 + + +EXECUTOR = ThreadPoolExecutor(max_workers=NUM_WORKERS) +atexit.register(EXECUTOR.shutdown) + + +def concurrent_run(func): + futures = [EXECUTOR.submit(func) for _ in range(NUM_FUTURES)] + future2index = {future: i for i, future in enumerate(futures)} + completed_futures = sorted(as_completed(futures), key=future2index.get) + first_exception = next(filter(None, (future.exception() for future in completed_futures)), None) + if first_exception is not None: + raise first_exception + return [future.result() for future in completed_futures] + + +concurrent_run(object) # warm-up + + +@parametrize( + tree=TREES, + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_tree_flatten_unflatten_thread_safe( + tree, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + def test_fn(): + return optree.tree_flatten(tree, namespace=namespace) + + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + leaves, treespec = expected = test_fn() + for result in concurrent_run(test_fn): + assert result == expected + + for result in concurrent_run(lambda: optree.tree_unflatten(treespec, leaves)): + assert result == tree + + +@parametrize( + tree=TREES, + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_tree_flatten_with_path_thread_safe( + tree, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + def test_fn(): + return optree.tree_flatten_with_path(tree, namespace=namespace) + + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected = test_fn() + for result in concurrent_run(test_fn): + assert result == expected + + +@parametrize( + tree=TREES, + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_tree_flatten_with_accessor_thread_safe( + tree, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + def test_fn(): + return optree.tree_flatten_with_accessor(tree, namespace=namespace) + + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected = test_fn() + for result in concurrent_run(test_fn): + assert result == expected + + +@parametrize(tree=TREES) +def test_treespec_string_representation(tree): + expected_string = repr(optree.tree_structure(tree)) + + def check1(): + treespec = optree.tree_structure(tree) + assert str(treespec) == expected_string + assert repr(treespec) == expected_string + + concurrent_run(check1) + + treespec = optree.tree_structure(tree) + + def check2(): + assert str(treespec) == expected_string + assert repr(treespec) == expected_string + + concurrent_run(check2) + + +def test_treespec_self_referential(): # noqa: C901 + class Holder: + def __init__(self, value): + self.value = value + + def __eq__(self, other): + return isinstance(other, Holder) and self.value == other.value + + def __hash__(self): + return hash(self.value) + + def __repr__(self): + return f'Holder({self.value!r})' + + hashes = set() + key = Holder('a') + + treespec = optree.tree_structure({key: 0}) + + def check1(): + assert str(treespec) == "PyTreeSpec({Holder('a'): *})" # noqa: F821 + assert hash(treespec) == hash(treespec) # noqa: F821 + + concurrent_run(check1) + + hashes.add(hash(treespec)) + + key.value = 'b' + + def check2(): + assert str(treespec) == "PyTreeSpec({Holder('b'): *})" # noqa: F821 + assert hash(treespec) == hash(treespec) # noqa: F821 + assert hash(treespec) not in hashes # noqa: F821 + + concurrent_run(check2) + + hashes.add(hash(treespec)) + + key.value = treespec + + def check3(): + assert str(treespec) == 'PyTreeSpec({Holder(...): *})' # noqa: F821 + assert hash(treespec) == hash(treespec) # noqa: F821 + assert hash(treespec) not in hashes # noqa: F821 + + concurrent_run(check3) + + hashes.add(hash(treespec)) + + key.value = ('a', treespec, treespec) + + def check4(): + assert str(treespec) == "PyTreeSpec({Holder(('a', ..., ...)): *})" # noqa: F821 + assert hash(treespec) == hash(treespec) # noqa: F821 + assert hash(treespec) not in hashes # noqa: F821 + + concurrent_run(check4) + + hashes.add(hash(treespec)) + + other = optree.tree_structure({Holder(treespec): 1}) + + def check5(): + assert ( + str(other) # noqa: F821 + == "PyTreeSpec({Holder(PyTreeSpec({Holder(('a', ..., ...)): *})): *})" + ) + assert hash(other) == hash(other) # noqa: F821 + assert hash(other) not in hashes # noqa: F821 + + concurrent_run(check5) + + hashes.add(hash(other)) + + key.value = other + + def check6(): + assert ( + str(treespec) == 'PyTreeSpec({Holder(PyTreeSpec({Holder(...): *})): *})' # noqa: F821 + ) + assert str(other) == 'PyTreeSpec({Holder(PyTreeSpec({Holder(...): *})): *})' # noqa: F821 + assert hash(treespec) == hash(treespec) # noqa: F821 + assert hash(treespec) not in hashes # noqa: F821 + + concurrent_run(check6) + + hashes.add(hash(treespec)) + + def check7(): + assert hash(other) == hash(other) # noqa: F821 + assert hash(treespec) == hash(other) # noqa: F821 + + concurrent_run(check7) + + with pytest.raises(RecursionError): + assert treespec != other + + wr = weakref.ref(treespec) + del treespec, key, other + gc_collect() + assert wr() is None + + +@parametrize( + tree=TREES, + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_treespec_pickle_round_trip( + tree, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + def check1(): + assert pickle.loads(pickle.dumps(expected)) == expected + + def check2(): + assert pickle.dumps(expected) == expected_serialized + assert pickle.loads(expected_serialized) == expected + + def check3(): + assert list(optree.tree_unflatten(actual, range(len(actual)))) == list( + optree.tree_unflatten(expected, range(len(expected))), + ) + + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + expected = optree.tree_structure(tree, namespace=namespace) + expected_serialized = b'' + try: + pickle.loads(pickle.dumps(tree)) + except pickle.PicklingError: + with pytest.raises(pickle.PicklingError, match=r"Can't pickle .*:"): + pickle.loads(pickle.dumps(expected)) + else: + expected_serialized = pickle.dumps(expected) + actual = pickle.loads(expected_serialized) + concurrent_run(check1) + concurrent_run(check2) + if expected.type in {dict, OrderedDict, defaultdict}: + concurrent_run(check3) + + +@parametrize( + tree=TREES, + none_is_leaf=[False, True], + namespace=['', 'undefined', 'namespace'], + dict_should_be_sorted=[False, True], + dict_session_namespace=['', 'undefined', 'namespace'], +) +def test_tree_iter_thread_safe( + tree, + none_is_leaf, + namespace, + dict_should_be_sorted, + dict_session_namespace, +): + counter = itertools.count() + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + new_tree = optree.tree_map( + lambda x: next(counter), + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + num_leaves = next(counter) + it = optree.tree_iter( + new_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + + results = concurrent_run(lambda: list(it)) + for seq in results: + assert sorted(seq) == seq + assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) diff --git a/tests/test_registry.py b/tests/test_registry.py index 5477a408..869994a9 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -23,7 +23,7 @@ import optree import optree._C -from helpers import GLOBAL_NAMESPACE, gc_collect, skipif_pypy +from helpers import GLOBAL_NAMESPACE, Py_GIL_DISABLED, gc_collect, skipif_pypy def test_register_pytree_node_class_with_no_namespace(): @@ -767,8 +767,7 @@ def test_unregister_pytree_node_namedtuple(): @skipif_pypy -def test_unregister_pytree_node_memory_leak(): # noqa: C901 - +def test_unregister_pytree_node_no_reference_leak(): # noqa: C901 @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList1(UserList): def tree_flatten(self): @@ -784,7 +783,8 @@ def tree_unflatten(cls, metadata, children): optree.unregister_pytree_node(MyList1, namespace=GLOBAL_NAMESPACE) del MyList1 gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList2(UserList): @@ -811,7 +811,8 @@ def tree_unflatten(cls, metadata, children): del treespec gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None @optree.register_pytree_node_class(namespace=GLOBAL_NAMESPACE) class MyList3(UserList): @@ -841,7 +842,8 @@ def tree_unflatten(cls, metadata, children): del treespec gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None @optree.register_pytree_node_class(namespace='mylist') class MyList4(UserList): @@ -858,7 +860,8 @@ def tree_unflatten(cls, metadata, children): optree.unregister_pytree_node(MyList4, namespace='mylist') del MyList4 gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None @optree.register_pytree_node_class(namespace='mylist') class MyList5(UserList): @@ -887,7 +890,8 @@ def tree_unflatten(cls, metadata, children): del treespec gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None def test_dict_insertion_order_with_invalid_namespace(): diff --git a/tests/test_treespec.py b/tests/test_treespec.py index d39503b7..d43aaa99 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -248,7 +248,6 @@ def __repr__(self): wr = weakref.ref(treespec) del treespec, key, other gc_collect() - if not PYPY: assert wr() is None @@ -279,7 +278,6 @@ def test_treeiter_self_referential(): del it, d gc_collect() - if not PYPY: assert wr() is None diff --git a/tests/test_typing.py b/tests/test_typing.py index 4fa85a78..c3d8998b 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -27,6 +27,7 @@ from helpers import ( CustomNamedTupleSubclass, CustomTuple, + Py_GIL_DISABLED, Vector2D, gc_collect, getrefcount, @@ -133,7 +134,8 @@ def test_is_namedtuple_cache(): assert wr() is Point del Point gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None refcount = getrefcount(time.struct_time) weakrefcount = weakref.getweakrefcount(time.struct_time) @@ -170,8 +172,9 @@ class Foo(metaclass=FooMeta): assert wr() is Foo del Foo gc_collect() - assert called_with == 'Foo' - assert wr() is None + if not Py_GIL_DISABLED: + assert called_with == 'Foo' + assert wr() is None def test_namedtuple_fields(): @@ -264,7 +267,8 @@ def test_namedtuple_fields_cache(): del Point gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None with pytest.raises( TypeError, @@ -299,8 +303,9 @@ class Foo(metaclass=FooMeta): assert wr() is Foo del Foo gc_collect() - assert called_with == 'Foo' - assert wr() is None + if not Py_GIL_DISABLED: + assert called_with == 'Foo' + assert wr() is None def test_is_structseq(): @@ -380,7 +385,8 @@ def test_is_structseq_cache(): assert wr() is Point del Point gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None refcount = getrefcount(time.struct_time) weakrefcount = weakref.getweakrefcount(time.struct_time) @@ -417,8 +423,9 @@ class Foo(metaclass=FooMeta): assert wr() is Foo del Foo gc_collect() - assert called_with == 'Foo' - assert wr() is None + if not Py_GIL_DISABLED: + assert called_with == 'Foo' + assert wr() is None def test_structseq_fields(): @@ -547,7 +554,8 @@ def test_structseq_fields_cache(): assert wr() is Point del Point gc_collect() - assert wr() is None + if not Py_GIL_DISABLED: + assert wr() is None refcount = getrefcount(time.struct_time) weakrefcount = weakref.getweakrefcount(time.struct_time) @@ -583,5 +591,6 @@ class Foo(metaclass=FooMeta): assert wr() is Foo del Foo gc_collect() - assert called_with == 'Foo' - assert wr() is None + if not Py_GIL_DISABLED: + assert called_with == 'Foo' + assert wr() is None