From a9ed3227deb782bb2986d6e0d9764ee7397e4e14 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 25 Nov 2024 08:51:05 +0000 Subject: [PATCH] [BugFix] auto-batch-size in dipatch ghstack-source-id: ca5b36195c28da65a20d42699346fbc06083181c Pull Request resolved: https://github.com/pytorch/tensordict/pull/1109 --- tensordict/nn/common.py | 2 +- test/test_nn.py | 19 +++++++++++++++++++ test/test_tensordict.py | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index c38bd31de..e6d150528 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -301,7 +301,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: tensordict = make_tensordict( tensordict_values, batch_size=batch_size, - auto_batch_size=False, + auto_batch_size=self.auto_batch_size, ) if _self is not None: out = func(_self, tensordict, *args, **kwargs) diff --git a/test/test_nn.py b/test/test_nn.py index 8c78dd3b2..3134932e9 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -559,6 +559,25 @@ def forward(self, tensordict): with pytest.raises(RuntimeError, match="Duplicated argument"): module(torch.zeros(1, 2), a_c=torch.ones(1, 2)) + @pytest.mark.parametrize("auto_batch_size", [True, False]) + def test_dispatch_auto_batch_size(self, auto_batch_size): + class MyModuleNest(nn.Module): + in_keys = [("a", "c"), "d"] + out_keys = ["b"] + + @dispatch(auto_batch_size=auto_batch_size) + def forward(self, tensordict): + if auto_batch_size: + assert tensordict.shape == (2, 3) + else: + assert tensordict.shape == () + tensordict["b"] = tensordict["a", "c"] + tensordict["d"] + return tensordict + + module = MyModuleNest() + b = module(torch.zeros(2, 3), d=torch.ones(2, 3)) + assert (b == 1).all() + def test_dispatch_nested_extra_args(self): class MyModuleNest(nn.Module): in_keys = [("a", "c"), "d"] diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 27b53b686..63ebc8935 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -9865,7 +9865,7 @@ def check_weakref_count(weakref_list, expected): @pytest.mark.skipif( not torch.cuda.is_available(), - # and not torch.backends.mps.is_available(), + # and not torch.backends.mps.is_available(), reason="a device is required.", ) def test_cached_data_lock_device(self):