Skip to content

Commit

Permalink
[BugFix] auto-batch-size in dipatch
Browse files Browse the repository at this point in the history
ghstack-source-id: ca5b36195c28da65a20d42699346fbc06083181c
Pull Request resolved: #1109
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent 1ffc463 commit 2728dbf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
tensordict = make_tensordict(
tensordict_values,
batch_size=batch_size,
auto_batch_size=False,
auto_batch_size=self.auto_batch_size,
)
if _self is not None:
out = func(_self, tensordict, *args, **kwargs)
Expand Down
19 changes: 19 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,25 @@ def forward(self, tensordict):
with pytest.raises(RuntimeError, match="Duplicated argument"):
module(torch.zeros(1, 2), a_c=torch.ones(1, 2))

@pytest.mark.parametrize("auto_batch_size", [True, False])
def test_dispatch_auto_batch_size(self, auto_batch_size):
class MyModuleNest(nn.Module):
in_keys = [("a", "c"), "d"]
out_keys = ["b"]

@dispatch(auto_batch_size=auto_batch_size)
def forward(self, tensordict):
if auto_batch_size:
assert tensordict.shape == (2, 3)
else:
assert tensordict.shape == ()
tensordict["b"] = tensordict["a", "c"] + tensordict["d"]
return tensordict

module = MyModuleNest()
b = module(torch.zeros(2, 3), d=torch.ones(2, 3))
assert (b == 1).all()

def test_dispatch_nested_extra_args(self):
class MyModuleNest(nn.Module):
in_keys = [("a", "c"), "d"]
Expand Down
2 changes: 1 addition & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9865,7 +9865,7 @@ def check_weakref_count(weakref_list, expected):

@pytest.mark.skipif(
not torch.cuda.is_available(),
# and not torch.backends.mps.is_available(),
# and not torch.backends.mps.is_available(),
reason="a device is required.",
)
def test_cached_data_lock_device(self):
Expand Down

0 comments on commit 2728dbf

Please sign in to comment.