diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76a097d3..16b22bf2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: - id: debug-statements - id: double-quote-string-fixer - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v18.1.8 + rev: v19.1.0 hooks: - id: clang-format - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/optree/dataclasses.py b/optree/dataclasses.py index bdc347fc..7774f015 100644 --- a/optree/dataclasses.py +++ b/optree/dataclasses.py @@ -254,11 +254,6 @@ def decorator(cls: _TypeT) -> _TypeT: if not inspect.isclass(cls): raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.') - if inspect.isabstract(cls): - raise TypeError( - f'@{__name__}.dataclass() cannot register abstract class {cls!r}, ' - 'because it cannot be instantiated.', - ) if _FIELDS in cls.__dict__: raise TypeError( f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.', diff --git a/optree/registry.py b/optree/registry.py index 422e014b..94daf4b2 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -240,10 +240,6 @@ def register_pytree_node( """ if not inspect.isclass(cls): raise TypeError(f'Expected a class, got {cls!r}.') - if inspect.isabstract(cls): - raise TypeError( - f'Cannot register abstract class {cls!r}, because it cannot be instantiated.', - ) if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.') if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str): @@ -387,10 +383,6 @@ def tree_unflatten(cls, metadata, children): ) # type: ignore[return-value] if not inspect.isclass(cls): raise TypeError(f'Expected a class, got {cls!r}.') - if inspect.isabstract(cls): - raise TypeError( - f'Cannot register abstract class {cls!r}, because it cannot be instantiated.', - ) if path_entry_type is None: path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry) if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)): diff --git a/optree/version.py b/optree/version.py index 4f80a448..7bbc8fef 100644 --- a/optree/version.py +++ b/optree/version.py @@ -14,6 +14,8 @@ # ============================================================================== """OpTree: Optimized PyTree Utilities.""" +# pylint: disable=invalid-name + __version__ = '0.12.1' __license__ = 'Apache License, Version 2.0' __author__ = 'OpTree Contributors' diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 3e153987..21715c84 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -15,8 +15,6 @@ # pylint: disable=missing-function-docstring,invalid-name,wrong-import-order -import abc -import collections.abc import dataclasses import inspect import math @@ -465,41 +463,6 @@ class Foo: assert treespec.type is Foo -def test_dataclass_with_abstract_class(): - with pytest.raises( - TypeError, - match=( - r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, ' - r'because it cannot be instantiated\.' - ), - ): - - @optree.dataclasses.dataclass(namespace='error') - class Vec(abc.ABC): - x: float - y: float - - @abc.abstractmethod - def norm(self, p: float = 2): - raise NotImplementedError - - with pytest.raises( - TypeError, - match=( - r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, ' - r'because it cannot be instantiated\.' - ), - ): - - @optree.dataclasses.dataclass(namespace='error') - class Vec(collections.abc.Sequence): - x: float - y: float - - def __len__(self): - return 2 - - def test_make_dataclass_future_parameters(): with pytest.raises( TypeError, @@ -850,39 +813,3 @@ def test_make_dataclass_with_invalid_namespace(): assert treespec.namespace == '' assert treespec.kind == optree.PyTreeKind.CUSTOM assert treespec.type is Bar - - -def test_make_dataclass_with_abstract_class(): - with pytest.raises( - TypeError, - match=( - r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, ' - r'because it cannot be instantiated\.' - ), - ): - optree.dataclasses.make_dataclass( - 'Vec', - [('x', float), ('y', float)], - bases=(abc.ABC,), - ns={ - 'norm': abc.abstractmethod(lambda self, p=2: (self.x**p + self.y**p) ** (1 / p)), - }, - namespace='error', - ) - - with pytest.raises( - TypeError, - match=( - r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, ' - r'because it cannot be instantiated\.' - ), - ): - optree.dataclasses.make_dataclass( - 'Vec', - [('x', float), ('y', float)], - bases=(collections.abc.Sequence,), - ns={ - '__len__': lambda self: 2, - }, - namespace='error', - ) diff --git a/tests/test_registry.py b/tests/test_registry.py index aae9d12d..5477a408 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -15,8 +15,6 @@ # pylint: disable=missing-function-docstring,invalid-name -import abc -import collections.abc import re import weakref from collections import UserDict, UserList, namedtuple @@ -174,37 +172,6 @@ def tree_unflatten(cls, metadata, children): ) -def test_register_pytree_node_with_abstract_class(): - with pytest.raises( - TypeError, - match=r'Cannot register abstract class .*, because it cannot be instantiated\.', - ): - - @optree.register_pytree_node_class(namespace='error') - class MyList(UserList, abc.ABC): - def tree_flatten(self): - return self.data, None, None - - @classmethod - def tree_unflatten(cls, metadata, children): - return cls(children) - - @abc.abstractmethod - def copy(self): - return type(self)(self) - - with pytest.raises( - TypeError, - match=r'Cannot register abstract class .*, because it cannot be instantiated\.', - ): - optree.register_pytree_node( - collections.abc.Sequence, - lambda seq: (list(seq), type(seq), None), - lambda cls, seq: cls(seq), - namespace='error', - ) - - def test_register_pytree_node_with_invalid_path_entry_type(): with pytest.raises(TypeError, match=r'Expected a subclass of PyTreeEntry, got .*\.'):