Skip to content

Commit

Permalink
[Feature] Better list casting in TensorDict.from_any
Browse files Browse the repository at this point in the history
ghstack-source-id: 427d19d5ef7c0d2779e064e64522fc0094a885af
Pull Request resolved: #1108
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 0b042cb commit fa8a521
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
10 changes: 9 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
_DTYPE2STRDTYPE,
_GENERIC_NESTED_ERR,
_is_dataclass as is_dataclass,
_is_list_tensor_compatible,
_is_non_tensor,
_is_number,
_is_tensorclass,
Expand Down Expand Up @@ -9869,6 +9870,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
"""
if is_tensor_collection(obj):
if is_non_tensor(obj):
return cls.from_any(obj.data, auto_batch_size=auto_batch_size)
return obj
if isinstance(obj, dict):
return cls.from_dict(obj, auto_batch_size=auto_batch_size)
Expand All @@ -9881,7 +9884,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False):
return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size)
return cls.from_tuple(obj, auto_batch_size=auto_batch_size)
if isinstance(obj, list):
return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size)
if _is_list_tensor_compatible(obj)[0]:
return torch.tensor(obj)
else:
from tensordict.tensorclass import NonTensorStack

return NonTensorStack.from_list(obj)
if is_dataclass(obj):
return cls.from_dataclass(obj, auto_batch_size=auto_batch_size)
if _has_h5:
Expand Down
2 changes: 1 addition & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,7 @@ def _write_to_tensordict(
tensordict_out = tensordict
for _out_key, _tensor in zip(out_keys, tensors):
if _out_key != "_":
tensordict_out.set(_out_key, _tensor)
tensordict_out.set(_out_key, TensorDict.from_any(_tensor))
return tensordict_out

def _call_module(
Expand Down
33 changes: 33 additions & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2824,3 +2824,36 @@ def _is_dataclass(obj):
else type(obj)
)
return hasattr(cls, _FIELDS)


def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]:
length_t = len(t)
dtypes = set()
sizes = set()
for i in t:
if isinstance(i, (float, int, torch.SymInt, Number)):
dtypes.add(type(i))
if len(dtypes) > 1:
return False, None, None
continue
elif isinstance(i, list):
is_compat, size_i, dtype = _is_list_tensor_compatible(i)
if not is_compat:
return False, None, None
if dtype is not None:
dtypes.add(dtype)
if len(dtypes) > 1:
return False, None, None
sizes.add(size_i)
if len(sizes) > 1:
return False, None, None
continue
return False, None
else:
if len(dtypes):
dtype = list(dtypes)[0]
else:
dtype = None
if len(sizes):
return True, (length_t, *list(sizes)[0]), dtype
return True, (length_t,), dtype
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,13 @@ class MyClass:
td.keys(True, True)
).symmetric_difference(expected)

def test_from_any_list(self):
t = torch.randn(3, 4, 5)
t = t.tolist()
assert isinstance(TensorDict.from_any(t), torch.Tensor)
t[0].extend([0, 2])
assert isinstance(TensorDict.from_any(t), TensorDict)

def test_from_any_userdict(self):
class D(UserDict): ...

Expand Down

0 comments on commit fa8a521

Please sign in to comment.