Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 25, 2024
2 parents d820a95 + 037f5a3 commit 24c1573
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 50 deletions.
8 changes: 6 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9870,8 +9870,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
"""
if is_tensor_collection(obj):
if is_non_tensor(obj):
return cls.from_any(obj.data, auto_batch_size=auto_batch_size)
# Conversions from non-tensor data must be done manually
# if is_non_tensor(obj):
# from tensordict.tensorclass import LazyStackedTensorDict
# if isinstance(obj, LazyStackedTensorDict):
# return obj
# return cls.from_any(obj.data, auto_batch_size=auto_batch_size)
return obj
if isinstance(obj, dict):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
Expand Down
4 changes: 2 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,13 @@ def test_nontensor(self):
in_keys=[],
out_keys=["out"],
)
assert tdm(TensorDict({}))["out"] == [1, 2]
assert tdm(TensorDict())["out"] == [1, 2]
tdm = TensorDictModule(
lambda: "a string!",
in_keys=[],
out_keys=["out"],
)
assert tdm(TensorDict({}))["out"] == "a string!"
assert tdm(TensorDict())["out"] == "a string!"

@pytest.mark.parametrize(
"out_keys",
Expand Down
92 changes: 46 additions & 46 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,13 +814,13 @@ def test_expand_with_singleton(self, device):
@set_lazy_legacy(True)
def test_filling_empty_tensordict(self, device, td_type, update):
if td_type == "tensordict":
td = TensorDict({}, batch_size=[16], device=device)
td = TensorDict(batch_size=[16], device=device)
elif td_type == "view":
td = TensorDict({}, batch_size=[4, 4], device=device).view(-1)
td = TensorDict(batch_size=[4, 4], device=device).view(-1)
elif td_type == "unsqueeze":
td = TensorDict({}, batch_size=[16], device=device).unsqueeze(-1)
td = TensorDict(batch_size=[16], device=device).unsqueeze(-1)
elif td_type == "squeeze":
td = TensorDict({}, batch_size=[16, 1], device=device).squeeze(-1)
td = TensorDict(batch_size=[16, 1], device=device).squeeze(-1)
elif td_type == "stack":
td = LazyStackedTensorDict.lazy_stack(
[TensorDict({}, [], device=device) for _ in range(16)], 0
Expand Down Expand Up @@ -2591,7 +2591,7 @@ def test_record_stream(self):
@pytest.mark.parametrize("device", get_available_devices())
def test_subtensordict_construction(self, device):
torch.manual_seed(1)
td = TensorDict({}, batch_size=(4, 5))
td = TensorDict(batch_size=(4, 5))
val1 = torch.randn(4, 5, 1, device=device)
val2 = torch.randn(4, 5, 6, dtype=torch.double, device=device)
val1_copy = val1.clone()
Expand Down Expand Up @@ -2694,7 +2694,7 @@ def test_tensordict_error_messages(self, device):
@pytest.mark.parametrize("device", get_available_devices())
def test_tensordict_indexing(self, device):
torch.manual_seed(1)
td = TensorDict({}, batch_size=(4, 5))
td = TensorDict(batch_size=(4, 5))
td.set("key1", torch.randn(4, 5, 1, device=device))
td.set("key2", torch.randn(4, 5, 6, device=device, dtype=torch.double))

Expand Down Expand Up @@ -2736,7 +2736,7 @@ def test_tensordict_prealloc_nested(self):
N = 3
B = 5
T = 4
buffer = TensorDict({}, batch_size=[B, N])
buffer = TensorDict(batch_size=[B, N])

td_0 = TensorDict(
{
Expand Down Expand Up @@ -2777,7 +2777,7 @@ def test_tensordict_prealloc_nested(self):
@pytest.mark.parametrize("device", get_available_devices())
def test_tensordict_set(self, device):
torch.manual_seed(1)
td = TensorDict({}, batch_size=(4, 5), device=device)
td = TensorDict(batch_size=(4, 5), device=device)
td.set("key1", torch.randn(4, 5))
assert td.device == torch.device(device)
# by default inplace:
Expand Down Expand Up @@ -4235,7 +4235,7 @@ def test_flatten_unflatten_bis(self, td_name, device):
def test_from_empty(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
new_td = TensorDict({}, batch_size=td.batch_size, device=device)
new_td = TensorDict(batch_size=td.batch_size, device=device)
for key, item in td.items():
new_td.set(key, item)
assert_allclose_td(td, new_td)
Expand Down Expand Up @@ -4433,7 +4433,7 @@ def test_items_values_keys(self, td_name, device):
items = list(td.items())

# Test td.items()
constructed_td1 = TensorDict({}, batch_size=td.shape)
constructed_td1 = TensorDict(batch_size=td.shape)
for key, value in items:
constructed_td1.set(key, value)

Expand All @@ -4443,7 +4443,7 @@ def test_items_values_keys(self, td_name, device):
# items = [key, value] should be verified
assert len(values) == len(items)
assert len(keys) == len(items)
constructed_td2 = TensorDict({}, batch_size=td.shape)
constructed_td2 = TensorDict(batch_size=td.shape)
for key, value in list(zip(td.keys(), td.values())):
constructed_td2.set(key, value)

Expand All @@ -4464,7 +4464,7 @@ def test_items_values_keys(self, td_name, device):

# Test td.items()
# after adding the new element
constructed_td1 = TensorDict({}, batch_size=td.shape)
constructed_td1 = TensorDict(batch_size=td.shape)
for key, value in items:
constructed_td1.set(key, value)

Expand All @@ -4476,7 +4476,7 @@ def test_items_values_keys(self, td_name, device):
assert len(values) == len(items)
assert len(keys) == len(items)

constructed_td2 = TensorDict({}, batch_size=td.shape)
constructed_td2 = TensorDict(batch_size=td.shape)
for key, value in list(zip(td.keys(), td.values())):
constructed_td2.set(key, value)

Expand Down Expand Up @@ -9382,30 +9382,30 @@ def run_assertions():

class TestNamedDims(TestTensorDictsBase):
def test_all(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tda = td.all(2)
assert tda.names == ["a", "b", "d"]
tda = td.any(2)
assert tda.names == ["a", "b", "d"]

def test_apply(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tda = td.apply(lambda x: x + 1)
assert tda.names == ["a", "b", "c", "d"]
tda = td.apply(lambda x: x.squeeze(2), batch_size=[3, 4, 6])
# no way to tell what the names have become, in general
assert tda.names == [None] * 3

def test_cat(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
tdc = torch.cat([td, td], -1)
assert tdc.names == [None] * 4
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
tdc = torch.cat([td, td], -1)
assert tdc.names == ["a", "b", "c", "d"]

def test_change_batch_size(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"])
td.batch_size = [3, 4, 1, 6, 1]
assert td.names == ["a", "b", "c", "z", None]
td.batch_size = []
Expand All @@ -9417,22 +9417,22 @@ def test_change_batch_size(self):
assert td.names == ["a"]

def test_clone(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
td.names = ["a", "b", "c", "d"]
tdc = td.clone()
assert tdc.names == ["a", "b", "c", "d"]
tdc = td.clone(False)
assert tdc.names == ["a", "b", "c", "d"]

def test_detach(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td[""] = torch.zeros(td.shape, requires_grad=True)
tdd = td.detach()
assert tdd.names == ["a", "b", "c", "d"]

def test_error_similar(self):
with pytest.raises(ValueError):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "a"])
with pytest.raises(ValueError):
td = TensorDict(
{},
Expand All @@ -9446,16 +9446,16 @@ def test_error_similar(self):
)
td.refine_names("a", "a", ...)
with pytest.raises(ValueError):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "z"])
td.rename_(a="z")

def test_expand(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tde = td.expand(2, 3, 4, 5, 6)
assert tde.names == [None, "a", "b", "c", "d"]

def test_flatten(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tdf = td.flatten(1, 3)
assert tdf.names == ["a", None]
tdu = tdf.unflatten(1, (4, 1, 6))
Expand All @@ -9470,11 +9470,11 @@ def test_flatten(self):
assert tdu.names == [None, None, None, "d"]

def test_fullname(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
assert td.names == ["a", "b", "c", "d"]

def test_gather(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
idx = torch.randint(6, (3, 4, 1, 18))
tdg = td.gather(dim=-1, index=idx)
assert tdg.names == ["a", "b", "c", "d"]
Expand All @@ -9499,7 +9499,7 @@ def test_h5_td(self):
assert td.names == list("abgd")

def test_index(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
assert td[0].names == ["b", "c", "d"]
assert td[:, 0].names == ["a", "c", "d"]
assert td[0, :].names == ["b", "c", "d"]
Expand All @@ -9519,7 +9519,7 @@ def test_index(self):
assert tdbool.ndim == 3

def test_masked_fill(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tdm = td.masked_fill(torch.zeros(3, 4, 1, dtype=torch.bool), 1.0)
assert tdm.names == ["a", "b", "c", "d"]

Expand All @@ -9543,16 +9543,16 @@ def test_memmap_td(self):
assert td.clone().names == list("abgd")

def test_nested(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td["a"] = TensorDict(batch_size=[3, 4, 1, 6])
assert td["a"].names == td.names
td["a"] = TensorDict({}, batch_size=[])
td["a"] = TensorDict()
assert td["a"].names == td.names
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=None)
td["a"] = TensorDict({}, batch_size=[3, 4, 1, 6])
td = TensorDict(batch_size=[3, 4, 1, 6], names=None)
td["a"] = TensorDict(batch_size=[3, 4, 1, 6])
td.names = ["a", "b", None, None]
assert td["a"].names == td.names
td.set_("a", TensorDict({}, batch_size=[3, 4, 1, 6]))
td.set_("a", TensorDict(batch_size=[3, 4, 1, 6]))
assert td["a"].names == td.names

def test_nested_indexing(self):
Expand Down Expand Up @@ -9602,15 +9602,15 @@ def test_nested_td(self):
assert nested_td.contiguous()["my_nested_td"].names == list("abgd")

def test_noname(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
assert td.names == [None] * 4

def test_partial_name(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", None, None, "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", None, None, "d"])
assert td.names == ["a", None, None, "d"]

def test_partial_set(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
td.names = ["a", None, None, "d"]
assert td.names == ["a", None, None, "d"]
td.names = ["a", "b", "c", "d"]
Expand Down Expand Up @@ -9639,7 +9639,7 @@ def test_permute_td(self):
td.names = list("abcd")

def test_refine_names(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6])
td = TensorDict(batch_size=[3, 4, 5, 6])
tdr = td.refine_names(None, None, None, "d")
assert tdr.names == [None, None, None, "d"]
tdr = tdr.refine_names(None, None, "c", "d")
Expand All @@ -9654,7 +9654,7 @@ def test_refine_names(self):
assert tdr.names == ["a", None, "c", "d"]

def test_rename(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
td.names = ["a", None, None, "d"]
td.rename_(a="c")
assert td.names == ["c", None, None, "d"]
Expand All @@ -9670,7 +9670,7 @@ def test_rename(self):
assert td2.names == ["w", "x", "y", "z"]

def test_select(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
tds = td.select()
assert tds.names == ["a", "b", "c", "d"]
tde = td.exclude()
Expand Down Expand Up @@ -9707,11 +9707,11 @@ def test_split(self):
# assert tdu.is_locked

def test_squeeze(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
td.names = ["a", "b", "c", "d"]
tds = td.squeeze(0)
assert tds.names == ["a", "b", "c", "d"]
td = TensorDict({}, batch_size=[3, 1, 5, 6], names=None)
td = TensorDict(batch_size=[3, 1, 5, 6], names=None)
td.names = ["a", "b", "c", "d"]
tds = td.squeeze(1)
assert tds.names == ["a", "c", "d"]
Expand All @@ -9724,7 +9724,7 @@ def test_squeeze_td(self):
td.names = list("abcd")

def test_stack(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
tds = LazyStackedTensorDict.lazy_stack([td, td], 0)
assert tds.names == [None, "a", "b", "c", "d"]
tds = LazyStackedTensorDict.lazy_stack([td, td], -1)
Expand Down Expand Up @@ -9762,7 +9762,7 @@ def test_sub_td(self):
td.names = list("abcd")

def test_subtd(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 5, 6], names=["a", "b", "c", "d"])
assert td._get_sub_tensordict(0).names == ["b", "c", "d"]
assert td._get_sub_tensordict((slice(None), 0)).names == ["a", "c", "d"]
assert td._get_sub_tensordict((0, slice(None))).names == ["b", "c", "d"]
Expand Down Expand Up @@ -9826,14 +9826,14 @@ def test_to(self, device, non_blocking_pin, num_threads, inplace):
assert tdt is not td

def test_unbind(self):
td = TensorDict({}, batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
td = TensorDict(batch_size=[3, 4, 1, 6], names=["a", "b", "c", "d"])
*_, tdu = td.unbind(-1)
assert tdu.names == ["a", "b", "c"]
*_, tdu = td.unbind(-2)
assert tdu.names == ["a", "b", "d"]

def test_unsqueeze(self):
td = TensorDict({}, batch_size=[3, 4, 5, 6], names=None)
td = TensorDict(batch_size=[3, 4, 5, 6], names=None)
td.names = ["a", "b", "c", "d"]
tdu = td.unsqueeze(0)
assert tdu.names == [None, "a", "b", "c", "d"]
Expand Down

0 comments on commit 24c1573

Please sign in to comment.