diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index e6d150528..eee4bb8ad 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -773,11 +773,19 @@ class TensorDictModule(TensorDictModuleBase): order given by the in_keys iterable. If ``in_keys`` is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of - the keyword argument in the function signature. + the keyword argument in the function signature. If `out_to_in_map` is True, + the mapping gets inverted so that the keys correspond to the keyword + arguments in the function signature. out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. Keyword Args: + out_to_in_map (bool, optional): if ``True``, `in_keys` is read as if the keys are the arguments keys of + the :meth:`~.forward` method and the values are the keys in the input :class:`~tensordict.TensorDict`. If + `False` or `None` (default), keys are considered to be the input keys and values the method's arguments keys. + + .. warning:: + The default value of `out_to_in_map` will change from `False` to `True` in the v0.9 release. inplace (bool or string, optional): if ``True`` (default), the output of the module are written in the tensordict provided to the :meth:`~.forward` method. If ``False``, a new :class:`~tensordict.TensorDict` with and empty batch-size and no device is created. if ``"empty"``, :meth:`~tensordict.TensorDict.empty` will be used to @@ -865,12 +873,24 @@ class TensorDictModule(TensorDictModuleBase): Examples: >>> module = TensorDictModule(lambda x, *, y: x+y, - ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], + ... in_keys={'1': 'x', '2': 'y'}, out_keys=['z'], out_to_in_map=False ... ) >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) >>> td['z'] tensor(3.) + If `out_to_in_map` is set to `True`, then the `in_keys` mapping is reversed. This way, + one can use the same input key for different keyword arguments. + + Examples: + >>> module = TensorDictModule(lambda x, *, y, z: x+y+z, + ... in_keys={'x': '1', 'y': '2', z: '2'}, out_keys=['t'], out_to_in_map=True + ... ) + >>> td = module(TensorDict({'1': torch.ones(()), '2': torch.ones(())*2}, [])) + >>> td['t'] + tensor(5.) + + Functional calls to a tensordict module is easy: Examples: @@ -923,17 +943,37 @@ def __init__( in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str], out_keys: NestedKey | List[NestedKey], *, + out_to_in_map: bool | None = None, inplace: bool | str = True, ) -> None: super().__init__() + if out_to_in_map is not None and not isinstance(in_keys, dict): + warnings.warn( + "out_to_in_map is not None but is only used when in_key is a dictionary." + ) + if isinstance(in_keys, dict): + if out_to_in_map is None: + warnings.warn( + "Using a dictionary in_keys without specifying out_to_in_map is deprecated." + "By default, out_to_in_map is False (`in_keys` keys as tensordict keys), but from " + "version>=0.9, default will be True (`in_keys` as func arg keys)." + "Please use explicit out_to_in_map to indicate the ordering of the input keys. ", + DeprecationWarning, + stacklevel=2, + ) + # write the kwargs and create a list instead _in_keys = [] self._kwargs = [] for key, value in in_keys.items(): - self._kwargs.append(value) - _in_keys.append(key) + if out_to_in_map: # arg: td_key + self._kwargs.append(key) + _in_keys.append(value) + else: # td_key: arg + self._kwargs.append(value) + _in_keys.append(key) in_keys = _in_keys else: if isinstance(in_keys, (str, tuple)): diff --git a/test/test_nn.py b/test/test_nn.py index cdbab76e9..510441207 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -157,7 +157,9 @@ def fn(a, b=None, *, c=None): return a + 1 if kwargs: - module = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"]) + module = TensorDictModule( + fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False + ) td = TensorDict( { "1": torch.ones(1), @@ -171,6 +173,76 @@ def fn(a, b=None, *, c=None): td = TensorDict({"1": torch.ones(1)}, []) assert (module(td)["a"] == 2).all() + def test_unused_out_to_in_map(self): + def fn(x, y): + return x + y + + with pytest.warns( + match="out_to_in_map is not None but is only used when in_key is a dictionary." + ): + _ = TensorDictModule(fn, in_keys=["x"], out_keys=["a"], out_to_in_map=False) + + def test_input_keys_dict_reversed(self): + in_keys = {"x": "1", "y": "2"} + + def fn(x, y): + return x + y + + module = TensorDictModule( + fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=True + ) + + td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, []) + assert (module(td)["a"] == 4).all() + + def test_input_keys_match_reversed(self): + in_keys = {"1": "x", "2": "y"} + reversed_in_keys = {v: k for k, v in in_keys.items()} + + def fn(x, y): + return y - x + + module = TensorDictModule( + fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=False + ) + reversed_module = TensorDictModule( + fn, in_keys=reversed_in_keys, out_keys=["a"], out_to_in_map=True + ) + + td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, []) + + assert module(td)["a"] == reversed_module(td)["a"] == torch.Tensor([2]) + + @pytest.mark.parametrize("out_to_in_map", [True, False]) + def test_input_keys_wrong_mapping(self, out_to_in_map): + in_keys = {"1": "x", "2": "y"} + if not out_to_in_map: + in_keys = {v: k for k, v in in_keys.items()} + + def fn(x, y): + return x + y + + module = TensorDictModule( + fn, in_keys=in_keys, out_keys=["a"], out_to_in_map=out_to_in_map + ) + + td = TensorDict({"1": torch.ones(1), "2": torch.ones(1) * 3}, []) + + with pytest.raises(TypeError, match="got an unexpected keyword argument '1'"): + module(td) + + def test_input_keys_dict_deprecated_warning(self): + in_keys = {"1": "x", "2": "y"} + + def fn(x, y): + return x + y + + with pytest.warns( + DeprecationWarning, + match="Using a dictionary in_keys without specifying out_to_in_map is deprecated.", + ): + _ = TensorDictModule(fn, in_keys=in_keys, out_keys=["a"]) + def test_reset(self): torch.manual_seed(0) net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())]) @@ -478,7 +550,10 @@ def test_functional_functorch(self): def test_vmap_kwargs(self): module = TensorDictModule( - lambda x, *, y: x + y, in_keys={"1": "x", "2": "y"}, out_keys=["z"] + lambda x, *, y: x + y, + in_keys={"1": "x", "2": "y"}, + out_keys=["z"], + out_to_in_map=False, ) td = TensorDict( {"1": torch.ones((10,)), "2": torch.ones((10,)) * 2}, batch_size=[10] @@ -723,7 +798,9 @@ def fn(a, b=None, *, c=None): return a + 1 if kwargs: - module1 = TensorDictModule(fn, in_keys=kwargs, out_keys=["a"]) + module1 = TensorDictModule( + fn, in_keys=kwargs, out_keys=["a"], out_to_in_map=False + ) td = TensorDict( { "input": torch.ones(1), @@ -1176,7 +1253,10 @@ def mycallable(): module, in_keys=[("i", "i2")], out_keys=[(("o", "o2"), ("o3",))] ) TensorDictModule( - module, in_keys={"i": "i1", (("i2",),): "i3"}, out_keys=[("o", "o2")] + module, + in_keys={"i": "i1", (("i2",),): "i3"}, + out_keys=[("o", "o2")], + out_to_in_map=False, ) # corner cases that should work