diff --git a/tensordict/base.py b/tensordict/base.py index f3034fd96..159777e28 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -57,6 +57,7 @@ _DTYPE2STRDTYPE, _GENERIC_NESTED_ERR, _is_dataclass as is_dataclass, + _is_list_tensor_compatible, _is_non_tensor, _is_number, _is_tensorclass, @@ -9869,6 +9870,8 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): """ if is_tensor_collection(obj): + if is_non_tensor(obj): + return cls.from_any(obj.data, auto_batch_size=auto_batch_size) return obj if isinstance(obj, dict): return cls.from_dict(obj, auto_batch_size=auto_batch_size) @@ -9881,7 +9884,12 @@ def from_any(cls, obj, *, auto_batch_size: bool = False): return cls.from_namedtuple(obj, auto_batch_size=auto_batch_size) return cls.from_tuple(obj, auto_batch_size=auto_batch_size) if isinstance(obj, list): - return cls.from_tuple(tuple(obj), auto_batch_size=auto_batch_size) + if _is_list_tensor_compatible(obj)[0]: + return torch.tensor(obj) + else: + from tensordict.tensorclass import NonTensorStack + + return NonTensorStack.from_list(obj) if is_dataclass(obj): return cls.from_dataclass(obj, auto_batch_size=auto_batch_size) if _has_h5: diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index ffedba9ad..c38bd31de 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -1003,7 +1003,7 @@ def _write_to_tensordict( tensordict_out = tensordict for _out_key, _tensor in zip(out_keys, tensors): if _out_key != "_": - tensordict_out.set(_out_key, _tensor) + tensordict_out.set(_out_key, TensorDict.from_any(_tensor)) return tensordict_out def _call_module( diff --git a/tensordict/utils.py b/tensordict/utils.py index 81ab2fa0c..2344da517 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2824,3 +2824,36 @@ def _is_dataclass(obj): else type(obj) ) return hasattr(cls, _FIELDS) + + +def _is_list_tensor_compatible(t) -> Tuple[bool, tuple | None, type | None]: + length_t = len(t) + dtypes = set() + sizes = set() + for i in t: + if isinstance(i, (float, int, torch.SymInt, Number)): + dtypes.add(type(i)) + if len(dtypes) > 1: + return False, None, None + continue + elif isinstance(i, list): + is_compat, size_i, dtype = _is_list_tensor_compatible(i) + if not is_compat: + return False, None, None + if dtype is not None: + dtypes.add(dtype) + if len(dtypes) > 1: + return False, None, None + sizes.add(size_i) + if len(sizes) > 1: + return False, None, None + continue + return False, None + else: + if len(dtypes): + dtype = list(dtypes)[0] + else: + dtype = None + if len(sizes): + return True, (length_t, *list(sizes)[0]), dtype + return True, (length_t,), dtype diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 737ff4f24..27b53b686 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -997,6 +997,13 @@ class MyClass: td.keys(True, True) ).symmetric_difference(expected) + def test_from_any_list(self): + t = torch.randn(3, 4, 5) + t = t.tolist() + assert isinstance(TensorDict.from_any(t), torch.Tensor) + t[0].extend([0, 2]) + assert isinstance(TensorDict.from_any(t), TensorDict) + def test_from_any_userdict(self): class D(UserDict): ...