diff --git a/.clang-format b/.clang-format index 2ea83640..cac3b9eb 100644 --- a/.clang-format +++ b/.clang-format @@ -1,9 +1,13 @@ +--- BasedOnStyle: Google UseTab: Never IndentWidth: 4 ContinuationIndentWidth: 4 -AccessModifierOffset: -3 +AccessModifierOffset: -4 ColumnLimit: 100 + +Language: Cpp +Standard: c++17 AlignAfterOpenBracket: Align AlignEscapedNewlines: Right AllowAllArgumentsOnNextLine: false @@ -12,6 +16,7 @@ AllowShortIfStatementsOnASingleLine: Never BinPackArguments: false BinPackParameters: false BreakBeforeTernaryOperators: true +CommentPragmas: 'NOLINT(NEXTLINE|BEGIN|END)?\[.*\]' FixNamespaceComments: true IncludeBlocks: Regroup IncludeCategories: @@ -34,13 +39,17 @@ IndentPPDirectives: None InsertBraces: true InsertTrailingCommas: Wrapped LambdaBodyIndentation: Signature +MacroBlockBegin: '^Py_BEGIN_(ALLOW_THREADS|CRITICAL_SECTION(2)?(_MUT)?)$' +MacroBlockEnd: '^Py_END_(ALLOW_THREADS|CRITICAL_SECTION(2)?(_MUT)?)$' PackConstructorInitializers: NextLine PointerAlignment: Right QualifierAlignment: Custom -QualifierOrder: [friend, static, inline, const, constexpr, volatile, type, restrict] +QualifierOrder: + [friend, static, inline, const, constexpr, volatile, type, restrict] ReferenceAlignment: Right -RemoveParentheses: MultipleParentheses +RemoveParentheses: ReturnStatement RemoveSemicolon: true SeparateDefinitionBlocks: Leave +SkipMacroDefinitionBody: false SortIncludes: CaseSensitive SpaceAroundPointerQualifiers: Both diff --git a/.clang-tidy b/.clang-tidy index 37ef02cd..8e613a20 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -4,24 +4,24 @@ InheritParentConfig: true FormatStyle: file UseColor: true WarningsAsErrors: '*' -Checks: ' -bugprone-*, --bugprone-easily-swappable-parameters, -clang-analyzer-*, -cppcoreguidelines-*, --cppcoreguidelines-macro-usage, --cppcoreguidelines-pro-type-reinterpret-cast, -hicpp-*, -misc-*, -modernize-*, --modernize-use-trailing-return-type, -performance-*, -readability-*, --readability-redundant-inline-specifier, --readability-redundant-member-init, --readability-identifier-length, -' -CheckOptions: - misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*' HeaderFilterRegex: '^include/.*$' -... + +Checks: | + bugprone-*, + -bugprone-easily-swappable-parameters, + clang-analyzer-*, + cppcoreguidelines-*, + -cppcoreguidelines-macro-usage, + -cppcoreguidelines-pro-type-reinterpret-cast, + hicpp-*, + misc-*, + modernize-*, + -modernize-use-trailing-return-type, + performance-*, + readability-*, + -readability-redundant-inline-specifier, + -readability-redundant-member-init, + -readability-identifier-length, + +CheckOptions: + misc-include-cleaner.IgnoreHeaders: 'python.*/.*;pybind11/.*;include/.*' diff --git a/CMakeLists.txt b/CMakeLists.txt index 65b537b7..61631ff2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -88,7 +88,7 @@ if(OPTREE_CXX_WERROR) endif() string(LENGTH "${CMAKE_SOURCE_DIR}/" SOURCE_PATH_PREFIX_SIZE) -add_definitions("-DSOURCE_PATH_PREFIX_SIZE=${SOURCE_PATH_PREFIX_SIZE}") +add_definitions("-DSOURCE_PATH_PREFIX_SIZE=(${SOURCE_PATH_PREFIX_SIZE})") function(system) set(options STRIP) diff --git a/CPPLINT.cfg b/CPPLINT.cfg index f4e9cd50..78a93f3d 100644 --- a/CPPLINT.cfg +++ b/CPPLINT.cfg @@ -1,6 +1,8 @@ linelength=100 filter=-readability/nolint filter=-readability/braces +filter=-whitespace/indent filter=-whitespace/newline +filter=-whitespace/parens filter=-build/c++11 filter=-build/include_order diff --git a/include/critical_section.h b/include/critical_section.h index 76a2edf4..649b166a 100644 --- a/include/critical_section.h +++ b/include/critical_section.h @@ -40,7 +40,7 @@ inline bool Py_IsConstant(PyObject* x) { return Py_IsNone(x) || Py_IsTrue(x) || #define Py_IsConstant(x) Py_IsConstant(x) class scoped_critical_section { - public: +public: scoped_critical_section() = delete; #ifdef Py_GIL_DISABLED @@ -65,7 +65,7 @@ class scoped_critical_section { scoped_critical_section(scoped_critical_section&&) = delete; scoped_critical_section& operator=(scoped_critical_section&&) = delete; - private: +private: #ifdef Py_GIL_DISABLED PyObject* m_ptr{nullptr}; PyCriticalSection m_critical_section{}; @@ -73,7 +73,7 @@ class scoped_critical_section { }; class scoped_critical_section2 { - public: +public: scoped_critical_section2() = delete; #ifdef Py_GIL_DISABLED @@ -111,7 +111,7 @@ class scoped_critical_section2 { scoped_critical_section2(scoped_critical_section2&&) = delete; scoped_critical_section2& operator=(scoped_critical_section2&&) = delete; - private: +private: #ifdef Py_GIL_DISABLED PyObject* m_ptr1{nullptr}; PyObject* m_ptr2{nullptr}; diff --git a/include/exceptions.h b/include/exceptions.h index a7e016b5..7c7a78fe 100644 --- a/include/exceptions.h +++ b/include/exceptions.h @@ -23,7 +23,7 @@ limitations under the License. #include // std::string #ifndef SOURCE_PATH_PREFIX_SIZE -#define SOURCE_PATH_PREFIX_SIZE 0 +#define SOURCE_PATH_PREFIX_SIZE (0) #endif #ifndef FILE_RELPATH @@ -36,9 +36,11 @@ limitations under the License. namespace optree { class InternalError : public std::logic_error { - public: +public: explicit InternalError(const std::string& msg) : std::logic_error{msg} {} - InternalError(const std::string& msg, const std::string& file, const std::size_t& lineno) + explicit InternalError(const std::string& msg, + const std::string& file, + const std::size_t& lineno) : InternalError([&msg, &file, &lineno]() -> std::string { std::ostringstream oss{}; oss << msg << " (at file " << file << ":" << lineno << ")\n\n" @@ -49,32 +51,25 @@ class InternalError : public std::logic_error { } // namespace optree -#define INTERNAL_ERROR1_(message) throw optree::InternalError(message, FILE_RELPATH, __LINE__) +#define INTERNAL_ERROR1_(message) throw optree::InternalError((message), FILE_RELPATH, __LINE__) #define INTERNAL_ERROR0_() INTERNAL_ERROR1_("Unreachable code.") -#define INTERNAL_ERROR(...) /* NOLINTNEXTLINE[whitespace/parens] */ \ +#define INTERNAL_ERROR(...) \ VA_FUNC2_(__0 __VA_OPT__(, ) __VA_ARGS__, INTERNAL_ERROR1_, INTERNAL_ERROR0_)(__VA_ARGS__) #define EXPECT2_(condition, message) \ - if (!(condition)) [[unlikely]] \ - INTERNAL_ERROR1_(message) + if (!(condition)) [[unlikely]] { \ + INTERNAL_ERROR1_(message); \ + } #define EXPECT0_() INTERNAL_ERROR0_() -#define EXPECT1_(condition) EXPECT2_(condition, "`" #condition "` failed.") -#define EXPECT_(...) /* NOLINTNEXTLINE[whitespace/parens] */ \ +#define EXPECT1_(condition) EXPECT2_((condition), "`" #condition "` failed.") +#define EXPECT_(...) \ VA_FUNC3_(__0 __VA_OPT__(, ) __VA_ARGS__, EXPECT2_, EXPECT1_, EXPECT0_)(__VA_ARGS__) -#define EXPECT_TRUE(condition, ...) \ - EXPECT_(condition __VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_FALSE(condition, ...) \ - EXPECT_(!(condition)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_EQ(a, b, ...) \ - EXPECT_((a) == (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_NE(a, b, ...) \ - EXPECT_((a) != (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_LT(a, b, ...) \ - EXPECT_((a) < (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_LE(a, b, ...) \ - EXPECT_((a) <= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_GT(a, b, ...) \ - EXPECT_((a) > (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] -#define EXPECT_GE(a, b, ...) \ - EXPECT_((a) >= (b)__VA_OPT__(, ) __VA_ARGS__) // NOLINT[whitespace/parens] +#define EXPECT_TRUE(condition, ...) EXPECT_((condition)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_FALSE(condition, ...) EXPECT_(!(condition)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_EQ(a, b, ...) EXPECT_((a) == (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_NE(a, b, ...) EXPECT_((a) != (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_LT(a, b, ...) EXPECT_((a) < (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_LE(a, b, ...) EXPECT_((a) <= (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_GT(a, b, ...) EXPECT_((a) > (b)__VA_OPT__(, ) __VA_ARGS__) +#define EXPECT_GE(a, b, ...) EXPECT_((a) >= (b)__VA_OPT__(, ) __VA_ARGS__) diff --git a/include/mutex.h b/include/mutex.h index d01db861..03dfbeae 100644 --- a/include/mutex.h +++ b/include/mutex.h @@ -24,7 +24,7 @@ limitations under the License. #ifdef Py_GIL_DISABLED class pymutex { - public: +public: pymutex() = default; ~pymutex() = default; @@ -36,7 +36,7 @@ class pymutex { void lock() { PyMutex_Lock(&mutex); } void unlock() { PyMutex_Unlock(&mutex); } - private: +private: PyMutex mutex{0}; }; diff --git a/include/registry.h b/include/registry.h index 7ec36625..eec08e55 100644 --- a/include/registry.h +++ b/include/registry.h @@ -63,7 +63,7 @@ constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence; // Registry of custom node types. class PyTreeTypeRegistry { - public: +public: PyTreeTypeRegistry() = default; ~PyTreeTypeRegistry() = default; @@ -116,7 +116,7 @@ class PyTreeTypeRegistry { // Clear the registry on cleanup. static void Clear(); - private: +private: template static PyTreeTypeRegistry *Singleton(); diff --git a/include/treespec.h b/include/treespec.h index ee0fbfff..de73ac7c 100644 --- a/include/treespec.h +++ b/include/treespec.h @@ -41,7 +41,7 @@ using ssize_t = py::ssize_t; // The maximum depth of a pytree. #ifndef Py_C_RECURSION_LIMIT -#define Py_C_RECURSION_LIMIT 1000 +#define Py_C_RECURSION_LIMIT (1000) #endif #ifndef PYPY_VERSION constexpr ssize_t MAX_RECURSION_DEPTH = std::min(1000, Py_C_RECURSION_LIMIT); @@ -77,10 +77,10 @@ py::module_ GetCxxModule(const std::optional &module = std::nullopt // the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves // are other objects. class PyTreeSpec { - private: +private: struct Node; - public: +public: PyTreeSpec() = default; ~PyTreeSpec() = default; @@ -232,7 +232,7 @@ class PyTreeSpec { // Used in tp_traverse for GC support. static int PyTpTraverse(PyObject *self_base, visitproc visit, void *arg); - private: +private: using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr; struct Node { @@ -358,7 +358,7 @@ class PyTreeSpec { std::string registry_namespace); class ThreadIndentTypeHash { - public: + public: using is_transparent = void; size_t operator()(const std::pair &p) const; }; @@ -370,7 +370,7 @@ class PyTreeSpec { }; class PyTreeIter { - public: +public: explicit PyTreeIter(const py::object &tree, const std::optional &leaf_predicate, const bool &none_is_leaf, @@ -397,7 +397,7 @@ class PyTreeIter { // Used in tp_traverse for GC support. static int PyTpTraverse(PyObject *self_base, visitproc visit, void *arg); - private: +private: const py::object m_root; std::vector> m_agenda; const std::optional m_leaf_predicate; diff --git a/include/utils.h b/include/utils.h index 4a0ed03d..4b0c79e4 100644 --- a/include/utils.h +++ b/include/utils.h @@ -58,13 +58,13 @@ inline void HashCombine(py::ssize_t& seed, const T& v) { // NOLINT[runtime/refe } class TypeHash { - public: +public: using is_transparent = void; py::size_t operator()(const py::object& t) const { return std::hash{}(t.ptr()); } py::size_t operator()(const py::handle& t) const { return std::hash{}(t.ptr()); } }; class TypeEq { - public: +public: using is_transparent = void; bool operator()(const py::object& a, const py::object& b) const { return a.ptr() == b.ptr(); } bool operator()(const py::object& a, const py::handle& b) const { return a.ptr() == b.ptr(); } @@ -73,7 +73,7 @@ class TypeEq { }; class NamedTypeHash { - public: +public: using is_transparent = void; py::size_t operator()(const std::pair& p) const { py::size_t seed = 0; @@ -89,7 +89,7 @@ class NamedTypeHash { } }; class NamedTypeEq { - public: +public: using is_transparent = void; bool operator()(const std::pair& a, const std::pair& b) const { @@ -114,11 +114,11 @@ constexpr bool NONE_IS_NODE = false; #define Py_Declare_ID(name) \ namespace { \ - static inline PyObject* Py_ID_##name() { \ + inline PyObject* Py_ID_##name() { \ PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; \ return storage \ .call_once_and_store_result([]() -> PyObject* { \ - PyObject* ptr = PyUnicode_InternFromString(#name); \ + PyObject* const ptr = PyUnicode_InternFromString(#name); \ if (ptr == nullptr) [[unlikely]] { \ throw py::error_already_set(); \ } \ @@ -147,6 +147,14 @@ Py_Declare_ID(n_fields); // structseq.n_fields Py_Declare_ID(n_sequence_fields); // structseq.n_sequence_fields Py_Declare_ID(n_unnamed_fields); // structseq.n_unnamed_fields +#define PyNoneTypeObject \ + (py::reinterpret_borrow(reinterpret_cast(Py_TYPE(Py_None)))) +#define PyTupleTypeObject \ + (py::reinterpret_borrow(reinterpret_cast(&PyTuple_Type))) +#define PyListTypeObject \ + (py::reinterpret_borrow(reinterpret_cast(&PyList_Type))) +#define PyDictTypeObject \ + (py::reinterpret_borrow(reinterpret_cast(&PyDict_Type))) #define PyOrderedDictTypeObject (ImportOrderedDict()) #define PyDefaultDictTypeObject (ImportDefaultDict()) #define PyDequeTypeObject (ImportDeque()) @@ -210,7 +218,7 @@ inline py::object TupleGetItem(const py::handle& tuple, const py::ssize_t& index template >> inline T ListGetItemAs(const py::handle& list, const py::ssize_t& index) { #if PY_VERSION_HEX >= 0x030D00A4 // Python 3.13.0a4 - PyObject* item = PyList_GetItemRef(list.ptr(), index); + PyObject* const item = PyList_GetItemRef(list.ptr(), index); if (item == nullptr) [[unlikely]] { throw py::error_already_set(); } @@ -418,7 +426,7 @@ inline py::tuple NamedTupleGetFields(const py::handle& object) { inline bool IsStructSequenceClassImpl(const py::handle& type) { // We can only identify PyStructSequences heuristically, here by the presence of // n_fields, n_sequence_fields, n_unnamed_fields attributes. - auto* type_object = reinterpret_cast(type.ptr()); + auto* const type_object = reinterpret_cast(type.ptr()); if (PyType_FastSubclass(type_object, Py_TPFLAGS_TUPLE_SUBCLASS) && type_object->tp_bases != nullptr && static_cast(PyTuple_CheckExact(type_object->tp_bases)) && @@ -426,9 +434,9 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { PyTuple_GET_ITEM(type_object->tp_bases, 0) == reinterpret_cast(&PyTuple_Type)) [[unlikely]] { // NOLINTNEXTLINE[readability-use-anyofallof] - for (PyObject* name : + for (PyObject* const name : {Py_Get_ID(n_fields), Py_Get_ID(n_sequence_fields), Py_Get_ID(n_unnamed_fields)}) { - if (PyObject* attr = PyObject_GetAttr(type.ptr(), name)) [[unlikely]] { + if (const PyObject* const attr = PyObject_GetAttr(type.ptr(), name)) [[unlikely]] { const bool result = static_cast(PyLong_CheckExact(attr)); Py_DECREF(attr); if (!result) [[unlikely]] { @@ -443,11 +451,11 @@ inline bool IsStructSequenceClassImpl(const py::handle& type) { try { py::exec("class _(cls): pass", py::dict(py::arg("cls") = type)); } catch (py::error_already_set& ex) { - return (ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError)); + return ex.matches(PyExc_AssertionError) || ex.matches(PyExc_TypeError); } return false; #else - return (!static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE))); + return !static_cast(PyType_HasFeature(type_object, Py_TPFLAGS_BASETYPE)); #endif } return false; @@ -513,8 +521,8 @@ inline py::tuple StructSequenceGetFieldsImpl(const py::handle& type) { py::dict(py::arg("cls") = type, py::arg("fields") = fields)); return py::tuple{fields}; #else - const auto n_sequence_fields = - py::cast(py::getattr(type, Py_Get_ID(n_sequence_fields))); + const auto n_sequence_fields = thread_safe_cast( + EVALUATE_WITH_LOCK_HELD(py::getattr(type, Py_Get_ID(n_sequence_fields)), type)); const auto* const members = reinterpret_cast(type.ptr())->tp_members; py::tuple fields{n_sequence_fields}; for (py::ssize_t i = 0; i < n_sequence_fields; ++i) { @@ -584,11 +592,11 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references] try { // Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)` const auto sort_key_fn = py::cpp_function([](const py::object& o) -> py::tuple { - const py::handle t = py::type::handle_of(o); - const py::str qualname{ - EVALUATE_WITH_LOCK_HELD(PyStr(py::getattr(t, Py_Get_ID(__module__))) + "." + - PyStr(py::getattr(t, Py_Get_ID(__qualname__))), - t)}; + const py::handle cls = py::type::handle_of(o); + const py::str qualname{EVALUATE_WITH_LOCK_HELD( + PyStr(py::getattr(cls, Py_Get_ID(__module__))) + "." + + PyStr(py::getattr(cls, Py_Get_ID(__qualname__))), + cls)}; return py::make_tuple(qualname, o); }); { diff --git a/src/optree.cpp b/src/optree.cpp index e996f3d3..11e8ddd9 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -168,7 +168,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.") .value("DEQUE", PyTreeKind::Deque, "A collections.deque.") .value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence."); - auto* PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); + auto* const PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); PyTreeKind_Type->tp_name = "optree.PyTreeKind"; py::setattr(PyTreeKindTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); @@ -178,13 +178,13 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Representing the structure of the pytree.", // NOLINTBEGIN[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while] py::custom_type_setup([](PyHeapTypeObject* heap_type) -> void { - auto* type = &heap_type->ht_type; + auto* const type = &heap_type->ht_type; type->tp_flags |= Py_TPFLAGS_HAVE_GC; type->tp_traverse = &PyTreeSpec::PyTpTraverse; }), // NOLINTEND[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while] py::module_local()); - auto* PyTreeSpec_Type = reinterpret_cast(PyTreeSpecTypeObject.ptr()); + auto* const PyTreeSpec_Type = reinterpret_cast(PyTreeSpecTypeObject.ptr()); PyTreeSpec_Type->tp_name = "optree.PyTreeSpec"; py::setattr(PyTreeSpecTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); @@ -316,13 +316,13 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Iterator over the leaves of a pytree.", // NOLINTBEGIN[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while] py::custom_type_setup([](PyHeapTypeObject* heap_type) -> void { - auto* type = &heap_type->ht_type; + auto* const type = &heap_type->ht_type; type->tp_flags |= Py_TPFLAGS_HAVE_GC; type->tp_traverse = &PyTreeIter::PyTpTraverse; }), // NOLINTEND[readability-function-cognitive-complexity,cppcoreguidelines-avoid-do-while] py::module_local()); - auto* PyTreeIter_Type = reinterpret_cast(PyTreeIterTypeObject.ptr()); + auto* const PyTreeIter_Type = reinterpret_cast(PyTreeIterTypeObject.ptr()); PyTreeIter_Type->tp_name = "optree.PyTreeIter"; py::setattr(PyTreeIterTypeObject.ptr(), Py_Get_ID(__module__), Py_Get_ID(optree)); diff --git a/src/registry.cpp b/src/registry.cpp index 5f50e88f..a9078607 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -34,43 +34,36 @@ namespace optree { template /*static*/ PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; - return &( - storage - .call_once_and_store_result([]() -> PyTreeTypeRegistry { - PyTreeTypeRegistry registry{}; - - const auto add_builtin_type = [®istry](const py::object& cls, - const PyTreeKind& kind) -> void { - auto registration = - std::make_shared>(); - registration->kind = kind; - registration->type = py::reinterpret_borrow(cls); - EXPECT_TRUE( - registry.m_registrations.emplace(cls, std::move(registration)).second, - "PyTree type " + PyRepr(cls) + - " is already registered in the global namespace."); - if (sm_builtins_types.emplace(cls).second) [[likely]] { - cls.inc_ref(); - } - }; - if constexpr (!NoneIsLeaf) { - add_builtin_type(py::type::of(py::none()), PyTreeKind::None); - } - add_builtin_type( - py::reinterpret_borrow(reinterpret_cast(&PyTuple_Type)), - PyTreeKind::Tuple); - add_builtin_type( - py::reinterpret_borrow(reinterpret_cast(&PyList_Type)), - PyTreeKind::List); - add_builtin_type( - py::reinterpret_borrow(reinterpret_cast(&PyDict_Type)), - PyTreeKind::Dict); - add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); - add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); - add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); - return registry; - }) - .get_stored()); + return &(storage + .call_once_and_store_result([]() -> PyTreeTypeRegistry { + PyTreeTypeRegistry registry{}; + + const auto add_builtin_type = [®istry](const py::object& cls, + const PyTreeKind& kind) -> void { + auto registration = + std::make_shared>(); + registration->kind = kind; + registration->type = py::reinterpret_borrow(cls); + EXPECT_TRUE( + registry.m_registrations.emplace(cls, std::move(registration)).second, + "PyTree type " + PyRepr(cls) + + " is already registered in the global namespace."); + if (sm_builtins_types.emplace(cls).second) [[likely]] { + cls.inc_ref(); + } + }; + if constexpr (!NoneIsLeaf) { + add_builtin_type(PyNoneTypeObject, PyTreeKind::None); + } + add_builtin_type(PyTupleTypeObject, PyTreeKind::Tuple); + add_builtin_type(PyListTypeObject, PyTreeKind::List); + add_builtin_type(PyDictTypeObject, PyTreeKind::Dict); + add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); + add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); + add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); + return registry; + }) + .get_stored()); } template PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton(); @@ -87,7 +80,7 @@ template " is a built-in type and cannot be re-registered."); } - PyTreeTypeRegistry* registry = Singleton(); + PyTreeTypeRegistry* const registry = Singleton(); auto registration = std::make_shared>(); registration->kind = PyTreeKind::Custom; registration->type = py::reinterpret_borrow(cls); @@ -181,7 +174,7 @@ template " is a built-in type and cannot be unregistered."); } - PyTreeTypeRegistry* registry = Singleton(); + PyTreeTypeRegistry* const registry = Singleton(); if (registry_namespace.empty()) [[unlikely]] { const auto it = registry->m_registrations.find(cls); if (it == registry->m_registrations.end()) [[unlikely]] { @@ -247,7 +240,7 @@ template const std::string& registry_namespace) { const scoped_read_lock_guard lock{sm_mutex}; - PyTreeTypeRegistry* registry = Singleton(); + PyTreeTypeRegistry* const registry = Singleton(); if (!registry_namespace.empty()) [[unlikely]] { const auto named_it = registry->m_named_registrations.find(std::make_pair(registry_namespace, cls)); @@ -304,8 +297,8 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( /*static*/ void PyTreeTypeRegistry::Clear() { const scoped_write_lock_guard lock{sm_mutex}; - PyTreeTypeRegistry* registry1 = PyTreeTypeRegistry::Singleton(); - PyTreeTypeRegistry* registry2 = PyTreeTypeRegistry::Singleton(); + PyTreeTypeRegistry* const registry1 = PyTreeTypeRegistry::Singleton(); + PyTreeTypeRegistry* const registry2 = PyTreeTypeRegistry::Singleton(); EXPECT_LE(sm_builtins_types.size(), registry1->m_registrations.size()); EXPECT_EQ(registry1->m_registrations.size(), registry2->m_registrations.size() + 1); diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index ca455674..704b6ada 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -815,8 +815,8 @@ bool IsLeafImpl(const py::handle& handle, return true; } PyTreeTypeRegistry::RegistrationPtr custom{nullptr}; - return (PyTreeTypeRegistry::GetKind(handle, custom, registry_namespace) == - PyTreeKind::Leaf); + return PyTreeTypeRegistry::GetKind(handle, custom, registry_namespace) == + PyTreeKind::Leaf; } bool IsLeaf(const py::object& object, diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index d32e6743..e4f466d7 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -885,13 +885,13 @@ py::object PyTreeSpec::GetType(const std::optional& node) const { case PyTreeKind::Leaf: return py::none(); case PyTreeKind::None: - return py::type::of(py::none()); + return PyNoneTypeObject; case PyTreeKind::Tuple: - return py::reinterpret_borrow(reinterpret_cast(&PyTuple_Type)); + return PyTupleTypeObject; case PyTreeKind::List: - return py::reinterpret_borrow(reinterpret_cast(&PyList_Type)); + return PyListTypeObject; case PyTreeKind::Dict: - return py::reinterpret_borrow(reinterpret_cast(&PyDict_Type)); + return PyDictTypeObject; case PyTreeKind::NamedTuple: case PyTreeKind::StructSequence: return py::reinterpret_borrow(n.node_data); @@ -1052,7 +1052,7 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec& other, const bool& strict) const { } } EXPECT_EQ(b, other_traversal.rend(), "PyTreeSpec traversal did not yield a singleton."); - return (!strict || !all_leaves_match); + return !strict || !all_leaves_match; } bool PyTreeSpec::operator==(const PyTreeSpec& other) const { @@ -1549,7 +1549,7 @@ size_t PyTreeSpec::ThreadIndentTypeHash::operator()( #if PY_VERSION_HEX >= 0x03090000 // Python 3.9 Py_VISIT(Py_TYPE(self_base)); #endif - auto* instance = reinterpret_cast(self_base); + auto* const instance = reinterpret_cast(self_base); if (!instance->get_value_and_holder().holder_constructed()) [[unlikely]] { // The holder is not constructed yet. Skip the traversal to avoid segfault. return 0; @@ -1568,7 +1568,7 @@ size_t PyTreeSpec::ThreadIndentTypeHash::operator()( #if PY_VERSION_HEX >= 0x03090000 // Python 3.9 Py_VISIT(Py_TYPE(self_base)); #endif - auto* instance = reinterpret_cast(self_base); + auto* const instance = reinterpret_cast(self_base); if (!instance->get_value_and_holder().holder_constructed()) [[unlikely]] { // The holder is not constructed yet. Skip the traversal to avoid segfault. return 0;