diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index c38bd31de..e6d150528 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -301,7 +301,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: tensordict = make_tensordict( tensordict_values, batch_size=batch_size, - auto_batch_size=False, + auto_batch_size=self.auto_batch_size, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/test/test_nn.py b/test/test_nn.py index 8c78dd3b2..3134932e9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -559,6 +559,25 @@ def forward(self, tensordict): with pytest.raises(RuntimeError, match="Duplicated argument"): module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) + @pytest.mark.parametrize("auto_batch_size", [True, False]) + def test_dispatch_auto_batch_size(self, auto_batch_size): + class MyModuleNest(nn.Module): + in_keys = [("a", "c"), "d"] + out_keys = ["b"] + + @dispatch(auto_batch_size=auto_batch_size) + def forward(self, tensordict): + if auto_batch_size: + assert tensordict.shape == (2, 3) + else: + assert tensordict.shape == () + tensordict["b"] = tensordict["a", "c"] + tensordict["d"] + return tensordict + + module = MyModuleNest() + b = module(torch.zeros(2, 3), d=torch.ones(2, 3)) + assert (b == 1).all() + def test_dispatch_nested_extra_args(self): class MyModuleNest(nn.Module): in_keys = [("a", "c"), "d"]