Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] super() calls within TensorClass subclasses #1133

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensordict/nn/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensordict.nn.utils import mappings
from torch import distributions as D, nn

# We need this to build the distribution maps
__all__ = [
"NormalParamExtractor",
"NormalParamWrapper",
Expand All @@ -23,7 +24,7 @@
]

# speeds up distribution construction
D.Distribution.set_default_validate_args(False)
# D.Distribution.set_default_validate_args(False)


class NormalParamWrapper(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch import distributions as D

# We need this to build the distribution maps
__all__ = [
"OneHotCategorical",
]
Expand Down
50 changes: 41 additions & 9 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,20 +687,40 @@ def __torch_function__(
cls.__getstate__ = _getstate
cls.__setstate__ = _setstate
# cls.__getattribute__ = object.__getattribute__
# if "__getattr__" not in cls.__dict__:
cls.__getattr__ = _getattr
cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
# cls.__getattr__ = _getattr
cls.__getitem__ = _getitem
cls.__getitems__ = _getitem
cls.__setitem__ = _setitem
if "__setattr__" not in cls.__dict__:
cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
if "__getitem__" not in cls.__dict__:
cls.__getitem__ = _getitem
if "__getitems__" not in cls.__dict__:
cls.__getitems__ = _getitem
if "__setitem__" not in cls.__dict__:
cls.__setitem__ = _setitem
if not _is_non_tensor:
cls.__repr__ = _repr
cls.__len__ = _len
if "__len__" not in cls.__dict__:
cls.__len__ = _len
#
cls.__eq__ = _eq
cls.__ne__ = _ne
cls.__or__ = _or
cls.__xor__ = _xor
cls.__bool__ = _bool

# cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
# # cls.__getattr__ = _getattr
# cls.__getitem__ = _getitem
# cls.__getitems__ = _getitem
# cls.__setitem__ = _setitem
# if not _is_non_tensor:
# cls.__repr__ = _repr
# cls.__len__ = _len
# cls.__eq__ = _eq
# cls.__ne__ = _ne
# cls.__or__ = _or
# cls.__xor__ = _xor
# cls.__bool__ = _bool
if not hasattr(cls, "non_tensor_items"):
cls.non_tensor_items = _non_tensor_items
if not hasattr(cls, "set"):
Expand Down Expand Up @@ -2910,22 +2930,32 @@ def to_tensordict(self, *, retain_none: bool | None = None):
return self

@classmethod
def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]

ids = set()
firstdata = NO_DEFAULT
return_stack = False
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
return_stack = True
if raise_if_non_unique:
data = cls._stack_non_tensor(
data, raise_if_non_unique=raise_if_non_unique
)
else:
return_stack = True
break
if firstdata is NO_DEFAULT:
firstdata = data.data
ids.add(id(data.data))
if len(ids) > 1:
if _check_equal(data.data, firstdata):
continue
if raise_if_non_unique:
raise ValueError(
"More than one unique value has been found in the stack."
)
return_stack = True
break
else:
Expand Down Expand Up @@ -3454,7 +3484,9 @@ def data(self):
self.tensordicts, raise_if_non_unique=True
).data
except ValueError:
raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.")
raise AttributeError(
"Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead."
)


_register_tensor_class(NonTensorStack)
Expand Down
26 changes: 22 additions & 4 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ class X:
assert isinstance(x.y, torch.Tensor)
_ = {x: 0}
assert x.is_locked
with pytest.raises(RuntimeError, match="locked"):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
x.y = 0

@tensorclass(frozen=False, autocast=True)
Expand All @@ -643,7 +643,7 @@ class X:
assert isinstance(x.y, str)
_ = {x: 0}
assert x.is_locked
with pytest.raises(RuntimeError, match="locked"):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
x.y = 0

@tensorclass(frozen=False, autocast=False)
Expand Down Expand Up @@ -2585,7 +2585,7 @@ class SubClass(TensorClass, nocast=True, frozen=True):
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
s.a = 2

class SubClass(TensorClass["nocast", "frozen"]):
Expand All @@ -2599,9 +2599,27 @@ class SubClass(TensorClass["nocast", "frozen"]):
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
s.a = 2

def test_subclassing_super_call(self):
class SubClass(TensorClass, nocast=True):
a: int
b: int

def __setattr__(self, key, value):
if key == "b":
return super().__setattr__("b", value + 1)
return super().__setattr__("a", value - 1)

s = SubClass(a=torch.zeros(3), b=torch.zeros(3))
assert (s.a == -1).all()
assert (s.b == 1).all()
s.a = torch.ones(())
s.b = torch.ones(())
assert (s.a == 0).all()
assert (s.b == 2).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
1 change: 1 addition & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8372,6 +8372,7 @@ def test_consolidate(self, device, use_file, tmpdir):
assert hasattr(td_c, "_consolidated")
assert type(td_c) == type(td) # noqa
assert (td.to(td_c.device) == td_c).all()
assert td["d"] == [["a string!"] * 3]
assert td_c["d"] == [["a string!"] * 3]

storage = td_c._consolidated["storage"]
Expand Down
Loading