Skip to content

Commit

Permalink
feat(nn.common.TensorDictModule): support tuple values in in_keys for…
Browse files Browse the repository at this point in the history
… flexible input key dispatching
  • Loading branch information
bachdj-px committed Nov 28, 2024
1 parent 004f979 commit c9e5be5
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 8 deletions.
48 changes: 44 additions & 4 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,11 +773,19 @@ class TensorDictModule(TensorDictModuleBase):
order given by the in_keys iterable.
If ``in_keys`` is a dictionary, its keys must correspond to the key
to be read in the tensordict and its values must match the name of
the keyword argument in the function signature.
the keyword argument in the function signature. If `out_to_in_map` is True,
the mapping gets inverted so that the keys correspond to the keyword
arguments in the function signature.
out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output.
Keyword Args:
out_to_in_map (bool, optional): if ``True``, `in_keys` is read as if the keys are the arguments keys of
the :meth:`~.forward` method and the values are the keys in the input :class:`~tensordict.TensorDict`. If
`False` or `None` (default), keys are considered to be the input keys and values the method's arguments keys.
.. warning::
The default value of `out_to_in_map` will change from `False` to `True` in the v0.9 release.
inplace (bool or string, optional): if ``True`` (default), the output of the module are written in the tensordict
provided to the :meth:`~.forward` method. If ``False``, a new :class:`~tensordict.TensorDict` with and empty
batch-size and no device is created. if ``"empty"``, :meth:`~tensordict.TensorDict.empty` will be used to
Expand Down Expand Up @@ -865,12 +873,24 @@ class TensorDictModule(TensorDictModuleBase):
Examples:
>>> module = TensorDictModule(lambda x, *, y: x+y,
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'],
... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['z']
tensor(3.)
If `out_to_in_map` is set to `True`, then the `in_keys` mapping is reversed. This way,
one can use the same input key for different keyword arguments.
Examples:
>>> module = TensorDictModule(lambda x, *, y, z: x+y+z,
... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True
... )
>>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, []))
>>> td['t']
tensor(5.)
Functional calls to a tensordict module is easy:
Examples:
Expand Down Expand Up @@ -923,17 +943,37 @@ def __init__(
in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str],
out_keys: NestedKey | List[NestedKey],
*,
out_to_in_map: bool | None = None,
inplace: bool | str = True,
) -> None:
super().__init__()

if out_to_in_map is not None and not isinstance(in_keys, dict):
warnings.warn(
"out_to_in_map is not None but is only used when in_key is a dictionary."
)

if isinstance(in_keys, dict):
if out_to_in_map is None:
warnings.warn(
"Using a dictionary in_keys without specifying out_to_in_map is deprecated."
"By default, out_to_in_map is False (`in_keys` keys as tensordict keys), but from "
"version>=0.9, default will be True (`in_keys` as func arg keys)."
"Please use explicit out_to_in_map to indicate the ordering of the input keys. ",
DeprecationWarning,
stacklevel=2,
)

# write the kwargs and create a list instead
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
if out_to_in_map: # arg: td_key
self._kwargs.append(key)
_in_keys.append(value)
else: # td_key: arg
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
Expand Down
88 changes: 84 additions & 4 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"1": torch.ones(1),
Expand All @@ -171,6 +173,76 @@ def fn(a, b=None, *, c=None):
td = TensorDict({"1": torch.ones(1)}, [])
assert (module(td)["a"] == 2).all()

def test_unused_out_to_in_map(self):
def fn(x, y):
return x + y

with pytest.warns(
match="out_to_in_map is not None but is only used when in_key is a dictionary."
):
_ = TensorDictModule(fn, in_keys=["x"], out_keys=["a"], out_to_in_map=False)

def test_input_keys_dict_reversed(self):
in_keys = {"x": "1", "y": "2"}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])
assert (module(td)["a"] == 4).all()

def test_input_keys_match_reversed(self):
in_keys = {"1": "x", "2": "y"}
reversed_in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return y - x

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=False
)
reversed_module = TensorDictModule(
fn, in_keys=reversed_in_keys, out_keys=["a"], out_to_in_map=True
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

assert module(td)["a"] == reversed_module(td)["a"] == torch.Tensor([2])

@pytest.mark.parametrize("out_to_in_map", [True, False])
def test_input_keys_wrong_mapping(self, out_to_in_map):
in_keys = {"1": "x", "2": "y"}
if not out_to_in_map:
in_keys = {v: k for k, v in in_keys.items()}

def fn(x, y):
return x + y

module = TensorDictModule(
fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=out_to_in_map
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, [])

with pytest.raises(TypeError, match="got an unexpected keyword argument '1'"):
module(td)

def test_input_keys_dict_deprecated_warning(self):
in_keys = {"1": "x", "2": "y"}

def fn(x, y):
return x + y

with pytest.warns(
DeprecationWarning,
match="Using a dictionary in_keys without specifying out_to_in_map is deprecated.",
):
_ = TensorDictModule(fn, in_keys=in_keys, out_keys=["a"])

def test_reset(self):
torch.manual_seed(0)
net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())])
Expand Down Expand Up @@ -478,7 +550,10 @@ def test_functional_functorch(self):

def test_vmap_kwargs(self):
module = TensorDictModule(
lambda x, *, y: x + y, in_keys={"1": "x", "2": "y"}, out_keys=["z"]
lambda x, *, y: x + y,
in_keys={"1": "x", "2": "y"},
out_keys=["z"],
out_to_in_map=False,
)
td = TensorDict(
{"1": torch.ones((10,)), "2": torch.ones((10,)) * 2}, batch_size=[10]
Expand Down Expand Up @@ -723,7 +798,9 @@ def fn(a, b=None, *, c=None):
return a + 1

if kwargs:
module1 = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"])
module1 = TensorDictModule(
fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False
)
td = TensorDict(
{
"input": torch.ones(1),
Expand Down Expand Up @@ -1176,7 +1253,10 @@ def mycallable():
module, in_keys=[("i", "i2")], out_keys=[(("o", "o2"), ("o3",))]
)
TensorDictModule(
module, in_keys={"i": "i1", (("i2",),): "i3"}, out_keys=[("o", "o2")]
module,
in_keys={"i": "i1", (("i2",),): "i3"},
out_keys=[("o", "o2")],
out_to_in_map=False,
)

# corner cases that should work
Expand Down

0 comments on commit c9e5be5

Please sign in to comment.