From 3267533d9e1e45039de233145cb12369f3620d4b Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sat, 3 Aug 2024 01:22:16 +0100 Subject: [PATCH] [Feature] Store MARL parameters in module (#2351) --- test/test_modules.py | 45 ++++++++++++++++++++++++++++ torchrl/modules/models/multiagent.py | 30 +++++++++++++++++-- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/test/test_modules.py b/test/test_modules.py index 11cf11f41e6..00e58678788 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -898,6 +898,51 @@ def one_outofplace(mod): mlp.from_stateful_net(snet) assert (mlp.params == 1).all() + @retry(AssertionError, 5) + @pytest.mark.parametrize("n_agents", [3]) + @pytest.mark.parametrize("share_params", [True]) + @pytest.mark.parametrize("centralized", [True]) + @pytest.mark.parametrize("n_agent_inputs", [6]) + @pytest.mark.parametrize("batch", [(4,)]) + @pytest.mark.parametrize("tdparams", [True, False]) + def test_multiagent_mlp_tdparams( + self, + n_agents, + centralized, + share_params, + batch, + n_agent_inputs, + tdparams, + n_agent_outputs=2, + ): + torch.manual_seed(1) + mlp = MultiAgentMLP( + n_agent_inputs=n_agent_inputs, + n_agent_outputs=n_agent_outputs, + n_agents=n_agents, + centralized=centralized, + share_params=share_params, + depth=2, + use_td_params=tdparams, + ) + if tdparams: + assert list(mlp._empty_net.parameters()) == [] + assert list(mlp.params.parameters()) == list(mlp.parameters()) + else: + assert list(mlp._empty_net.parameters()) == list(mlp.parameters()) + assert not hasattr(mlp.params, "parameters") + if torch.backends.mps.is_available(): + device = torch.device("mps") + elif torch.cuda.is_available(): + device = torch.device("cuda") + else: + return + mlp = nn.Sequential(mlp) + mlp_device = mlp.to(device) + param_set = set(mlp.parameters()) + for p in mlp[0].params.values(True, True): + assert p in param_set + def test_multiagent_mlp_lazy(self): mlp = MultiAgentMLP( n_agent_inputs=None, diff --git a/torchrl/modules/models/multiagent.py b/torchrl/modules/models/multiagent.py index 6ccc4721678..e352101ee55 100644 --- a/torchrl/modules/models/multiagent.py +++ b/torchrl/modules/models/multiagent.py @@ -40,6 +40,7 @@ def __init__( share_params: bool | None = None, agent_dim: int | None = None, vmap_randomness: str = "different", + use_td_params: bool = True, **kwargs, ): super().__init__() @@ -53,6 +54,7 @@ def __init__( if agent_dim is None: raise TypeError("agent_dim arg must be passed.") + self.use_td_params = use_td_params self.n_agents = n_agents self.share_params = share_params self.centralized = centralized @@ -70,6 +72,7 @@ def __init__( break self.initialized = initialized self._make_params(agent_networks) + # We make sure all params and buffers are on 'meta' device # To do this, we set the device keyword arg to 'meta', we also temporarily change # the default device. Finally, we convert all params to 'meta' tensors that are not params. @@ -87,6 +90,8 @@ def __init__( TensorDict.from_module(self._empty_net).data.to("meta").to_module( self._empty_net ) + if not self.use_td_params: + self.params.to_module(self._empty_net) @property def vmap_randomness(self): @@ -100,9 +105,13 @@ def vmap_randomness(self): def _make_params(self, agent_networks): if self.share_params: - self.params = TensorDict.from_module(agent_networks[0], as_module=True) + self.params = TensorDict.from_module( + agent_networks[0], as_module=self.use_td_params + ) else: - self.params = TensorDict.from_modules(*agent_networks, as_module=True) + self.params = TensorDict.from_modules( + *agent_networks, as_module=self.use_td_params + ) @abc.abstractmethod def _build_single_net(self, *, device, **kwargs): @@ -289,6 +298,8 @@ class MultiAgentMLP(MultiAgentNetBase): the number of inputs is lazily instantiated during the first call. n_agent_outputs (int): number of outputs for each agent. n_agents (int): number of agents. + + Keyword Args: centralized (bool): If `centralized` is True, each agent will use the inputs of all agents to compute its output (n_agent_inputs * n_agents will be the number of inputs for one agent). Otherwise, each agent will only use its data as input. @@ -307,6 +318,11 @@ class MultiAgentMLP(MultiAgentNetBase): default: 32. activation_class (Type[nn.Module]): activation class to be used. default: nn.Tanh. + use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a + :class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`). + If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches + should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with + ``use_td_params=True`` cannot be used when ``use_td_params=False``. **kwargs: for :class:`torchrl.modules.models.MLP` can be passed to customize the MLPs. .. note:: to initialize the MARL module parameters with the `torch.nn.init` @@ -399,12 +415,14 @@ def __init__( n_agent_inputs: int | None, n_agent_outputs: int, n_agents: int, + *, centralized: bool | None = None, share_params: bool | None = None, device: Optional[DEVICE_TYPING] = None, depth: Optional[int] = None, num_cells: Optional[Union[Sequence, int]] = None, activation_class: Optional[Type[nn.Module]] = nn.Tanh, + use_td_params: bool = True, **kwargs, ): self.n_agents = n_agents @@ -422,6 +440,7 @@ def __init__( share_params=share_params, device=device, agent_dim=-2, + use_td_params=use_td_params, **kwargs, ) @@ -483,6 +502,11 @@ class MultiAgentConvNet(MultiAgentNetBase): Defaults to ``2``. activation_class (Type[nn.Module]): activation class to be used. Default to :class:`torch.nn.ELU`. + use_td_params (bool, optional): if ``True``, the parameters can be found in `self.params` which is a + :class:`~tensordict.nn.TensorDictParams` object (which inherits both from `TensorDict` and `nn.Module`). + If ``False``, parameters are contained in `self._empty_net`. All things considered, these two approaches + should be roughly identical but not interchangeable: for instance, a ``state_dict`` created with + ``use_td_params=True`` cannot be used when ``use_td_params=False``. **kwargs: for :class:`~torchrl.modules.models.ConvNet` can be passed to customize the ConvNet. @@ -611,6 +635,7 @@ def __init__( strides: Union[Sequence, int] = 2, paddings: Union[Sequence, int] = 0, activation_class: Type[nn.Module] = nn.ELU, + use_td_params: bool = True, **kwargs, ): self.in_features = in_features @@ -625,6 +650,7 @@ def __init__( share_params=share_params, device=device, agent_dim=-4, + use_td_params=use_td_params, **kwargs, )