From 826d31d542f56f6c0827e73d654e255ce8263c1a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 10:10:33 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 2 +- test/test_tensordict.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index 007189df7..3c3ebe320 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -8679,7 +8679,7 @@ def _clone_recurse(self) -> TensorDictBase: # noqa: D417 nested_keys=True, is_leaf=_NESTED_TENSORS_AS_LISTS, propagate_lock=False, - filter_empty=True, + filter_empty=False, default=None, ) if items: diff --git a/test/test_tensordict.py b/test/test_tensordict.py index edf9eb580..a815ee79c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -340,6 +340,17 @@ def test_cat_from_tensordict(self): assert (tensor[:, :4] == 0).all() assert (tensor[:, 4:] == 1).all() + @pytest.mark.parametrize("recurse", [True, False]) + def test_clone_empty(self, recurse): + td = TensorDict() + assert td.clone(recurse=recurse) is not None + td = TensorDict(device="cpu") + assert td.clone(recurse=recurse) is not None + td = TensorDict(batch_size=[2]) + assert td.clone(recurse=recurse) is not None + td = TensorDict(device="cpu", batch_size=[2]) + assert td.clone(recurse=recurse) is not None + @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("num_threads", [0, 1, 2])