Skip to content

Commit

Permalink
[Feature] super() calls within TensorClass subclasses
Browse files Browse the repository at this point in the history
ghstack-source-id: 060a89982413869c54e1fb4aa74f90e2b9cdaac4
Pull Request resolved: #1133
  • Loading branch information
vmoens committed Dec 9, 2024
1 parent fbcbcb2 commit dbcadab
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 14 deletions.
3 changes: 2 additions & 1 deletion tensordict/nn/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tensordict.nn.utils import mappings
from torch import distributions as D, nn

# We need this to build the distribution maps
__all__ = [
"NormalParamExtractor",
"NormalParamWrapper",
Expand All @@ -23,7 +24,7 @@
]

# speeds up distribution construction
D.Distribution.set_default_validate_args(False)
# D.Distribution.set_default_validate_args(False)


class NormalParamWrapper(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions tensordict/nn/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
from torch import distributions as D

# We need this to build the distribution maps
__all__ = [
"OneHotCategorical",
]
Expand Down
50 changes: 41 additions & 9 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,20 +687,40 @@ def __torch_function__(
cls.__getstate__ = _getstate
cls.__setstate__ = _setstate
# cls.__getattribute__ = object.__getattribute__
# if "__getattr__" not in cls.__dict__:
cls.__getattr__ = _getattr
cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
# cls.__getattr__ = _getattr
cls.__getitem__ = _getitem
cls.__getitems__ = _getitem
cls.__setitem__ = _setitem
if "__setattr__" not in cls.__dict__:
cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
if "__getitem__" not in cls.__dict__:
cls.__getitem__ = _getitem
if "__getitems__" not in cls.__dict__:
cls.__getitems__ = _getitem
if "__setitem__" not in cls.__dict__:
cls.__setitem__ = _setitem
if not _is_non_tensor:
cls.__repr__ = _repr
cls.__len__ = _len
if "__len__" not in cls.__dict__:
cls.__len__ = _len
#
cls.__eq__ = _eq
cls.__ne__ = _ne
cls.__or__ = _or
cls.__xor__ = _xor
cls.__bool__ = _bool

# cls.__setattr__ = _setattr_wrapper(cls.__setattr__, expected_keys)
# # cls.__getattr__ = _getattr
# cls.__getitem__ = _getitem
# cls.__getitems__ = _getitem
# cls.__setitem__ = _setitem
# if not _is_non_tensor:
# cls.__repr__ = _repr
# cls.__len__ = _len
# cls.__eq__ = _eq
# cls.__ne__ = _ne
# cls.__or__ = _or
# cls.__xor__ = _xor
# cls.__bool__ = _bool
if not hasattr(cls, "non_tensor_items"):
cls.non_tensor_items = _non_tensor_items
if not hasattr(cls, "set"):
Expand Down Expand Up @@ -2910,22 +2930,32 @@ def to_tensordict(self, *, retain_none: bool | None = None):
return self

@classmethod
def _stack_non_tensor(cls, list_of_non_tensor, dim=0):
def _stack_non_tensor(cls, list_of_non_tensor, dim=0, raise_if_non_unique=False):
# checks have been performed previously, so we're sure the list is non-empty
first = list_of_non_tensor[0]

ids = set()
firstdata = NO_DEFAULT
return_stack = False
for data in list_of_non_tensor:
if not isinstance(data, NonTensorData):
return_stack = True
if raise_if_non_unique:
data = cls._stack_non_tensor(
data, raise_if_non_unique=raise_if_non_unique
)
else:
return_stack = True
break
if firstdata is NO_DEFAULT:
firstdata = data.data
ids.add(id(data.data))
if len(ids) > 1:
if _check_equal(data.data, firstdata):
continue
if raise_if_non_unique:
raise ValueError(
"More than one unique value has been found in the stack."
)
return_stack = True
break
else:
Expand Down Expand Up @@ -3454,7 +3484,9 @@ def data(self):
self.tensordicts, raise_if_non_unique=True
).data
except ValueError:
raise AttributeError("Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead.")
raise AttributeError(
"Cannot get the non-unique data of a NonTensorStack. Use .tolist() instead."
)


_register_tensor_class(NonTensorStack)
Expand Down
26 changes: 22 additions & 4 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ class X:
assert isinstance(x.y, torch.Tensor)
_ = {x: 0}
assert x.is_locked
with pytest.raises(RuntimeError, match="locked"):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
x.y = 0

@tensorclass(frozen=False, autocast=True)
Expand All @@ -643,7 +643,7 @@ class X:
assert isinstance(x.y, str)
_ = {x: 0}
assert x.is_locked
with pytest.raises(RuntimeError, match="locked"):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
x.y = 0

@tensorclass(frozen=False, autocast=False)
Expand Down Expand Up @@ -2585,7 +2585,7 @@ class SubClass(TensorClass, nocast=True, frozen=True):
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
s.a = 2

class SubClass(TensorClass["nocast", "frozen"]):
Expand All @@ -2599,9 +2599,27 @@ class SubClass(TensorClass["nocast", "frozen"]):
assert issubclass(SubClass, TensorClass)
s = SubClass(1)
assert isinstance(s.a, int)
with pytest.raises(RuntimeError):
with pytest.raises((RuntimeError, dataclasses.FrozenInstanceError)):
s.a = 2

def test_subclassing_super_call(self):
class SubClass(TensorClass, nocast=True):
a: int
b: int

def __setattr__(self, key, value):
if key == "b":
return super().__setattr__("b", value + 1)
return super().__setattr__("a", value - 1)

s = SubClass(a=torch.zeros(3), b=torch.zeros(3))
assert (s.a == -1).all()
assert (s.b == 1).all()
s.a = torch.ones(())
s.b = torch.ones(())
assert (s.a == 0).all()
assert (s.b == 2).all()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
Expand Down
1 change: 1 addition & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8372,6 +8372,7 @@ def test_consolidate(self, device, use_file, tmpdir):
assert hasattr(td_c, "_consolidated")
assert type(td_c) == type(td) # noqa
assert (td.to(td_c.device) == td_c).all()
assert td["d"] == [["a string!"] * 3]
assert td_c["d"] == [["a string!"] * 3]

storage = td_c._consolidated["storage"]
Expand Down

0 comments on commit dbcadab

Please sign in to comment.