From 1c35bab61363384894bf6d9eba3215f5378209c7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 27 Nov 2024 14:13:03 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/base.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index f96ef2a08..8b86cb250 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2595,6 +2595,35 @@ def repeat_interleave( Returns: Repeated TensorDict which has the same shape as input, except along the given axis. + Examples: + >>> import torch + >>> + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... { + ... "a": torch.randn(3, 4, 5), + ... "b": TensorDict({ + ... "c": torch.randn(3, 4, 10, 1), + ... "a string": "a string!", + ... }, batch_size=[3, 4, 10]) + ... }, batch_size=[3, 4], + ... ) + >>> print(td.repeat_interleave(2, dim=0)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + a string: NonTensorData(data=a string!, batch_size=torch.Size([6, 4, 10]), device=None), + c: Tensor(shape=torch.Size([6, 4, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([6, 4, 10]), + device=None, + is_shared=False)}, + batch_size=torch.Size([6, 4]), + device=None, + is_shared=False) + """ ... @@ -2613,6 +2642,35 @@ def repeat(self, *repeats: int) -> TensorDictBase: repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along each dimension. + Examples: + >>> import torch + >>> + >>> from tensordict import TensorDict + >>> + >>> td = TensorDict( + ... { + ... "a": torch.randn(3, 4, 5), + ... "b": TensorDict({ + ... "c": torch.randn(3, 4, 10, 1), + ... "a string": "a string!", + ... }, batch_size=[3, 4, 10]) + ... }, batch_size=[3, 4], + ... ) + >>> print(td.repeat(1, 2)) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 8, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + a string: NonTensorData(data=a string!, batch_size=torch.Size([3, 8, 10]), device=None), + c: Tensor(shape=torch.Size([3, 8, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 8, 10]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 8]), + device=None, + is_shared=False) + """ if len(repeats) == 1 and not isinstance(repeats[0], int): repeats = repeats[0]