Skip to content

Commit

Permalink
[Feature] Extract primers from modules that contain RNNs (#2127)
Browse files Browse the repository at this point in the history
Co-authored-by: Vincent Moens <[email protected]>
Co-authored-by: Vincent Moens <[email protected]>
  • Loading branch information
3 people authored May 3, 2024
1 parent 6f1194b commit 7348af3
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ Utils
mappings
inv_softplus
biased_softplus
get_primers_from_module

.. currentmodule:: torchrl.modules

Expand Down
49 changes: 49 additions & 0 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
EnvCreator,
InitTracker,
SerialEnv,
TensorDictPrimer,
TransformedEnv,
)
from torchrl.envs.utils import set_exploration_type, step_mdp
Expand Down Expand Up @@ -52,6 +53,7 @@
SafeProbabilisticTensorDictSequential,
)
from torchrl.modules.tensordict_module.sequence import SafeSequential
from torchrl.modules.utils import get_primers_from_module
from torchrl.objectives import DDPGLoss

_has_functorch = False
Expand Down Expand Up @@ -1549,6 +1551,53 @@ def test_batched_actor_simple(self, time_steps):
).all()


def test_get_primers_from_module():

# No primers in the model
module = MLP(in_features=10, out_features=10, num_cells=[])
transform = get_primers_from_module(module)
assert transform is None

# 1 primer in the model
gru_module = GRUModule(
input_size=10,
hidden_size=10,
num_layers=1,
in_keys=["input", "gru_recurrent_state", "is_init"],
out_keys=["features", ("next", "gru_recurrent_state")],
)
transform = get_primers_from_module(gru_module)
assert isinstance(transform, TensorDictPrimer)
assert "gru_recurrent_state" in transform.primers

# 2 primers in the model
composed_model = TensorDictSequential(
gru_module,
LSTMModule(
input_size=10,
hidden_size=10,
num_layers=1,
in_keys=[
"input",
"lstm_recurrent_state_c",
"lstm_recurrent_state_h",
"is_init",
],
out_keys=[
"features",
("next", "lstm_recurrent_state_c"),
("next", "lstm_recurrent_state_h"),
],
),
)
transform = get_primers_from_module(composed_model)
assert isinstance(transform, Compose)
assert len(transform) == 2
assert "gru_recurrent_state" in transform[0].primers
assert "lstm_recurrent_state_c" in transform[1].primers
assert "lstm_recurrent_state_h" in transform[1].primers


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
18 changes: 13 additions & 5 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4511,6 +4511,11 @@ class TensorDictPrimer(Transform):
tensor([[1., 1., 1.],
[1., 1., 1.]])
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
To facilitate this process, the method :func:`~torchrl.models.utils.get_primers_from_module`
automatically checks for required primer transforms in a module and its submodules and
generates them.
"""

def __init__(
Expand Down Expand Up @@ -4696,15 +4701,18 @@ def _reset(
spec shape is assumed to match the tensordict's.
"""
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
_reset = _get_reset(self.reset_key, tensordict)
if _reset.any():
for key, spec in self.primers.items(True, True):
if spec.shape[: len(tensordict.batch_size)] != tensordict.batch_size:
expanded_spec = self._expand_shape(spec)
self.primers[key] = spec = expanded_spec
if self.random:
shape = (
()
if (not self.parent or self.parent.batch_locked)
else tensordict.batch_size
)
value = spec.rand(shape)
else:
value = self.default_value[key]
Expand Down
13 changes: 13 additions & 0 deletions torchrl/modules/tensordict_module/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ class LSTMModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.LSTMModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down Expand Up @@ -1059,6 +1066,12 @@ class GRUModule(ModuleBase):
set_recurrent_mode: controls whether the module should be executed in
recurrent mode.
.. note:: This module relies on specific ``recurrent_state`` keys being present in the input
TensorDicts. To generate a :class:`~torchrl.envs.transforms.TensorDictPrimer` transform that will automatically
add hidden states to the environment TensorDicts, use the method :func:`~torchrl.modules.rnn.GRUModule.make_tensordict_primer`.
If this class is a submodule in a larger module, the method :func:`~torchrl.models.utils.get_primers_from_module` can be called
on the parent module to automatically generate the primer transforms required for all submodules, including this one.
Examples:
>>> from torchrl.envs import TransformedEnv, InitTracker
>>> from torchrl.envs import GymEnv
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def __instancecheck__(self, instance):


from .mappings import biased_softplus, inv_softplus, mappings
from .utils import get_primers_from_module
73 changes: 73 additions & 0 deletions torchrl/modules/utils/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings


def get_primers_from_module(module):
"""Get all tensordict primers from all submodules of a module.
This method is useful for retrieving primers from modules that are contained within a
parent module.
Args:
module (torch.nn.Module): The parent module.
Returns:
TensorDictPrimer: A TensorDictPrimer Transform.
Example:
>>> from torchrl.modules.utils import get_primers_from_module
>>> from torchrl.modules import GRUModule, MLP
>>> from tensordict.nn import TensorDictModule, TensorDictSequential
>>> # Define a GRU module
>>> gru_module = GRUModule(
... input_size=10,
... hidden_size=10,
... num_layers=1,
... in_keys=["input", "recurrent_state", "is_init"],
... out_keys=["features", ("next", "recurrent_state")],
... )
>>> # Define a head module
>>> head = TensorDictModule(
... MLP(
... in_features=10,
... out_features=10,
... num_cells=[],
... ),
... in_keys=["features"],
... out_keys=["output"],
... )
>>> # Create a sequential model
>>> model = TensorDictSequential(gru_module, head)
>>> # Retrieve primers from the model
>>> primers = get_primers_from_module(model)
>>> print(primers)
TensorDictPrimer(primers=CompositeSpec(
recurrent_state: UnboundedContinuousTensorSpec(
shape=torch.Size([1, 10]),
space=None,
device=cpu,
dtype=torch.float32,
domain=continuous), device=None, shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
"""
primers = []

def make_primers(submodule):
if hasattr(submodule, "make_tensordict_primer"):
primers.append(submodule.make_tensordict_primer())

module.apply(make_primers)
if not primers:
warnings.warn("No primers found in the module.")
return
elif len(primers) == 1:
return primers[0]
else:
from torchrl.envs.transforms import Compose

return Compose(*primers)

0 comments on commit 7348af3

Please sign in to comment.