Skip to content

Commit

Permalink
[BugFix] Fix from_any tests
Browse files Browse the repository at this point in the history
ghstack-source-id: 8c3b3d825555c727c7c18c7e8a87311f718a94b6
Pull Request resolved: #1110
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 2728dbf commit dc11197
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,7 @@ class MyClass:
a: int

pytree = (
[torch.randint(10, (3,)), torch.zeros(2)],
[[-1, 0, 1], [2, 3, 4]],
{
"tensor": torch.randn(
2,
Expand All @@ -974,8 +974,7 @@ class MyClass:
pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},)
td = TensorDict.from_any(pytree)
expected = {
("0", "0"),
("0", "1"),
"0",
("1", "td", "one"),
("1", "tensor"),
("1", "tuple", "0"),
Expand All @@ -1001,8 +1000,8 @@ def test_from_any_list(self):
t = torch.randn(3, 4, 5)
t = t.tolist()
assert isinstance(TensorDict.from_any(t), torch.Tensor)
t[0].extend([0, 2])
assert isinstance(TensorDict.from_any(t), TensorDict)
t[0][1].extend([0, 2])
assert isinstance(TensorDict.from_any(t), NonTensorStack)

def test_from_any_userdict(self):
class D(UserDict): ...
Expand Down

0 comments on commit dc11197

Please sign in to comment.