Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Nov 27, 2024
1 parent f3cbf49 commit 1c35bab
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,35 @@ def repeat_interleave(
Returns:
Repeated TensorDict which has the same shape as input, except along the given axis.
Examples:
>>> import torch
>>>
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... {
... "a": torch.randn(3, 4, 5),
... "b": TensorDict({
... "c": torch.randn(3, 4, 10, 1),
... "a string": "a string!",
... }, batch_size=[3, 4, 10])
... }, batch_size=[3, 4],
... )
>>> print(td.repeat_interleave(2, dim=0))
TensorDict(
fields={
a: Tensor(shape=torch.Size([6, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
a string: NonTensorData(data=a string!, batch_size=torch.Size([6, 4, 10]), device=None),
c: Tensor(shape=torch.Size([6, 4, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([6, 4, 10]),
device=None,
is_shared=False)},
batch_size=torch.Size([6, 4]),
device=None,
is_shared=False)
"""
...

Expand All @@ -2613,6 +2642,35 @@ def repeat(self, *repeats: int) -> TensorDictBase:
repeat (torch.Size, int..., tuple of int or list of int): The number of times to repeat this tensor along
each dimension.
Examples:
>>> import torch
>>>
>>> from tensordict import TensorDict
>>>
>>> td = TensorDict(
... {
... "a": torch.randn(3, 4, 5),
... "b": TensorDict({
... "c": torch.randn(3, 4, 10, 1),
... "a string": "a string!",
... }, batch_size=[3, 4, 10])
... }, batch_size=[3, 4],
... )
>>> print(td.repeat(1, 2))
TensorDict(
fields={
a: Tensor(shape=torch.Size([3, 8, 5]), device=cpu, dtype=torch.float32, is_shared=False),
b: TensorDict(
fields={
a string: NonTensorData(data=a string!, batch_size=torch.Size([3, 8, 10]), device=None),
c: Tensor(shape=torch.Size([3, 8, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([3, 8, 10]),
device=None,
is_shared=False)},
batch_size=torch.Size([3, 8]),
device=None,
is_shared=False)
"""
if len(repeats) == 1 and not isinstance(repeats[0], int):
repeats = repeats[0]
Expand Down

0 comments on commit 1c35bab

Please sign in to comment.