Skip to content

Commit

Permalink
feat(nn.common.TensorDictModule): support tuple values in in_keys for…
Browse files Browse the repository at this point in the history
… flexible input key dispatching
  • Loading branch information
bachdj-px committed Nov 25, 2024
1 parent e2444ed commit 198404e
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 6 deletions.
37 changes: 31 additions & 6 deletions tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,13 +767,15 @@ class TensorDictModule(TensorDictModuleBase):
will be used to populate the output tensordict (ie. the keys present
in ``out_keys`` should be present in the dictionary returned by the
``module`` forward method).
in_keys (iterable of NestedKeys, Dict[NestedStr, str]): keys to be read
in_keys (iterable of NestedKeys, Dict[NestedStr, str], Dict[NestedStr, List[str]]): keys to be read
from input tensordict and passed to the module. If it
contains more than one element, the values will be passed in the
order given by the in_keys iterable.
If ``in_keys`` is a dictionary, its keys must correspond to the key
to be read in the tensordict and its values must match the name of
the keyword argument in the function signature.
the keyword argument in the function signature. Multiple keyword arguments
can be passed as list to the values of the dictionary to dispatch
the same tensordict key to several arguments in the function signature.
out_keys (iterable of str): keys to be written to the input tensordict. The length of out_keys must match the
number of tensors returned by the embedded module. Using "_" as a key avoid writing tensor to output.
Expand Down Expand Up @@ -871,6 +873,19 @@ class TensorDictModule(TensorDictModuleBase):
>>> td['z']
tensor(3.)
Addionally, the dictionary can be passed to dispatch the same input key to several keyword arguments.
In the example below, `second` is passed to the `y` and `t` arguments in `module`.
Examples:
>>> module = TensorDictModule(lambda x, *, y, t : x+y+t,
... in_keys={'1': 'x', '2': ['y', 't']}, out_keys=['z'],
... )
>>> first = torch.ones(())
>>> second = torch.ones(())*2
>>> td = module(TensorDict({'1': first, '2': second}, []))
>>> td['z']
tensor(5.)
Functional calls to a tensordict module is easy:
Examples:
Expand Down Expand Up @@ -914,13 +929,18 @@ class TensorDictModule(TensorDictModuleBase):
"""

_IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict."
_IN_KEY_ERR = "in_keys must be of type list, str or tuples of str, or dict, or dict of str and list."
_OUT_KEY_ERR = "out_keys must be of type list, str or tuples of str."

def __init__(
self,
module: Callable,
in_keys: NestedKey | List[NestedKey] | Dict[NestedKey:str],
in_keys: (
NestedKey
| List[NestedKey]
| Dict[NestedKey:str]
| Dict[NestedKey:str, List[str]]
),
out_keys: NestedKey | List[NestedKey],
*,
inplace: bool | str = True,
Expand All @@ -932,8 +952,13 @@ def __init__(
_in_keys = []
self._kwargs = []
for key, value in in_keys.items():
self._kwargs.append(value)
_in_keys.append(key)
if isinstance(value, list):
for _value in value:
self._kwargs.append(_value)
_in_keys.append(key)
else:
self._kwargs.append(value)
_in_keys.append(key)
in_keys = _in_keys
else:
if isinstance(in_keys, (str, tuple)):
Expand Down
39 changes: 39 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,45 @@ def fn(a, b=None, *, c=None):
td = TensorDict({"1": torch.ones(1)}, [])
assert (module(td)["a"] == 2).all()

def test_list_input_keys_tensordict_module(self):
in_keys = {"1": "x", "2": ["y", "z"]}

def fn(x, y, z):
return x + y + z

module = TensorDictModule(fn, in_keys=in_keys, out_keys=["out"])

kword_in_keys = sorted(zip(module._kwargs, module.in_keys))
assert kword_in_keys == [("x", "1"), ("y", "2"), ("z", "2")]

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1)}, [])
out_td = module(td)
assert set(out_td.keys()) == {"1", "2", "out"}
assert out_td["out"] == 3 * torch.ones(1)

def test_list_input_keys_prob_tensordict_module(self):
def fn(x, y, z):
return torch.Tensor([x + y + z]).repeat(2, 2)

in_keys = {"1": "x", "2": ["y", "z"]}
dist_in_keys = ["loc", "scale"]
module = TensorDictModule(fn, in_keys=in_keys, out_keys=["out"])
normal_params = TensorDictModule(
NormalParamExtractor(), in_keys=["out"], out_keys=dist_in_keys
)

kwargs = {"distribution_class": Normal, "default_interaction_type": InteractionType.MODE}
prob_module = ProbabilisticTensorDictModule(
in_keys=dist_in_keys, out_keys=["action"], **kwargs
)

td_module = ProbabilisticTensorDictSequential(
module, normal_params, prob_module
)

td = TensorDict({"1": torch.ones(1), "2": torch.ones(1)}, [])
assert (td_module(td)["action"] == 3).all()

def test_reset(self):
torch.manual_seed(0)
net = nn.ModuleList([nn.Sequential(nn.Linear(1, 1), nn.ReLU())])
Expand Down

0 comments on commit 198404e

Please sign in to comment.