Skip to content

Commit

Permalink
[BugFix] auto-batch-size in dipatch
Browse files Browse the repository at this point in the history
ghstack-source-id: fe5f7f5b04d08d0eb150ee9bf0fd4698171d43d2
Pull Request resolved: #1109
  • Loading branch information
vmoens committed Nov 25, 2024
1 parent fa8a521 commit 5ef8aed
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
tensordict = make_tensordict(
tensordict_values,
batch_size=batch_size,
auto_batch_size=False,
auto_batch_size=self.auto_batch_size,
)
if _self is not None:
out = func(_self, tensordict, *args, **kwargs)
Expand Down
19 changes: 19 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,25 @@ def forward(self, tensordict):
with pytest.raises(RuntimeError, match="Duplicated argument"):
module(torch.zeros(1, 2), a_c=torch.ones(1, 2))

@pytest.mark.parametrize("auto_batch_size", [True, False])
def test_dispatch_auto_batch_size(self, auto_batch_size):
class MyModuleNest(nn.Module):
in_keys = [("a", "c"), "d"]
out_keys = ["b"]

@dispatch(auto_batch_size=auto_batch_size)
def forward(self, tensordict):
if auto_batch_size:
assert tensordict.shape == (2, 3)
else:
assert tensordict.shape == ()
tensordict["b"] = tensordict["a", "c"] + tensordict["d"]
return tensordict

module = MyModuleNest()
b = module(torch.zeros(2, 3), d=torch.ones(2, 3))
assert (b == 1).all()

def test_dispatch_nested_extra_args(self):
class MyModuleNest(nn.Module):
in_keys = [("a", "c"), "d"]
Expand Down

0 comments on commit 5ef8aed

Please sign in to comment.