diff --git a/tensordict/_td.py b/tensordict/_td.py index deec43b61..1730e9200 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -945,12 +945,18 @@ def _cast_reduction( ) for val in agglomerate ] + cat_dim = -1 dim = -1 keepdim = False - agglomerate = torch.cat(agglomerate, dim=-1) - return getattr(torch, reduction_name)( - agglomerate, keepdim=keepdim, dim=dim - ) + elif isinstance(dim, tuple): + cat_dim = dim[0] + else: + cat_dim = dim + agglomerate = torch.cat(agglomerate, dim=cat_dim) + kwargs = {} + if keepdim is not NO_DEFAULT: + kwargs["keepdim"] = keepdim + return getattr(torch, reduction_name)(agglomerate, dim=dim, **kwargs) # IMPORTANT: do not directly access batch_dims (or any other property) # via self.batch_dims otherwise a reference cycle is introduced diff --git a/tensordict/base.py b/tensordict/base.py index fec8f9a79..b8712abd1 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -628,6 +628,66 @@ def min( when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.min(dim=0) + min( + indices=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.int64, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + vals=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.min() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.min(reduce=True) + tensor(-2.9953) + """ result = self._cast_reduction( reduction_name="min", @@ -702,6 +762,66 @@ def max( when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.max(dim=0) + max( + indices=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.int64, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + vals=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.max() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.max(reduce=True) + tensor(3.2942) + """ result = self._cast_reduction( reduction_name="max", @@ -749,6 +869,61 @@ def cummin( when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.cummin(dim=0) + cummin( + indices=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.int64, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + vals=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.cummin(reduce=True, dim=0) + torch.return_types.cummin(...) + """ result = self._cast_reduction( reduction_name="cummin", @@ -759,6 +934,8 @@ def cummin( call_on_nested=False, batch_size=self.batch_size, ) + if isinstance(result, (torch.Tensor, torch.return_types.cummin)): + return result if dim is not NO_DEFAULT and return_indices: # Split the tensordict from .return_types import cummin @@ -796,6 +973,61 @@ def cummax( when the ``dim`` argument is passed. The ``TensorDict`` equivalent of this is to return a tensorclass with entries ``"values"`` and ``"indices"`` with idendical structure within. Defaults to ``True``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.cummax(dim=0) + cummax( + indices=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.int64, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + vals=TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False), + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.cummax(reduce=True, dim=0) + torch.return_types.cummax(...) + """ result = self._cast_reduction( reduction_name="cummax", @@ -806,6 +1038,8 @@ def cummax( call_on_nested=False, batch_size=self.batch_size, ) + if isinstance(result, (torch.Tensor, torch.return_types.cummin)): + return result if dim is not NO_DEFAULT and return_indices: # Split the tensordict from .return_types import cummax @@ -853,6 +1087,82 @@ def mean( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.mean(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.mean() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.mean(reduce=True) + tensor(-0.0547) + >>> td.mean(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.mean(reduce=True, dim="feature") + tensor([[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]) + >>> td.mean(reduce=True, dim=0) + tensor([[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]) + + """ # if dim is NO_DEFAULT and not keepdim: # dim = None @@ -895,6 +1205,81 @@ def nanmean( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.nanmean(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.nanmean() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.nanmean(reduce=True) + tensor(-0.0547) + >>> td.nanmean(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.nanmean(reduce=True, dim="feature") + tensor([[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]) + >>> td.nanmean(reduce=True, dim=0) + tensor([[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]) + """ return self._cast_reduction( reduction_name="nanmean", @@ -934,6 +1319,81 @@ def prod( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.prod(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.prod() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.prod(reduce=True) + tensor(-0.) + >>> td.prod(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.prod(reduce=True, dim="feature") + tensor([[1., 1., 1., 1.], + [1., 1., 1., 1.], + [1., 1., 1., 1.]]) + >>> td.prod(reduce=True, dim=0) + tensor([[1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.], + [1., 1., 1., 1., 1.]]) + """ result = self._cast_reduction( reduction_name="prod", @@ -982,6 +1442,81 @@ def sum( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.sum(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.sum() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.sum(reduce=True) + tensor(-0.) + >>> td.sum(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.sum(reduce=True, dim="feature") + tensor([[15., 15., 15., 15.], + [15., 15., 15., 15.], + [15., 15., 15., 15.]]) + >>> td.sum(reduce=True, dim=0) + tensor([[9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.]]) + """ return self._cast_reduction( reduction_name="sum", @@ -1021,6 +1556,81 @@ def nansum( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.nansum(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.nansum() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.nansum(reduce=True) + tensor(-0.) + >>> td.nansum(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.nansum(reduce=True, dim="feature") + tensor([[15., 15., 15., 15.], + [15., 15., 15., 15.], + [15., 15., 15., 15.]]) + >>> td.nansum(reduce=True, dim=0) + tensor([[9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.], + [9., 9., 9., 9., 9.]]) + """ return self._cast_reduction( reduction_name="nansum", @@ -1059,6 +1669,81 @@ def std( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.std(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.std() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.std(reduce=True) + tensor(1.0006) + >>> td.std(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.std(reduce=True, dim="feature") + tensor([[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) + >>> td.std(reduce=True, dim=0) + tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ return self._cast_reduction( reduction_name="std", @@ -1097,6 +1782,81 @@ def var( and a single reduced tensor will be returned. Defaults to ``False``. + Examples: + >>> from tensordict import TensorDict + >>> import torch + >>> td = TensorDict( + ... a=torch.randn(3, 4, 5), + ... b=TensorDict( + ... c=torch.randn(3, 4, 5, 6), + ... d=torch.randn(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.var(dim=0) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([4, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([4]), + device=None, + is_shared=False) + >>> td.var() + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> td.var(reduce=True) + tensor(1.0006) + >>> td.var(dim="feature") + TensorDict( + fields={ + a: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), + b: TensorDict( + fields={ + c: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False), + d: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([3, 4, 5]), + device=None, + is_shared=False)}, + batch_size=torch.Size([3, 4]), + device=None, + is_shared=False) + >>> td = TensorDict( + ... a=torch.ones(3, 4, 5), + ... b=TensorDict( + ... c=torch.ones(3, 4, 5), + ... d=torch.ones(3, 4, 5), + ... batch_size=(3, 4, 5), + ... ), + ... batch_size=(3, 4) + ... ) + >>> td.var(reduce=True, dim="feature") + tensor([[0., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]) + >>> td.var(reduce=True, dim=0) + tensor([[0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.], + [0., 0., 0., 0., 0.]]) + """ return self._cast_reduction( reduction_name="var", diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 6edeb3064..f613e6a13 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2673,6 +2673,18 @@ def test_reduction_feature_full(self, reduction): reduced = getattr(td, reduction)(dim="feature", reduce=True) assert reduced.shape == (3, 4) + td = TensorDict( + a=torch.ones(3, 4, 5), + b=TensorDict( + c=torch.ones(3, 4, 5), + d=torch.ones(3, 4, 5), + batch_size=(3, 4, 5), + ), + batch_size=(3, 4), + ) + assert getattr(td, reduction)(reduce=True, dim="feature").shape == (3, 4) + assert getattr(td, reduction)(reduce=True, dim=1).shape == (3, 5) + @pytest.mark.parametrize("device", get_available_devices()) def test_subtensordict_construction(self, device): torch.manual_seed(1)