From dc11197892a541f04b74be5465ec969b95a558b7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 11:32:29 +0000 Subject: [PATCH] [BugFix] Fix from_any tests ghstack-source-id: 8c3b3d825555c727c7c18c7e8a87311f718a94b6 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1110 --- test/test_tensordict.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 63ebc8935..f7fe9a9ff 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -959,7 +959,7 @@ class MyClass: a: int pytree = ( - [torch.randint(10, (3,)), torch.zeros(2)], + [[-1, 0, 1], [2, 3, 4]], { "tensor": torch.randn( 2, @@ -974,8 +974,7 @@ class MyClass: pytree = pytree + ({"h5py": TestTensorDictsBase.td_h5(device="cpu").file},) td = TensorDict.from_any(pytree) expected = { - ("0", "0"), - ("0", "1"), + "0", ("1", "td", "one"), ("1", "tensor"), ("1", "tuple", "0"), @@ -1001,8 +1000,8 @@ def test_from_any_list(self): t = torch.randn(3, 4, 5) t = t.tolist() assert isinstance(TensorDict.from_any(t), torch.Tensor) - t[0].extend([0, 2]) - assert isinstance(TensorDict.from_any(t), TensorDict) + t[0][1].extend([0, 2]) + assert isinstance(TensorDict.from_any(t), NonTensorStack) def test_from_any_userdict(self): class D(UserDict): ...