diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index e6d150528..b4f4962ec 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -767,13 +767,15 @@ class TensorDictModule(TensorDictModuleBase): will be used to populate the output tensordict (ie. the keys present in ``out_keys`` should be present in the dictionary returned by the ``module`` forward method). - in_keys (iterable of NestedKeys, Dict[NestedStr, str]): keys to be read + in_keys (iterable of NestedKeys, Dict[NestedStr, str], Dict[NestedStr, List[str]]): keys to be read from input tensordict and passed to the module. If it contains more than one element, the values will be passed in the order given by the in_keys iterable. If ``in_keys`` is a dictionary, its keys must correspond to the key to be read in the tensordict and its values must match the name of - the keyword argument in the function signature. + the keyword argument in the function signature. Multiple keyword arguments + can be passed as list to the values of the dictionary to dispatch + the same tensordict key to several arguments in the function signature. out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output. @@ -871,6 +873,19 @@ class TensorDictModule(TensorDictModuleBase): >>> td['z'] tensor(3.) + Addionally, the dictionary can be passed to dispatch the same input key to several keyword arguments. + In the example below, `second` is passed to the `y` and `t` arguments in `module`. + + Examples: + >>> module = TensorDictModule(lambda x, *, y, t : x+y+t, + ... in_keys={'1': 'x', '2': ['y', 't']}, out_keys=['z'], + ... ) + >>> first = torch.ones(()) + >>> second = torch.ones(())*2 + >>> td = module(TensorDict({'1': first, '2': second}, [])) + >>> td['z'] + tensor(5.) + Functional calls to a tensordict module is easy: Examples: @@ -914,13 +929,18 @@ class TensorDictModule(TensorDictModuleBase): """ - _IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict." + _IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict, or dict of str and list." _OUT_KEY_ERR = "out_keys must be of type list, str or tuples of str." def __init__( self, module: Callable, - in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str], + in_keys: ( + NestedKey + | List[NestedKey] + | Dict[NestedKey:str] + | Dict[NestedKey:str, List[str]] + ), out_keys: NestedKey | List[NestedKey], *, inplace: bool | str = True, @@ -932,8 +952,13 @@ def __init__( _in_keys = [] self._kwargs = [] for key, value in in_keys.items(): - self._kwargs.append(value) - _in_keys.append(key) + if isinstance(value, list): + for _value in value: + self._kwargs.append(_value) + _in_keys.append(key) + else: + self._kwargs.append(value) + _in_keys.append(key) in_keys = _in_keys else: if isinstance(in_keys, (str, tuple)): diff --git a/test/test_nn.py b/test/test_nn.py index 3134932e9..b89f11435 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -161,6 +161,45 @@ def fn(a, b=None, *, c=None): td = TensorDict({"1": torch.ones(1)}, []) assert (module(td)["a"] == 2).all() + def test_list_input_keys_tensordict_module(self): + in_keys = {"1": "x", "2": ["y", "z"]} + + def fn(x, y, z): + return x + y + z + + module = TensorDictModule(fn, in_keys=in_keys, out_keys=["out"]) + + kword_in_keys = sorted(zip(module._kwargs, module.in_keys)) + assert kword_in_keys == [("x", "1"), ("y", "2"), ("z", "2")] + + td = TensorDict({"1": torch.ones(1), "2": torch.ones(1)}, []) + out_td = module(td) + assert set(out_td.keys()) == {"1", "2", "out"} + assert out_td["out"] == 3 * torch.ones(1) + + def test_list_input_keys_prob_tensordict_module(self): + def fn(x, y, z): + return torch.Tensor([x + y + z]).repeat(2, 2) + + in_keys = {"1": "x", "2": ["y", "z"]} + dist_in_keys = ["loc", "scale"] + module = TensorDictModule(fn, in_keys=in_keys, out_keys=["out"]) + normal_params = TensorDictModule( + NormalParamExtractor(), in_keys=["out"], out_keys=dist_in_keys + ) + + kwargs = {"distribution_class": Normal, "default_interaction_type": InteractionType.MODE} + prob_module = ProbabilisticTensorDictModule( + in_keys=dist_in_keys, out_keys=["action"], **kwargs + ) + + td_module = ProbabilisticTensorDictSequential( + module, normal_params, prob_module + ) + + td = TensorDict({"1": torch.ones(1), "2": torch.ones(1)}, []) + assert (td_module(td)["action"] == 3).all() + def test_reset(self): torch.manual_seed(0) net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())])