Skip to content

Commit

Permalink
chore(registry): re-enable registering abstract classes
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Sep 24, 2024
1 parent b5753fd commit 1436812
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 120 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/pre-commit/mirrors-clang-format
rev: v18.1.8
rev: v19.1.0
hooks:
- id: clang-format
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
5 changes: 0 additions & 5 deletions optree/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,6 @@ def decorator(cls: _TypeT) -> _TypeT:

if not inspect.isclass(cls):
raise TypeError(f'@{__name__}.dataclass() can only be used with classes, not {cls!r}.')
if inspect.isabstract(cls):
raise TypeError(
f'@{__name__}.dataclass() cannot register abstract class {cls!r}, '
'because it cannot be instantiated.',
)
if _FIELDS in cls.__dict__:
raise TypeError(
f'@{__name__}.dataclass() cannot be applied to {cls.__name__} more than once.',
Expand Down
8 changes: 0 additions & 8 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,6 @@ def register_pytree_node(
"""
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if inspect.isabstract(cls):
raise TypeError(
f'Cannot register abstract class {cls!r}, because it cannot be instantiated.',
)
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
raise TypeError(f'Expected a subclass of PyTreeEntry, got {path_entry_type!r}.')
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
Expand Down Expand Up @@ -387,10 +383,6 @@ def tree_unflatten(cls, metadata, children):
) # type: ignore[return-value]
if not inspect.isclass(cls):
raise TypeError(f'Expected a class, got {cls!r}.')
if inspect.isabstract(cls):
raise TypeError(
f'Cannot register abstract class {cls!r}, because it cannot be instantiated.',
)
if path_entry_type is None:
path_entry_type = getattr(cls, 'TREE_PATH_ENTRY_TYPE', AutoEntry)
if not (inspect.isclass(path_entry_type) and issubclass(path_entry_type, PyTreeEntry)):
Expand Down
2 changes: 2 additions & 0 deletions optree/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""

# pylint: disable=invalid-name

__version__ = '0.12.1'
__license__ = 'Apache License, Version 2.0'
__author__ = 'OpTree Contributors'
Expand Down
73 changes: 0 additions & 73 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# pylint: disable=missing-function-docstring,invalid-name,wrong-import-order

import abc
import collections.abc
import dataclasses
import inspect
import math
Expand Down Expand Up @@ -465,41 +463,6 @@ class Foo:
assert treespec.type is Foo


def test_dataclass_with_abstract_class():
with pytest.raises(
TypeError,
match=(
r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, '
r'because it cannot be instantiated\.'
),
):

@optree.dataclasses.dataclass(namespace='error')
class Vec(abc.ABC):
x: float
y: float

@abc.abstractmethod
def norm(self, p: float = 2):
raise NotImplementedError

with pytest.raises(
TypeError,
match=(
r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, '
r'because it cannot be instantiated\.'
),
):

@optree.dataclasses.dataclass(namespace='error')
class Vec(collections.abc.Sequence):
x: float
y: float

def __len__(self):
return 2


def test_make_dataclass_future_parameters():
with pytest.raises(
TypeError,
Expand Down Expand Up @@ -850,39 +813,3 @@ def test_make_dataclass_with_invalid_namespace():
assert treespec.namespace == ''
assert treespec.kind == optree.PyTreeKind.CUSTOM
assert treespec.type is Bar


def test_make_dataclass_with_abstract_class():
with pytest.raises(
TypeError,
match=(
r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, '
r'because it cannot be instantiated\.'
),
):
optree.dataclasses.make_dataclass(
'Vec',
[('x', float), ('y', float)],
bases=(abc.ABC,),
ns={
'norm': abc.abstractmethod(lambda self, p=2: (self.x**p + self.y**p) ** (1 / p)),
},
namespace='error',
)

with pytest.raises(
TypeError,
match=(
r'@optree\.dataclasses\.dataclass\(\) cannot register abstract class .*, '
r'because it cannot be instantiated\.'
),
):
optree.dataclasses.make_dataclass(
'Vec',
[('x', float), ('y', float)],
bases=(collections.abc.Sequence,),
ns={
'__len__': lambda self: 2,
},
namespace='error',
)
33 changes: 0 additions & 33 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

# pylint: disable=missing-function-docstring,invalid-name

import abc
import collections.abc
import re
import weakref
from collections import UserDict, UserList, namedtuple
Expand Down Expand Up @@ -174,37 +172,6 @@ def tree_unflatten(cls, metadata, children):
)


def test_register_pytree_node_with_abstract_class():
with pytest.raises(
TypeError,
match=r'Cannot register abstract class .*, because it cannot be instantiated\.',
):

@optree.register_pytree_node_class(namespace='error')
class MyList(UserList, abc.ABC):
def tree_flatten(self):
return self.data, None, None

@classmethod
def tree_unflatten(cls, metadata, children):
return cls(children)

@abc.abstractmethod
def copy(self):
return type(self)(self)

with pytest.raises(
TypeError,
match=r'Cannot register abstract class .*, because it cannot be instantiated\.',
):
optree.register_pytree_node(
collections.abc.Sequence,
lambda seq: (list(seq), type(seq), None),
lambda cls, seq: cls(seq),
namespace='error',
)


def test_register_pytree_node_with_invalid_path_entry_type():
with pytest.raises(TypeError, match=r'Expected a subclass of PyTreeEntry, got .*\.'):

Expand Down

0 comments on commit 1436812

Please sign in to comment.