From b3ef4ee74adee32f1866f43484a83a652ee35177 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 11:12:16 +0200 Subject: [PATCH 01/59] meltingpot --- torchrl/envs/libs/dm_control.py | 25 ++- torchrl/envs/libs/meltingpot.py | 383 ++++++++++++++++++++++++++++++++ torchrl/envs/libs/vmas.py | 2 +- 3 files changed, 406 insertions(+), 4 deletions(-) create mode 100644 torchrl/envs/libs/meltingpot.py diff --git a/torchrl/envs/libs/dm_control.py b/torchrl/envs/libs/dm_control.py index 9293dd195a0..3e1aac917e0 100644 --- a/torchrl/envs/libs/dm_control.py +++ b/torchrl/envs/libs/dm_control.py @@ -19,6 +19,7 @@ BoundedTensorSpec, CompositeSpec, DiscreteTensorSpec, + OneHotDiscreteTensorSpec, TensorSpec, UnboundedContinuousTensorSpec, UnboundedDiscreteTensorSpec, @@ -43,15 +44,34 @@ def _dmcontrol_to_torchrl_spec_transform( spec, dtype: Optional[torch.dtype] = None, device: DEVICE_TYPING = None, + categorical_discrete_encoding: bool = False, ) -> TensorSpec: import dm_env - if isinstance(spec, collections.OrderedDict): + if isinstance(spec, collections.OrderedDict) or isinstance(spec, Dict): spec = { - k: _dmcontrol_to_torchrl_spec_transform(item, device=device) + k: _dmcontrol_to_torchrl_spec_transform( + item, + device=device, + categorical_discrete_encoding=categorical_discrete_encoding, + ) for k, item in spec.items() } return CompositeSpec(**spec) + elif isinstance(spec, dm_env.specs.DiscreteArray): + # DiscreteArray is a type of BoundedArray so this block needs to go first + action_space_cls = ( + DiscreteTensorSpec + if categorical_discrete_encoding + else OneHotDiscreteTensorSpec + ) + if dtype is None: + dtype = ( + numpy_to_torch_dtype_dict[spec.dtype] + if categorical_discrete_encoding + else torch.long + ) + return action_space_cls(spec.num_values, device=device, dtype=dtype) elif isinstance(spec, dm_env.specs.BoundedArray): if dtype is None: dtype = numpy_to_torch_dtype_dict[spec.dtype] @@ -77,7 +97,6 @@ def _dmcontrol_to_torchrl_spec_transform( ) else: return UnboundedDiscreteTensorSpec(shape=shape, dtype=dtype, device=device) - else: raise NotImplementedError(type(spec)) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py new file mode 100644 index 00000000000..db949f17409 --- /dev/null +++ b/torchrl/envs/libs/meltingpot.py @@ -0,0 +1,383 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import importlib + +from typing import Any, Dict, List, Mapping, Optional, Sequence + +import dm_env + +import torch + +from tensordict import TensorDictBase + +from torchrl.data import CompositeSpec, DiscreteTensorSpec, TensorSpec +from torchrl.envs.common import _EnvWrapper +from torchrl.envs.libs.dm_control import _dmcontrol_to_torchrl_spec_transform +from torchrl.envs.utils import _classproperty, check_marl_grouping, MarlGroupMapType + +_has_meltingpot = importlib.util.find_spec("meltingpot") is not None + +PLAYER_STR_FORMAT = "player_{index}" +_WORLD_PREFIX = "WORLD." + + +class MeltingpotWrapper(_EnvWrapper): + """Meltingpot environment wrapper. + + GitHub: https://github.com/google-deepmind/meltingpot + + Paper: https://arxiv.org/abs/2211.13746 + + Args: + env (``meltingpot.utils.substrates.substrate.Substrate``): the meltingpot substrate to wrap. + + Keyword Args: + max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). + Each Meltingpot substrate can + be terminating or not. If ``max_steps`` is specified, + the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached. + Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`, + this argument will not set the ``"truncated"`` entry in the tensordict. + categorical_actions (bool, optional): if the environment actions are discrete, whether to transform + them to categorical or one-hot. Defaults to ``True``. + group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for + input/output. By default, they will be all put + in one group named ``"agents"``. + Otherwise, a group map can be specified or selected from some premade options. + See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + + Attributes: + group_map (Dict[str, List[str]]): how to group agents in tensordicts for + input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + agent_names (list of str): names of the agent in the environment + agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment + available_envs (List[str]): the list of the scenarios available to build. + + .. warning:: + Meltingpot returns a single ``done`` flag which does not distinguish between + when the env reached ``max_steps`` and termination. + If you deem the ``truncation`` signal necessary, set ``max_steps`` to + ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. + + """ + + git_url = "https://github.com/google-deepmind/meltingpot" + libname = "melitingpot" + + @property + def lib(self): + import meltingpot + + return meltingpot + + @_classproperty + def available_envs(cls): + if not _has_meltingpot: + return [] + from meltingpot.substrate import SUBSTRATES + + return list(SUBSTRATES) + + def __init__( + self, + env: "meltingpot.utils.substrates.substrate.Substrate" = None, # noqa + categorical_actions: bool = True, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + max_steps: int = None, + **kwargs, + ): + if env is not None: + kwargs["env"] = env + self.group_map = group_map + self.categorical_actions = categorical_actions + self.max_steps = max_steps + super().__init__(**kwargs) + + def _build_env( + self, + env: "meltingpot.utils.substrates.substrate.Substrate", # noqa + ): + return env + + def _make_group_map(self): + if isinstance(self.group_map, MarlGroupMapType): + self.group_map = self.group_map.get_group_map(self.agent_names) + check_marl_grouping(self.group_map, self.agent_names) + + def _make_specs( + self, env: "meltingpot.utils.substrates.substrate.Substrate" # noqa + ) -> None: + mp_obs_spec = self._env.observation_spec() # List of dict of arrays + mp_obs_spec_no_world = _remove_world_observations_from_obs_spec( + mp_obs_spec + ) # List of dict of arrays + mp_global_state_spec = _global_state_spec_from_obs_spec( + mp_obs_spec + ) # Dict of arrays + mp_act_spec = self._env.action_spec() # List of discrete arrays + mp_rew_spec = self._env.reward_spec() # List of arrays + + torchrl_agent_obs_specs = [ + _dmcontrol_to_torchrl_spec_transform(agent_obs_spec) + for agent_obs_spec in mp_obs_spec_no_world + ] + torchrl_agent_act_specs = [ + _dmcontrol_to_torchrl_spec_transform( + agent_act_spec, categorical_discrete_encoding=self.categorical_actions + ) + for agent_act_spec in mp_act_spec + ] + torchrl_state_spec = _dmcontrol_to_torchrl_spec_transform(mp_global_state_spec) + torchrl_rew_spec = [ + _dmcontrol_to_torchrl_spec_transform(agent_rew_spec) + for agent_rew_spec in mp_rew_spec + ] + + # Create and check group map + _num_players = len(torchrl_rew_spec) + self.agent_names = [ + PLAYER_STR_FORMAT.format(index=index) for index in range(_num_players) + ] + self.agent_names_to_indices_map = { + agent_name: i for i, agent_name in enumerate(self.agent_names) + } + self._make_group_map() + + action_spec = CompositeSpec() + observation_spec = CompositeSpec() + reward_spec = CompositeSpec() + + for group in self.group_map.keys(): + ( + group_observation_spec, + group_action_spec, + group_reward_spec, + ) = self._make_group_specs( + group, + torchrl_agent_obs_specs, + torchrl_agent_act_specs, + torchrl_rew_spec, + ) + action_spec[group] = group_action_spec + observation_spec[group] = group_observation_spec + reward_spec[group] = group_reward_spec + + observation_spec.update(torchrl_state_spec) + self.done_spec = CompositeSpec( + { + "done": DiscreteTensorSpec( + n=2, shape=torch.Size((1,)), dtype=torch.bool + ), + }, + ) + self.action_spec = action_spec + self.observation_spec = observation_spec + self.reward_spec = reward_spec + + def _make_group_specs( + self, + group: str, + torchrl_agent_obs_specs: List[TensorSpec], + torchrl_agent_act_specs: List[TensorSpec], + torchrl_rew_spec: List[TensorSpec], + ): + # Agent specs + action_specs = [] + observation_specs = [] + reward_specs = [] + + for agent_name in self.group_map[group]: + agent_index = self.agent_names_to_indices_map[agent_name] + action_specs.append( + CompositeSpec( + { + "action": torchrl_agent_act_specs[ + agent_index + ] # shape = (n_actions_per_agent,) + }, + ) + ) + observation_specs.append( + CompositeSpec( + { + "observation": torchrl_agent_obs_specs[ + agent_index + ] # shape = (n_obs_per_agent,) + }, + ) + ) + reward_specs.append( + CompositeSpec({"reward": torchrl_rew_spec[agent_index]}) # shape = (1,) + ) + + # Create multi-agent specs + group_action_spec = torch.stack( + action_specs, dim=0 + ) # shape = (n_agents_in_group, n_actions_per_agent) + group_observation_spec = torch.stack( + observation_specs, dim=0 + ) # shape = (n_agents_in_group, n_obs_per_agent) + group_reward_spec = torch.stack( + reward_specs, dim=0 + ) # shape = (n_agents_in_group, 1) + return ( + group_observation_spec, + group_action_spec, + group_reward_spec, + ) + + def _check_kwargs(self, kwargs: Dict): + meltingpot = self.lib + + if "env" not in kwargs: + raise TypeError("Could not find environment key 'env' in kwargs.") + env = kwargs["env"] + if not isinstance(env, meltingpot.utils.substrates.substrate.Substrate): + raise TypeError( + "env is not of type 'meltingpot.utils.substrates.substrate.Substrate'." + ) + + def _init_env(self) -> Optional[int]: + pass + + def _set_seed(self, seed: Optional[int]): + raise NotImplementedError + + def close(self) -> None: + self._env.close() + + def _reset( + self, tensordict: Optional[TensorDictBase] = None, **kwargs + ) -> TensorDictBase: + return tensordict + + def _step( + self, + tensordict: TensorDictBase, + ) -> TensorDictBase: + return tensordict + + +class MeltingpotEnv(MeltingpotWrapper): + """Meltingpot environment wrapper. + + GitHub: https://github.com/google-deepmind/meltingpot + + Paper: https://arxiv.org/abs/2211.13746 + + Args: + substrate(str or ml_collections.config_dict.ConfigDict): the meltingpot substrate to build. + Can be a string from :attr:`~.available_envs` or a ConfigDict for the substrate + + Keyword Args: + max_steps (int, optional): Horizon of the task. Defaults to ``None`` (infinite horizon). + Each Meltingpot substrate can + be terminating or not. If ``max_steps`` is specified, + the scenario is also terminated (and the ``"terminated"`` flag is set) whenever this horizon is reached. + Unlike gym's ``TimeLimit`` transform or torchrl's :class:`~torchrl.envs.transforms.StepCounter`, + this argument will not set the ``"truncated"`` entry in the tensordict. + categorical_actions (bool, optional): if the environment actions are discrete, whether to transform + them to categorical or one-hot. Defaults to ``True``. + group_map (MarlGroupMapType or Dict[str, List[str]], optional): how to group agents in tensordicts for + input/output. By default, they will be all put + in one group named ``"agents"``. + Otherwise, a group map can be specified or selected from some premade options. + See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + + + Attributes: + group_map (Dict[str, List[str]]): how to group agents in tensordicts for + input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. + agent_names (list of str): names of the agent in the environment + agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the enviornment + available_envs (List[str]): the list of the scenarios available to build. + + .. warning:: + Meltingpot returns a single ``done`` flag which does not distinguish between + when the env reached ``max_steps`` and termination. + If you deem the ``truncation`` signal necessary, set ``max_steps`` to + ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. + + + """ + + def __init__( + self, + substrate: str | "ml_collections.config_dict.ConfigDict", # noqa + *, + max_steps: Optional[int] = None, + categorical_actions: bool = True, + group_map: MarlGroupMapType + | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, + **kwargs, + ): + if not _has_meltingpot: + raise ImportError( + f"meltingpot python package was not found. Please install this dependency. " + f"More info: {self.git_url}." + ) + super().__init__( + substrate=substrate, + max_steps=max_steps, + categorical_actions=categorical_actions, + group_map=group_map, + **kwargs, + ) + + def _check_kwargs(self, kwargs: Dict): + if "substrate" not in kwargs: + raise TypeError("Could not find environment key 'substrate' in kwargs.") + + def _build_env( + self, + substrate: str | "ml_collections.config_dict.ConfigDict", # noqa + ) -> "meltingpot.utils.substrates.substrate.Substrate": # noqa + from meltingpot import substrate as mp_substrate + + if isinstance(substrate, str): + substrate_config = mp_substrate.get_config(substrate) + else: + substrate_config = substrate + + return super()._build_env( + env=mp_substrate.build_from_config( + substrate_config, roles=substrate_config.default_player_roles + ) + ) + + +def _timestep_to_observations(timestep: dm_env.TimeStep) -> Mapping[str, Any]: + gym_observations = {} + for index, observation in enumerate(timestep.observation): + gym_observations[PLAYER_STR_FORMAT.format(index=index)] = { + key: value for key, value in observation.items() if _WORLD_PREFIX not in key + } + return gym_observations + + +def _remove_world_observations_from_obs_spec( + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] +) -> Sequence[Mapping[str, dm_env.specs.Array]]: + return [ + {key: value for key, value in agent_obs.items() if _WORLD_PREFIX not in key} + for agent_obs in observation_spec + ] + + +def _global_state_spec_from_obs_spec( + observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] +) -> Mapping[str, dm_env.specs.Array]: + # We only look at agent 0 since world entries are the same for all agents + world_entries = { + key: value for key, value in observation_spec[0].items() if _WORLD_PREFIX in key + } + if len(world_entries) != 1 and "WORLD.RGB" not in world_entries: + raise ValueError( + f"Expected only one world entry named WORLD.RGB in observation_spec, but got {world_entries}" + ) + return world_entries diff --git a/torchrl/envs/libs/vmas.py b/torchrl/envs/libs/vmas.py index 51d3970fded..cb1d6294a2d 100644 --- a/torchrl/envs/libs/vmas.py +++ b/torchrl/envs/libs/vmas.py @@ -141,7 +141,7 @@ class VmasWrapper(_EnvWrapper): group_map (Dict[str, List[str]]): how to group agents in tensordicts for input/output. See :class:`~torchrl.envs.utils.MarlGroupMapType` for more info. agent_names (list of str): names of the agent in the environment - agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the enviornment + agent_names_to_indices_map (Dict[str, int]): dictionary mapping agent names to their index in the environment unbatched_action_spec (TensorSpec): version of the spec without the vectorized dimension unbatched_observation_spec (TensorSpec): version of the spec without the vectorized dimension unbatched_reward_spec (TensorSpec): version of the spec without the vectorized dimension From 98e0abdc215eb91004e5e9c146e039cc7a0cf7cf Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 11:19:56 +0200 Subject: [PATCH 02/59] remove import --- torchrl/envs/libs/meltingpot.py | 68 ++++++++++++++++----------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index db949f17409..1d7ca6f8c3c 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -8,8 +8,6 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence -import dm_env - import torch from tensordict import TensorDictBase @@ -25,6 +23,40 @@ _WORLD_PREFIX = "WORLD." +def _timestep_to_observations( + timestep: "dm_env.TimeStep", # noqa +) -> Mapping[str, Any]: + gym_observations = {} + for index, observation in enumerate(timestep.observation): + gym_observations[PLAYER_STR_FORMAT.format(index=index)] = { + key: value for key, value in observation.items() if _WORLD_PREFIX not in key + } + return gym_observations + + +def _remove_world_observations_from_obs_spec( + observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]], # noqa +) -> Sequence[Mapping[str, "dm_env.specs.Array"]]: # noqa + return [ + {key: value for key, value in agent_obs.items() if _WORLD_PREFIX not in key} + for agent_obs in observation_spec + ] + + +def _global_state_spec_from_obs_spec( + observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]] # noqa +) -> Mapping[str, "dm_env.specs.Array"]: # noqa + # We only look at agent 0 since world entries are the same for all agents + world_entries = { + key: value for key, value in observation_spec[0].items() if _WORLD_PREFIX in key + } + if len(world_entries) != 1 and "WORLD.RGB" not in world_entries: + raise ValueError( + f"Expected only one world entry named WORLD.RGB in observation_spec, but got {world_entries}" + ) + return world_entries + + class MeltingpotWrapper(_EnvWrapper): """Meltingpot environment wrapper. @@ -349,35 +381,3 @@ def _build_env( substrate_config, roles=substrate_config.default_player_roles ) ) - - -def _timestep_to_observations(timestep: dm_env.TimeStep) -> Mapping[str, Any]: - gym_observations = {} - for index, observation in enumerate(timestep.observation): - gym_observations[PLAYER_STR_FORMAT.format(index=index)] = { - key: value for key, value in observation.items() if _WORLD_PREFIX not in key - } - return gym_observations - - -def _remove_world_observations_from_obs_spec( - observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] -) -> Sequence[Mapping[str, dm_env.specs.Array]]: - return [ - {key: value for key, value in agent_obs.items() if _WORLD_PREFIX not in key} - for agent_obs in observation_spec - ] - - -def _global_state_spec_from_obs_spec( - observation_spec: Sequence[Mapping[str, dm_env.specs.Array]] -) -> Mapping[str, dm_env.specs.Array]: - # We only look at agent 0 since world entries are the same for all agents - world_entries = { - key: value for key, value in observation_spec[0].items() if _WORLD_PREFIX in key - } - if len(world_entries) != 1 and "WORLD.RGB" not in world_entries: - raise ValueError( - f"Expected only one world entry named WORLD.RGB in observation_spec, but got {world_entries}" - ) - return world_entries From f73bbf0eb956c2551325d9ef11ac3acd3169d71d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 11:50:57 +0200 Subject: [PATCH 03/59] reset --- torchrl/envs/libs/meltingpot.py | 70 ++++++++++++++++++++++++++------- 1 file changed, 56 insertions(+), 14 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 1d7ca6f8c3c..87dfdd31994 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -6,11 +6,12 @@ import importlib -from typing import Any, Dict, List, Mapping, Optional, Sequence +from typing import Dict, List, Mapping, Optional, Sequence, Union +import numpy as np import torch -from tensordict import TensorDictBase +from tensordict import TensorDict, TensorDictBase from torchrl.data import CompositeSpec, DiscreteTensorSpec, TensorSpec from torchrl.envs.common import _EnvWrapper @@ -23,17 +24,6 @@ _WORLD_PREFIX = "WORLD." -def _timestep_to_observations( - timestep: "dm_env.TimeStep", # noqa -) -> Mapping[str, Any]: - gym_observations = {} - for index, observation in enumerate(timestep.observation): - gym_observations[PLAYER_STR_FORMAT.format(index=index)] = { - key: value for key, value in observation.items() if _WORLD_PREFIX not in key - } - return gym_observations - - def _remove_world_observations_from_obs_spec( observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]], # noqa ) -> Sequence[Mapping[str, "dm_env.specs.Array"]]: # noqa @@ -286,7 +276,38 @@ def close(self) -> None: def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: - return tensordict + timestep = self._env.reset() + obs = timestep.observation + td = self.full_done_spec.zero() + + self.num_cycles = 0 + + for group, agent_names in self.group_map.items(): + agent_tds = [] + for agent_name in agent_names: + i = self.agent_names_to_indices_map[agent_name] + agent_obs = self._read_obs(obs[i], world=False) + agent_td = TensorDict( + source={ + "observation": agent_obs, + }, + batch_size=self.batch_size, + device=self.device, + ) + + agent_tds.append(agent_td) + agent_tds = torch.stack(agent_tds, dim=0) + td.update({group: agent_tds}) + + # Global state + td.update(self._read_obs(obs[0], world=True)) + + tensordict_out = TensorDict( + source=td, + batch_size=self.batch_size, + device=self.device, + ) + return tensordict_out def _step( self, @@ -294,6 +315,27 @@ def _step( ) -> TensorDictBase: return tensordict + def _read_obs( + self, observation: Union[Dict[str, np.ndarray], np.ndarray], world: bool + ) -> Union[TensorDictBase, torch.Tensor]: + if isinstance(observation, np.ndarray): + return torch.from_numpy(observation) + elif isinstance(observation, Dict): + return TensorDict( + source={ + key: self._read_obs(value, world=world) + for key, value in observation.items() + if ( + (_WORLD_PREFIX not in key) + if not world + else (_WORLD_PREFIX in key) + ) + }, + batch_size=self.batch_size, + ) + else: + return torch.tensor(observation) + class MeltingpotEnv(MeltingpotWrapper): """Meltingpot environment wrapper. From e7e1b370e3ede39d3fb306ca16a1e62ae7715beb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 11:54:14 +0200 Subject: [PATCH 04/59] reset --- torchrl/envs/libs/meltingpot.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 87dfdd31994..6712e165adc 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -264,8 +264,9 @@ def _check_kwargs(self, kwargs: Dict): "env is not of type 'meltingpot.utils.substrates.substrate.Substrate'." ) - def _init_env(self) -> Optional[int]: - pass + def _init_env(self): + # Caching + self.cached_full_done_spec_zero = self.full_done_spec.zero() def _set_seed(self, seed: Optional[int]): raise NotImplementedError @@ -276,11 +277,11 @@ def close(self) -> None: def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: + self.num_cycles = 0 timestep = self._env.reset() obs = timestep.observation - td = self.full_done_spec.zero() - self.num_cycles = 0 + td = self.cached_full_done_spec_zero.clone() for group, agent_names in self.group_map.items(): agent_tds = [] From 0d1d65c1a689804c7a7ccfa21995bf94b497e08d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 12:10:02 +0200 Subject: [PATCH 05/59] simplify reset --- torchrl/envs/libs/meltingpot.py | 48 ++++++++++++--------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 6712e165adc..944ec004645 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -6,9 +6,8 @@ import importlib -from typing import Dict, List, Mapping, Optional, Sequence, Union +from typing import Dict, List, Mapping, Optional, Sequence -import numpy as np import torch from tensordict import TensorDict, TensorDictBase @@ -24,11 +23,19 @@ _WORLD_PREFIX = "WORLD." +def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa + return { + key: value + for key, value in obs_dict.items() + if ((_WORLD_PREFIX not in key) if not world else (_WORLD_PREFIX in key)) + } + + def _remove_world_observations_from_obs_spec( observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]], # noqa ) -> Sequence[Mapping[str, "dm_env.specs.Array"]]: # noqa return [ - {key: value for key, value in agent_obs.items() if _WORLD_PREFIX not in key} + _filter_global_state_from_dict(agent_obs, world=False) for agent_obs in observation_spec ] @@ -37,9 +44,7 @@ def _global_state_spec_from_obs_spec( observation_spec: Sequence[Mapping[str, "dm_env.specs.Array"]] # noqa ) -> Mapping[str, "dm_env.specs.Array"]: # noqa # We only look at agent 0 since world entries are the same for all agents - world_entries = { - key: value for key, value in observation_spec[0].items() if _WORLD_PREFIX in key - } + world_entries = _filter_global_state_from_dict(observation_spec[0], world=True) if len(world_entries) != 1 and "WORLD.RGB" not in world_entries: raise ValueError( f"Expected only one world entry named WORLD.RGB in observation_spec, but got {world_entries}" @@ -285,9 +290,11 @@ def _reset( for group, agent_names in self.group_map.items(): agent_tds = [] - for agent_name in agent_names: - i = self.agent_names_to_indices_map[agent_name] - agent_obs = self._read_obs(obs[i], world=False) + for index_in_group, agent_name in enumerate(agent_names): + global_index = self.agent_names_to_indices_map[agent_name] + agent_obs = self.observation_spec[group, "observation"][ + index_in_group + ].encode(_filter_global_state_from_dict(obs[global_index], world=False)) agent_td = TensorDict( source={ "observation": agent_obs, @@ -301,7 +308,7 @@ def _reset( td.update({group: agent_tds}) # Global state - td.update(self._read_obs(obs[0], world=True)) + td.update(_filter_global_state_from_dict(obs[0], world=True)) tensordict_out = TensorDict( source=td, @@ -316,27 +323,6 @@ def _step( ) -> TensorDictBase: return tensordict - def _read_obs( - self, observation: Union[Dict[str, np.ndarray], np.ndarray], world: bool - ) -> Union[TensorDictBase, torch.Tensor]: - if isinstance(observation, np.ndarray): - return torch.from_numpy(observation) - elif isinstance(observation, Dict): - return TensorDict( - source={ - key: self._read_obs(value, world=world) - for key, value in observation.items() - if ( - (_WORLD_PREFIX not in key) - if not world - else (_WORLD_PREFIX in key) - ) - }, - batch_size=self.batch_size, - ) - else: - return torch.tensor(observation) - class MeltingpotEnv(MeltingpotWrapper): """Meltingpot environment wrapper. From b56beb16123fa69a05ce8c21de4af7d13844aba0 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 14:34:43 +0200 Subject: [PATCH 06/59] step --- torchrl/envs/libs/meltingpot.py | 55 ++++++++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 944ec004645..03662e480ad 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -123,6 +123,7 @@ def __init__( self.group_map = group_map self.categorical_actions = categorical_actions self.max_steps = max_steps + self.num_cycles = 0 super().__init__(**kwargs) def _build_env( @@ -321,7 +322,59 @@ def _step( self, tensordict: TensorDictBase, ) -> TensorDictBase: - return tensordict + action_dict = {} + for group, agents in self.group_map.items(): + group_action = tensordict.get((group, "action")) + group_action_np = self.full_action_spec[group, "action"].to_numpy( + group_action + ) + for index, agent in enumerate(agents): + action_dict[agent] = group_action_np[index] + + actions = [action_dict[agent] for agent in self.agent_names] + timestep = self._env.step(actions) + self.num_cycles += 1 + + rewards = timestep.reward + done = timestep.last() or ( + (self.num_cycles >= self.max_steps) if self.max_steps is not None else False + ) + obs = timestep.observation + + td = TensorDict( + { + "done": self.full_done_spec["done"].encode(done), + "terminated": self.full_done_spec["terminated"].encode(done), + }, + batch_size=self.batch_size, + ) + # Global state + td.update(_filter_global_state_from_dict(obs[0], world=True)) + + for group, agent_names in self.group_map.items(): + agent_tds = [] + for index_in_group, agent_name in enumerate(agent_names): + global_index = self.agent_names_to_indices_map[agent_name] + agent_obs = self.observation_spec[group, "observation"][ + index_in_group + ].encode(_filter_global_state_from_dict(obs[global_index], world=False)) + agent_reward = self.full_reward_spec[group, "reward"][ + index_in_group + ].encode(rewards[global_index]) + agent_td = TensorDict( + source={ + "observation": agent_obs, + "reward": agent_reward, + }, + batch_size=self.batch_size, + device=self.device, + ) + + agent_tds.append(agent_td) + agent_tds = torch.stack(agent_tds, dim=0) + td.update({group: agent_tds}) + + return td class MeltingpotEnv(MeltingpotWrapper): From 8166bf71e02141f5b67c8b1dea131e521abd5bbc Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 14:50:25 +0200 Subject: [PATCH 07/59] tests --- .../scripts_meltingpot/environment.yml | 20 + .../linux_libs/scripts_meltingpot/install.sh | 58 +++ .../scripts_meltingpot/post_process.sh | 6 + .../scripts_meltingpot/run-clang-format.py | 356 ++++++++++++++++++ .../linux_libs/scripts_meltingpot/run_test.sh | 31 ++ .../scripts_meltingpot/setup_env.sh | 47 +++ .github/workflows/test-linux-libs.yml | 32 ++ test/test_libs.py | 17 +- torchrl/prova.py | 13 + 9 files changed, 575 insertions(+), 5 deletions(-) create mode 100644 .github/unittest/linux_libs/scripts_meltingpot/environment.yml create mode 100755 .github/unittest/linux_libs/scripts_meltingpot/install.sh create mode 100755 .github/unittest/linux_libs/scripts_meltingpot/post_process.sh create mode 100755 .github/unittest/linux_libs/scripts_meltingpot/run-clang-format.py create mode 100755 .github/unittest/linux_libs/scripts_meltingpot/run_test.sh create mode 100755 .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh create mode 100644 torchrl/prova.py diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml new file mode 100644 index 00000000000..8eb8faf8e64 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -0,0 +1,20 @@ +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: + - cloudpickle + - importlib-metadata + - numpy + - torch + - zipp + - pytest + - pytest-cov + - pytest-mock + - pytest-instafail + - pytest-rerunfailures + - pytest-error-for-skips + - expecttest + - pyyaml + - meltingpot diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh new file mode 100755 index 00000000000..5ac346c95c5 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bash + +unset PYTORCH_VERSION +# For unittest, nightly PyTorch is used as the following section, +# so no need to set PYTORCH_VERSION. +# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config. + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env + +if [ "${CU_VERSION:-}" == cpu ] ; then + version="cpu" +else + if [[ ${#CU_VERSION} -eq 4 ]]; then + CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}" + elif [[ ${#CU_VERSION} -eq 5 ]]; then + CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}" + fi + echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)" + version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")" +fi + +# submodules +git submodule sync && git submodule update --init --recursive + +printf "Installing PyTorch with cu121" +if [[ "$TORCH_VERSION" == "nightly" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U + else + pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U + fi +elif [[ "$TORCH_VERSION" == "stable" ]]; then + if [ "${CU_VERSION:-}" == cpu ] ; then + pip3 install torch --index-url https://download.pytorch.org/whl/cpu + else + pip3 install torch --index-url https://download.pytorch.org/whl/cu121 + fi +else + printf "Failed to install pytorch" + exit 1 +fi + +# install tensordict +if [[ "$RELEASE" == 0 ]]; then + pip3 install git+https://github.com/pytorch/tensordict.git +else + pip3 install tensordict +fi + +# smoke test +python -c "import tensordict" + +printf "* Installing torchrl\n" +python setup.py develop +python -c "import torchrl" diff --git a/.github/unittest/linux_libs/scripts_meltingpot/post_process.sh b/.github/unittest/linux_libs/scripts_meltingpot/post_process.sh new file mode 100755 index 00000000000..e97bf2a7b1b --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/post_process.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env diff --git a/.github/unittest/linux_libs/scripts_meltingpot/run-clang-format.py b/.github/unittest/linux_libs/scripts_meltingpot/run-clang-format.py new file mode 100755 index 00000000000..5783a885d86 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/run-clang-format.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python +""" +MIT License + +Copyright (c) 2017 Guillaume Papin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +A wrapper script around clang-format, suitable for linting multiple files +and to use for continuous integration. + +This is an alternative API for the clang-format command line. +It runs over multiple files and directories in parallel. +A diff output is produced and a sensible exit code is returned. + +""" + +import argparse +import difflib +import fnmatch +import multiprocessing +import os +import signal +import subprocess +import sys +import traceback +from functools import partial + +try: + from subprocess import DEVNULL # py3k +except ImportError: + DEVNULL = open(os.devnull, "wb") + + +DEFAULT_EXTENSIONS = "c,h,C,H,cpp,hpp,cc,hh,c++,h++,cxx,hxx,cu" + + +class ExitStatus: + SUCCESS = 0 + DIFF = 1 + TROUBLE = 2 + + +def list_files(files, recursive=False, extensions=None, exclude=None): + if extensions is None: + extensions = [] + if exclude is None: + exclude = [] + + out = [] + for file in files: + if recursive and os.path.isdir(file): + for dirpath, dnames, fnames in os.walk(file): + fpaths = [os.path.join(dirpath, fname) for fname in fnames] + for pattern in exclude: + # os.walk() supports trimming down the dnames list + # by modifying it in-place, + # to avoid unnecessary directory listings. + dnames[:] = [ + x + for x in dnames + if not fnmatch.fnmatch(os.path.join(dirpath, x), pattern) + ] + fpaths = [x for x in fpaths if not fnmatch.fnmatch(x, pattern)] + for f in fpaths: + ext = os.path.splitext(f)[1][1:] + if ext in extensions: + out.append(f) + else: + out.append(file) + return out + + +def make_diff(file, original, reformatted): + return list( + difflib.unified_diff( + original, + reformatted, + fromfile=f"{file}\t(original)", + tofile=f"{file}\t(reformatted)", + n=3, + ) + ) + + +class DiffError(Exception): + def __init__(self, message, errs=None): + super().__init__(message) + self.errs = errs or [] + + +class UnexpectedError(Exception): + def __init__(self, message, exc=None): + super().__init__(message) + self.formatted_traceback = traceback.format_exc() + self.exc = exc + + +def run_clang_format_diff_wrapper(args, file): + try: + ret = run_clang_format_diff(args, file) + return ret + except DiffError: + raise + except Exception as e: + raise UnexpectedError(f"{file}: {e.__class__.__name__}: {e}", e) + + +def run_clang_format_diff(args, file): + try: + with open(file, encoding="utf-8") as f: + original = f.readlines() + except OSError as exc: + raise DiffError(str(exc)) + invocation = [args.clang_format_executable, file] + + # Use of utf-8 to decode the process output. + # + # Hopefully, this is the correct thing to do. + # + # It's done due to the following assumptions (which may be incorrect): + # - clang-format will returns the bytes read from the files as-is, + # without conversion, and it is already assumed that the files use utf-8. + # - if the diagnostics were internationalized, they would use utf-8: + # > Adding Translations to Clang + # > + # > Not possible yet! + # > Diagnostic strings should be written in UTF-8, + # > the client can translate to the relevant code page if needed. + # > Each translation completely replaces the format string + # > for the diagnostic. + # > -- http://clang.llvm.org/docs/InternalsManual.html#internals-diag-translation + + try: + proc = subprocess.Popen( + invocation, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True, + encoding="utf-8", + ) + except OSError as exc: + raise DiffError( + f"Command '{subprocess.list2cmdline(invocation)}' failed to start: {exc}" + ) + proc_stdout = proc.stdout + proc_stderr = proc.stderr + + # hopefully the stderr pipe won't get full and block the process + outs = list(proc_stdout.readlines()) + errs = list(proc_stderr.readlines()) + proc.wait() + if proc.returncode: + raise DiffError( + "Command '{}' returned non-zero exit status {}".format( + subprocess.list2cmdline(invocation), proc.returncode + ), + errs, + ) + return make_diff(file, original, outs), errs + + +def bold_red(s): + return "\x1b[1m\x1b[31m" + s + "\x1b[0m" + + +def colorize(diff_lines): + def bold(s): + return "\x1b[1m" + s + "\x1b[0m" + + def cyan(s): + return "\x1b[36m" + s + "\x1b[0m" + + def green(s): + return "\x1b[32m" + s + "\x1b[0m" + + def red(s): + return "\x1b[31m" + s + "\x1b[0m" + + for line in diff_lines: + if line[:4] in ["--- ", "+++ "]: + yield bold(line) + elif line.startswith("@@ "): + yield cyan(line) + elif line.startswith("+"): + yield green(line) + elif line.startswith("-"): + yield red(line) + else: + yield line + + +def print_diff(diff_lines, use_color): + if use_color: + diff_lines = colorize(diff_lines) + sys.stdout.writelines(diff_lines) + + +def print_trouble(prog, message, use_colors): + error_text = "error:" + if use_colors: + error_text = bold_red(error_text) + print(f"{prog}: {error_text} {message}", file=sys.stderr) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--clang-format-executable", + metavar="EXECUTABLE", + help="path to the clang-format executable", + default="clang-format", + ) + parser.add_argument( + "--extensions", + help=f"comma separated list of file extensions (default: {DEFAULT_EXTENSIONS})", + default=DEFAULT_EXTENSIONS, + ) + parser.add_argument( + "-r", + "--recursive", + action="store_true", + help="run recursively over directories", + ) + parser.add_argument("files", metavar="file", nargs="+") + parser.add_argument("-q", "--quiet", action="store_true") + parser.add_argument( + "-j", + metavar="N", + type=int, + default=0, + help="run N clang-format jobs in parallel (default number of cpus + 1)", + ) + parser.add_argument( + "--color", + default="auto", + choices=["auto", "always", "never"], + help="show colored diff (default: auto)", + ) + parser.add_argument( + "-e", + "--exclude", + metavar="PATTERN", + action="append", + default=[], + help="exclude paths matching the given glob-like pattern(s) from recursive search", + ) + + args = parser.parse_args() + + # use default signal handling, like diff return SIGINT value on ^C + # https://bugs.python.org/issue14229#msg156446 + signal.signal(signal.SIGINT, signal.SIG_DFL) + try: + signal.SIGPIPE + except AttributeError: + # compatibility, SIGPIPE does not exist on Windows + pass + else: + signal.signal(signal.SIGPIPE, signal.SIG_DFL) + + colored_stdout = False + colored_stderr = False + if args.color == "always": + colored_stdout = True + colored_stderr = True + elif args.color == "auto": + colored_stdout = sys.stdout.isatty() + colored_stderr = sys.stderr.isatty() + + version_invocation = [args.clang_format_executable, "--version"] + try: + subprocess.check_call(version_invocation, stdout=DEVNULL) + except subprocess.CalledProcessError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + return ExitStatus.TROUBLE + except OSError as e: + print_trouble( + parser.prog, + f"Command '{subprocess.list2cmdline(version_invocation)}' failed to start: {e}", + use_colors=colored_stderr, + ) + return ExitStatus.TROUBLE + + retcode = ExitStatus.SUCCESS + files = list_files( + args.files, + recursive=args.recursive, + exclude=args.exclude, + extensions=args.extensions.split(","), + ) + + if not files: + return + + njobs = args.j + if njobs == 0: + njobs = multiprocessing.cpu_count() + 1 + njobs = min(len(files), njobs) + + if njobs == 1: + # execute directly instead of in a pool, + # less overhead, simpler stacktraces + it = (run_clang_format_diff_wrapper(args, file) for file in files) + pool = None + else: + pool = multiprocessing.Pool(njobs) + it = pool.imap_unordered(partial(run_clang_format_diff_wrapper, args), files) + while True: + try: + outs, errs = next(it) + except StopIteration: + break + except DiffError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + retcode = ExitStatus.TROUBLE + sys.stderr.writelines(e.errs) + except UnexpectedError as e: + print_trouble(parser.prog, str(e), use_colors=colored_stderr) + sys.stderr.write(e.formatted_traceback) + retcode = ExitStatus.TROUBLE + # stop at the first unexpected error, + # something could be very wrong, + # don't process all files unnecessarily + if pool: + pool.terminate() + break + else: + sys.stderr.writelines(errs) + if outs == []: + continue + if not args.quiet: + print_diff(outs, use_color=colored_stdout) + if retcode == ExitStatus.SUCCESS: + retcode = ExitStatus.DIFF + return retcode + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/run_test.sh b/.github/unittest/linux_libs/scripts_meltingpot/run_test.sh new file mode 100755 index 00000000000..6f7ec265f74 --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/run_test.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +set -e + +eval "$(./conda/bin/conda shell.bash hook)" +conda activate ./env +apt-get update && apt-get install -y git wget + + +export PYTORCH_TEST_WITH_SLOW='1' +export LAZY_LEGACY_OP=False +python -m torch.utils.collect_env +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' + +root_dir="$(git rev-parse --show-toplevel)" +env_dir="${root_dir}/env" +lib_dir="${env_dir}/lib" + +# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir +export MKL_THREADING_LAYER=GNU +# more logging +export MAGNUM_LOG=verbose MAGNUM_GPU_VALIDATION=ON + +# this workflow only tests the libs +python -c "import meltingpot" + +python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestMeltingpot --error-for-skips +coverage combine +coverage xml -i diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh new file mode 100755 index 00000000000..dc524958e5e --- /dev/null +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +# This script is for setting up environment in which unit test is ran. +# To speed up the CI time, the resulting environment is cached. +# +# Do not install PyTorch and torchvision here, otherwise they also get cached. + +set -e + +this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +# Avoid error: "fatal: unsafe repository" +git config --global --add safe.directory '*' +root_dir="$(git rev-parse --show-toplevel)" +conda_dir="${root_dir}/conda" +env_dir="${root_dir}/env" + +cd "${root_dir}" + +case "$(uname -s)" in + Darwin*) os=MacOSX;; + *) os=Linux +esac + +# 1. Install conda at ./conda +if [ ! -d "${conda_dir}" ]; then + printf "* Installing conda\n" + wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh" + bash ./miniconda.sh -b -f -p "${conda_dir}" +fi +eval "$(${conda_dir}/bin/conda shell.bash hook)" + +# 2. Create test environment at ./env +printf "python: ${PYTHON_VERSION}\n" +if [ ! -d "${env_dir}" ]; then + printf "* Creating a test environment\n" + conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION" +fi +conda activate "${env_dir}" + +# 4. Install Conda dependencies +printf "* Installing dependencies (except PyTorch)\n" +echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml" +cat "${this_dir}/environment.yml" + +pip install pip --upgrade + +conda env update --file "${this_dir}/environment.yml" --prune diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index cc0e4a4f54e..4a83e86c650 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -562,3 +562,35 @@ jobs: bash .github/unittest/linux_libs/scripts_vmas/install.sh bash .github/unittest/linux_libs/scripts_vmas/run_test.sh bash .github/unittest/linux_libs/scripts_vmas/post_process.sh + + unittests-meltingpot: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="12.1" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh + bash .github/unittest/linux_libs/scripts_meltingpot/install.sh + bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh + bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh diff --git a/test/test_libs.py b/test/test_libs.py index c72fcc562e6..3bab721e7e6 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -103,6 +103,7 @@ ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv +from torchrl.envs.libs.meltingpot import MeltingpotEnv from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv @@ -139,6 +140,8 @@ assert gym_backend() is gym +_has_meltingpot = importlib.util.find_spec("meltingpot") is not None + def get_gym_pixel_wrapper(): try: @@ -278,7 +281,6 @@ def _make_spec( # noqa: F811 @pytest.mark.parametrize("categorical", [True, False]) def test_gym_spec_cast(self, categorical): - batch_size = [3, 4] cat = DiscreteTensorSpec if categorical else OneHotDiscreteTensorSpec cat_shape = batch_size if categorical else (*batch_size, 5) @@ -543,7 +545,6 @@ def test_torchrl_to_gym(self, backend, numpy): ], ) def test_gym(self, env_name, frame_skip, from_pixels, pixels_only): - if env_name == PONG_VERSIONED() and not from_pixels: # raise pytest.skip("already pixel") # we don't skip because that would raise an exception @@ -3126,7 +3127,6 @@ class TestPettingZoo: def test_pistonball( self, parallel, continuous_actions, use_mask, return_state, group_map ): - kwargs = {"n_pistons": 21, "continuous": continuous_actions} env = PettingZooEnv( @@ -3156,7 +3156,6 @@ def test_tic_tac_toe(self, wins_player_0): ) class Policy: - action = 0 t = 0 @@ -3329,7 +3328,7 @@ def test_collector(self, task, parallel): break -@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found") +@pytest.mark.skipif(not _has_robohive, reason="Robohive not found") class TestRoboHive: # unfortunately we must import robohive to get the available envs # and this import will occur whenever pytest is run on this file. @@ -3452,6 +3451,14 @@ def test_collector(self): collector.shutdown() +@pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found") +class TestMeltingpot: + @pytest.mark.parametrize("substrate", MeltingpotEnv.available_envs) + def test_all_envs(self, substrate): + env = MeltingpotEnv(substrate=substrate) + check_env_specs(env) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/prova.py b/torchrl/prova.py new file mode 100644 index 00000000000..9e93793f91a --- /dev/null +++ b/torchrl/prova.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024. +# ProrokLab (https://www.proroklab.org/) +# All rights reserved. +from torchrl.envs.libs.meltingpot import MeltingpotEnv + +if __name__ == "__main__": + from meltingpot import substrate as mp_substrate + + substrate_config = mp_substrate.get_config("commons_harvest__open") + env_torchrl = MeltingpotEnv(substrate_config) + td_reset = env_torchrl.reset() + td_in = env_torchrl.rand_action(td_reset.clone()) + td = env_torchrl.step(td_in.clone()) From bbcb0b2b920407ce075efaf9932dd5adbe2c4fa5 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 14:52:09 +0200 Subject: [PATCH 08/59] docs --- docs/source/reference/envs.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index 1523a5348ba..e2637bfffae 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -813,6 +813,8 @@ the following function will return ``1`` when queried: IsaacGymWrapper JumanjiEnv JumanjiWrapper + MeltingpotEnv + MeltingpotWrapper MOGymEnv MOGymWrapper MultiThreadedEnv From 4ef5197697867d441f22ee83316ca0642bc943f9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 14:54:46 +0200 Subject: [PATCH 09/59] oops --- torchrl/prova.py | 13 ------------- 1 file changed, 13 deletions(-) delete mode 100644 torchrl/prova.py diff --git a/torchrl/prova.py b/torchrl/prova.py deleted file mode 100644 index 9e93793f91a..00000000000 --- a/torchrl/prova.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024. -# ProrokLab (https://www.proroklab.org/) -# All rights reserved. -from torchrl.envs.libs.meltingpot import MeltingpotEnv - -if __name__ == "__main__": - from meltingpot import substrate as mp_substrate - - substrate_config = mp_substrate.get_config("commons_harvest__open") - env_torchrl = MeltingpotEnv(substrate_config) - td_reset = env_torchrl.reset() - td_in = env_torchrl.rand_action(td_reset.clone()) - td = env_torchrl.step(td_in.clone()) From 114ed2f00b6827e6a23350434f4669370a9191fd Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 15:02:48 +0200 Subject: [PATCH 10/59] render --- torchrl/envs/libs/meltingpot.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 03662e480ad..88c2f6be62c 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -376,6 +376,19 @@ def _step( return td + def render(self, mode="human", filename=None): + from matplotlib import pyplot as plt + + rgb_arr = self._env.observation()[0]["WORLD.RGB"] + if mode == "human": + plt.cla() + plt.imshow(rgb_arr, interpolation="nearest") + if filename is None: + plt.show(block=False) + else: + plt.savefig(filename) + return rgb_arr + class MeltingpotEnv(MeltingpotWrapper): """Meltingpot environment wrapper. From aeedda098032274e88f003653a517551a2233b30 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 15:06:58 +0200 Subject: [PATCH 11/59] docstring render --- torchrl/envs/libs/meltingpot.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 88c2f6be62c..99397d945f1 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -8,6 +8,7 @@ from typing import Dict, List, Mapping, Optional, Sequence +import numpy as np import torch from tensordict import TensorDict, TensorDictBase @@ -277,9 +278,6 @@ def _init_env(self): def _set_seed(self, seed: Optional[int]): raise NotImplementedError - def close(self) -> None: - self._env.close() - def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs ) -> TensorDictBase: @@ -376,7 +374,19 @@ def _step( return td - def render(self, mode="human", filename=None): + def render(self, mode="human", filename=None) -> np.ndarray: + """Renders the environment using matplotlib. + + Args: + mode (str, optional): One of ``"human"``, ``"rgb_array"``. If ``"human"`` it renders the + environment in the GUI. In any case the function returns a RGB array. Defaults to ``"human"`` + filename (str, optional): Filename to save the render to. Defaults to ``None``, + in which case no file is saved. + + Returns: + np.ndarray: The rendered image + + """ from matplotlib import pyplot as plt rgb_arr = self._env.observation()[0]["WORLD.RGB"] From 6974d8485294e0c1aac88faadb462b4b6e9ae783 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 15:30:13 +0200 Subject: [PATCH 12/59] more tests --- test/test_libs.py | 32 +++++++++++++++++++++++++++++++- torchrl/envs/libs/meltingpot.py | 5 ++++- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 3bab721e7e6..bc40de82399 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -103,7 +103,7 @@ ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.libs.meltingpot import MeltingpotEnv +from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv @@ -3458,6 +3458,36 @@ def test_all_envs(self, substrate): env = MeltingpotEnv(substrate=substrate) check_env_specs(env) + def test_passing_config(self, substrate="commons_harvest__open"): + from meltingpot import substrate as mp_substrate + + substrate_config = mp_substrate.get_config(substrate) + env_torchrl = MeltingpotEnv(substrate_config) + env_torchrl.rollout(max_steps=5) + + def test_wrapper(self, substrate="commons_harvest__open"): + from meltingpot import substrate as mp_substrate + + substrate_config = mp_substrate.get_config(substrate) + mp_env = mp_substrate.build_from_config( + substrate_config, roles=substrate_config.default_player_roles + ) + env_torchrl = MeltingpotWrapper(env=mp_env) + env_torchrl.rollout(max_steps=5) + + @pytest.mark.parametrize("max_steps", [1, 5]) + def test_max_steps(self, max_steps): + env = MeltingpotEnv(substrate="commons_harvest__open", max_steps=max_steps) + td = env.rollout(max_steps=100, break_when_any_done=True) + assert td.batch_size[0] == max_steps + + @pytest.mark.parametrize("categorical_actions", [True, False]) + def test_categorical_actions(self, categorical_actions): + env = MeltingpotEnv( + substrate="commons_harvest__open", categorical_actions=categorical_actions + ) + check_env_specs(env) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 99397d945f1..1dd2d30952c 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -276,7 +276,10 @@ def _init_env(self): self.cached_full_done_spec_zero = self.full_done_spec.zero() def _set_seed(self, seed: Optional[int]): - raise NotImplementedError + raise NotImplementedError( + "It is unclear how to set a seed in Meltingpot" + " (https://github.com/google-deepmind/meltingpot/issues/129)" + ) def _reset( self, tensordict: Optional[TensorDictBase] = None, **kwargs From 1db8cb244c4811aabab1f1bafb48b593d8c0c1d9 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 3 Apr 2024 15:34:58 +0200 Subject: [PATCH 13/59] examples in docs --- torchrl/envs/libs/meltingpot.py | 103 ++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 1dd2d30952c..c37a380f408 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -91,6 +91,60 @@ class MeltingpotWrapper(_EnvWrapper): If you deem the ``truncation`` signal necessary, set ``max_steps`` to ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. + Examples: + >>> from meltingpot import substrate + >>> from torchrl.envs.libs.meltingpot import MeltingpotWrapper + >>> substrate_config = substrate.get_config("commons_harvest__open") + >>> mp_env = substrate.build_from_config( + ... substrate_config, roles=substrate_config.default_player_roles + ... ) + >>> env_torchrl = MeltingpotWrapper(env=mp_env) + >>> print(env_torchrl.rollout(max_steps=5)) + TensorDict( + fields={ + WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), + observation: TensorDict( + fields={ + COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + agents: TensorDict( + fields={ + observation: TensorDict( + fields={ + COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + """ git_url = "https://github.com/google-deepmind/meltingpot" @@ -443,6 +497,55 @@ class MeltingpotEnv(MeltingpotWrapper): If you deem the ``truncation`` signal necessary, set ``max_steps`` to ``None`` and use a :class:`~torchrl.envs.transforms.StepCounter` transform. + Examples: + >>> from torchrl.envs.libs.meltingpot import MeltingpotEnv + >>> env_torchrl = MeltingpotEnv("commons_harvest__open") + >>> print(env_torchrl.rollout(max_steps=5)) + TensorDict( + fields={ + WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + agents: TensorDict( + fields={ + action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), + observation: TensorDict( + fields={ + COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + next: TensorDict( + fields={ + WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + agents: TensorDict( + fields={ + observation: TensorDict( + fields={ + COLLECTIVE_REWARD: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + READY_TO_SHOOT: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 7, 88, 88, 3]), device=cpu, dtype=torch.uint8, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + reward: Tensor(shape=torch.Size([5, 7, 1]), device=cpu, dtype=torch.float64, is_shared=False)}, + batch_size=torch.Size([5, 7]), + device=cpu, + is_shared=False), + done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False), + terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([5]), + device=cpu, + is_shared=False) + """ From ea4388cac8c2b43a676a14f547daca70dea5fdb7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Thu, 4 Apr 2024 10:26:31 +0200 Subject: [PATCH 14/59] add to init --- torchrl/envs/__init__.py | 2 ++ torchrl/envs/libs/__init__.py | 1 + 2 files changed, 3 insertions(+) diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 2b02ba2feea..3085a2a06d6 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -20,6 +20,8 @@ IsaacGymWrapper, JumanjiEnv, JumanjiWrapper, + MeltingpotEnv, + MeltingpotWrapper, MOGymEnv, MOGymWrapper, MultiThreadedEnv, diff --git a/torchrl/envs/libs/__init__.py b/torchrl/envs/libs/__init__.py index 9121ea4c677..e322c2cbf01 100644 --- a/torchrl/envs/libs/__init__.py +++ b/torchrl/envs/libs/__init__.py @@ -17,6 +17,7 @@ from .habitat import HabitatEnv from .isaacgym import IsaacGymEnv, IsaacGymWrapper from .jumanji import JumanjiEnv, JumanjiWrapper +from .meltingpot import MeltingpotEnv, MeltingpotWrapper from .openml import OpenMLEnv from .pettingzoo import PettingZooEnv, PettingZooWrapper from .robohive import RoboHiveEnv From 6cfda1baf271ce9d1c67df0a22bb8cec8e5ae882 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Sun, 7 Apr 2024 17:29:48 +0100 Subject: [PATCH 15/59] Update torchrl/envs/libs/meltingpot.py Co-authored-by: Vincent Moens --- torchrl/envs/libs/meltingpot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index c37a380f408..6cb13bb0295 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -331,7 +331,7 @@ def _init_env(self): def _set_seed(self, seed: Optional[int]): raise NotImplementedError( - "It is unclear how to set a seed in Meltingpot" + "It is currently unclear how to set a seed in Meltingpot" " (https://github.com/google-deepmind/meltingpot/issues/129)" ) From 096bc94356a69638fab8bb5cec5d216b91c1f9f7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Sun, 7 Apr 2024 17:29:58 +0100 Subject: [PATCH 16/59] Update torchrl/envs/libs/meltingpot.py Co-authored-by: Vincent Moens --- torchrl/envs/libs/meltingpot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 6cb13bb0295..b2fd955100f 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -332,7 +332,7 @@ def _init_env(self): def _set_seed(self, seed: Optional[int]): raise NotImplementedError( "It is currently unclear how to set a seed in Meltingpot" - " (https://github.com/google-deepmind/meltingpot/issues/129)" + "see https://github.com/google-deepmind/meltingpot/issues/129 to track the issue." ) def _reset( From 0ea5047a51233263ef8b8bb7930cebf21d8e6f0a Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 18:55:45 +0200 Subject: [PATCH 17/59] remove Optional type --- torchrl/envs/libs/meltingpot.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index c37a380f408..bb202a5e99e 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -6,7 +6,7 @@ import importlib -from typing import Dict, List, Mapping, Optional, Sequence +from typing import Dict, List, Mapping, Sequence import numpy as np import torch @@ -329,14 +329,14 @@ def _init_env(self): # Caching self.cached_full_done_spec_zero = self.full_done_spec.zero() - def _set_seed(self, seed: Optional[int]): + def _set_seed(self, seed: int | None): raise NotImplementedError( "It is unclear how to set a seed in Meltingpot" " (https://github.com/google-deepmind/meltingpot/issues/129)" ) def _reset( - self, tensordict: Optional[TensorDictBase] = None, **kwargs + self, tensordict: TensorDictBase | None = None, **kwargs ) -> TensorDictBase: self.num_cycles = 0 timestep = self._env.reset() @@ -553,7 +553,7 @@ def __init__( self, substrate: str | "ml_collections.config_dict.ConfigDict", # noqa *, - max_steps: Optional[int] = None, + max_steps: int | None = None, categorical_actions: bool = True, group_map: MarlGroupMapType | Dict[str, List[str]] = MarlGroupMapType.ALL_IN_ONE_GROUP, From 8ef8cb1c21e05cd8e661eaf1d0f48a1eee35dfd3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 18:57:22 +0200 Subject: [PATCH 18/59] alphabetic order tests --- .github/workflows/test-linux-libs.yml | 64 +++++++++++++-------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 4a83e86c650..f86da7a9557 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -255,6 +255,38 @@ jobs: bash .github/unittest/linux_libs/scripts_jumanji/run_test.sh bash .github/unittest/linux_libs/scripts_jumanji/post_process.sh + unittests-meltingpot: + uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + with: + repository: pytorch/rl + runner: "linux.g5.4xlarge.nvidia.gpu" + gpu-arch-type: cuda + gpu-arch-version: "11.7" + timeout: 120 + script: | + if [[ "${{ github.ref }}" =~ release/* ]]; then + export RELEASE=1 + export TORCH_VERSION=stable + else + export RELEASE=0 + export TORCH_VERSION=nightly + fi + + set -euo pipefail + export PYTHON_VERSION="3.9" + export CU_VERSION="12.1" + export TAR_OPTIONS="--no-same-owner" + export UPLOAD_CHANNEL="nightly" + export TF_CPP_MIN_LOG_LEVEL=0 + export BATCHED_PIPE_TIMEOUT=60 + + nvidia-smi + + bash .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh + bash .github/unittest/linux_libs/scripts_meltingpot/install.sh + bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh + bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh + unittests-minari: strategy: matrix: @@ -562,35 +594,3 @@ jobs: bash .github/unittest/linux_libs/scripts_vmas/install.sh bash .github/unittest/linux_libs/scripts_vmas/run_test.sh bash .github/unittest/linux_libs/scripts_vmas/post_process.sh - - unittests-meltingpot: - uses: pytorch/test-infra/.github/workflows/linux_job.yml@main - with: - repository: pytorch/rl - runner: "linux.g5.4xlarge.nvidia.gpu" - gpu-arch-type: cuda - gpu-arch-version: "11.7" - timeout: 120 - script: | - if [[ "${{ github.ref }}" =~ release/* ]]; then - export RELEASE=1 - export TORCH_VERSION=stable - else - export RELEASE=0 - export TORCH_VERSION=nightly - fi - - set -euo pipefail - export PYTHON_VERSION="3.9" - export CU_VERSION="12.1" - export TAR_OPTIONS="--no-same-owner" - export UPLOAD_CHANNEL="nightly" - export TF_CPP_MIN_LOG_LEVEL=0 - export BATCHED_PIPE_TIMEOUT=60 - - nvidia-smi - - bash .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh - bash .github/unittest/linux_libs/scripts_meltingpot/install.sh - bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh - bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh From f81261273b331523ed24c192b8d1624af1ea199e Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 18:59:03 +0200 Subject: [PATCH 19/59] amend --- torchrl/envs/libs/meltingpot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index c4812e0408d..75f66f179c1 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -160,9 +160,10 @@ def lib(self): def available_envs(cls): if not _has_meltingpot: return [] - from meltingpot.substrate import SUBSTRATES + else: + from meltingpot.substrate import SUBSTRATES - return list(SUBSTRATES) + return list(SUBSTRATES) def __init__( self, From ef7ff2d44d3bc78c17babca0d8f2d3cc4f0dfff7 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 19:00:49 +0200 Subject: [PATCH 20/59] add meltingpot description to docs --- torchrl/envs/libs/meltingpot.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 75f66f179c1..7cf2ba90cdd 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -60,6 +60,12 @@ class MeltingpotWrapper(_EnvWrapper): Paper: https://arxiv.org/abs/2211.13746 + Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals, + and has been designed to test a broad range of social interactions such as: cooperation, competition, deception, + reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent + reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios + on which to evaluate these trained agents. + Args: env (``meltingpot.utils.substrates.substrate.Substrate``): the meltingpot substrate to wrap. @@ -465,6 +471,12 @@ class MeltingpotEnv(MeltingpotWrapper): Paper: https://arxiv.org/abs/2211.13746 + Melting Pot assesses generalization to novel social situations involving both familiar and unfamiliar individuals, + and has been designed to test a broad range of social interactions such as: cooperation, competition, deception, + reciprocation, trust, stubbornness and so on. Melting Pot offers researchers a set of over 50 multi-agent + reinforcement learning substrates (multi-agent games) on which to train agents, and over 256 unique test scenarios + on which to evaluate these trained agents. + Args: substrate(str or ml_collections.config_dict.ConfigDict): the meltingpot substrate to build. Can be a string from :attr:`~.available_envs` or a ConfigDict for the substrate From 4ca4555abb9a859e9440ac2a0fd7d72ff8c98213 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 19:03:57 +0200 Subject: [PATCH 21/59] review comment -> improved efficiency --- torchrl/envs/libs/meltingpot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 7cf2ba90cdd..a937cdce9e2 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -368,7 +368,7 @@ def _reset( agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=0) - td.update({group: agent_tds}) + td.set(group, agent_tds) # Global state td.update(_filter_global_state_from_dict(obs[0], world=True)) @@ -434,7 +434,7 @@ def _step( agent_tds.append(agent_td) agent_tds = torch.stack(agent_tds, dim=0) - td.update({group: agent_tds}) + td.set(group, agent_tds) return td From dbca972cc2109c581196c666632eaf2700e1f9c2 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Sun, 7 Apr 2024 19:11:14 +0200 Subject: [PATCH 22/59] get_rgb_image --- test/test_libs.py | 10 ++++++++++ torchrl/envs/libs/meltingpot.py | 21 +++------------------ 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index bc40de82399..0f1e905d5fc 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3488,6 +3488,16 @@ def test_categorical_actions(self, categorical_actions): ) check_env_specs(env) + @pytest.mark.parametrize("rollout_steps", [1, 3]) + def test_render(self, rollout_steps): + env = MeltingpotEnv(substrate="commons_harvest__open") + td = env.rollout(2) + rollout_penultimate_image = td[-1].get("WORLD.RGB") + rollout_last_image = td[-1].get(("next", "WORLD.RGB")) + image_from_env = torch.from_numpy(env.get_rgb_image()) + assert torch.equal(rollout_last_image, image_from_env) + assert not torch.equal(rollout_penultimate_image, image_from_env) + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index a937cdce9e2..762ac6a936b 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -438,29 +438,14 @@ def _step( return td - def render(self, mode="human", filename=None) -> np.ndarray: - """Renders the environment using matplotlib. - - Args: - mode (str, optional): One of ``"human"``, ``"rgb_array"``. If ``"human"`` it renders the - environment in the GUI. In any case the function returns a RGB array. Defaults to ``"human"`` - filename (str, optional): Filename to save the render to. Defaults to ``None``, - in which case no file is saved. + def get_rgb_image(self) -> np.ndarray: + """Returns an RGB image of the environment. Returns: - np.ndarray: The rendered image + np.ndarray: The image """ - from matplotlib import pyplot as plt - rgb_arr = self._env.observation()[0]["WORLD.RGB"] - if mode == "human": - plt.cla() - plt.imshow(rgb_arr, interpolation="nearest") - if filename is None: - plt.show(block=False) - else: - plt.savefig(filename) return rgb_arr From 0b0333bc47744a965f90bb013f4cf13eb8f96d40 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 10:35:29 +0200 Subject: [PATCH 23/59] get_rgb_image --- test/test_libs.py | 2 +- torchrl/envs/libs/meltingpot.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 0f1e905d5fc..3be246d11ff 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3494,7 +3494,7 @@ def test_render(self, rollout_steps): td = env.rollout(2) rollout_penultimate_image = td[-1].get("WORLD.RGB") rollout_last_image = td[-1].get(("next", "WORLD.RGB")) - image_from_env = torch.from_numpy(env.get_rgb_image()) + image_from_env = env.get_rgb_image() assert torch.equal(rollout_last_image, image_from_env) assert not torch.equal(rollout_penultimate_image, image_from_env) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 762ac6a936b..efa7398a141 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -8,7 +8,6 @@ from typing import Dict, List, Mapping, Sequence -import numpy as np import torch from tensordict import TensorDict, TensorDictBase @@ -438,14 +437,14 @@ def _step( return td - def get_rgb_image(self) -> np.ndarray: + def get_rgb_image(self) -> torch.Tensor: """Returns an RGB image of the environment. Returns: - np.ndarray: The image + torch.Tensor: The image """ - rgb_arr = self._env.observation()[0]["WORLD.RGB"] + rgb_arr = torch.from_numpy(self._env.observation()[0]["WORLD.RGB"]) return rgb_arr From 1593753affd6415457a6bd4f6646a189b93cd381 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 10:39:04 +0200 Subject: [PATCH 24/59] lazy import --- torchrl/envs/libs/meltingpot.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index efa7398a141..2165bed0809 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -23,6 +23,14 @@ _WORLD_PREFIX = "WORLD." +def _get_envs(): + if not _has_meltingpot: + raise ImportError("meltingpot is not installed in your virtual environment.") + from meltingpot.substrate import SUBSTRATES + + return list(SUBSTRATES) + + def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa return { key: value @@ -165,10 +173,7 @@ def lib(self): def available_envs(cls): if not _has_meltingpot: return [] - else: - from meltingpot.substrate import SUBSTRATES - - return list(SUBSTRATES) + return _get_envs() def __init__( self, From 22dbc97102712f3f7c10e5f2380275c4f27e41d9 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 8 Apr 2024 09:41:46 +0100 Subject: [PATCH 25/59] Update torchrl/envs/libs/meltingpot.py --- torchrl/envs/libs/meltingpot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 2165bed0809..601685a7f66 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -446,11 +446,10 @@ def get_rgb_image(self) -> torch.Tensor: """Returns an RGB image of the environment. Returns: - torch.Tensor: The image + a ``torch.Tensor`` containing image in format WHC. """ - rgb_arr = torch.from_numpy(self._env.observation()[0]["WORLD.RGB"]) - return rgb_arr + return torch.from_numpy(self._env.observation()[0]["WORLD.RGB"]) class MeltingpotEnv(MeltingpotWrapper): From 5e070cef07a32a04f735013695c4c8dddfe12b4c Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 10:47:14 +0200 Subject: [PATCH 26/59] pypi name --- .github/unittest/linux_libs/scripts_meltingpot/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index 8eb8faf8e64..f376d3ab1a7 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -17,4 +17,4 @@ dependencies: - pytest-error-for-skips - expecttest - pyyaml - - meltingpot + - dm-meltingpot From c0ce67df14318fd191e348e5377c7885f8b25b33 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 10:53:17 +0200 Subject: [PATCH 27/59] world prefix --- test/test_libs.py | 6 +++--- torchrl/envs/libs/meltingpot.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 3be246d11ff..74f4e1dee66 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -103,7 +103,7 @@ ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper +from torchrl.envs.libs.meltingpot import _WORLD_PREFIX, MeltingpotEnv, MeltingpotWrapper from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv @@ -3492,8 +3492,8 @@ def test_categorical_actions(self, categorical_actions): def test_render(self, rollout_steps): env = MeltingpotEnv(substrate="commons_harvest__open") td = env.rollout(2) - rollout_penultimate_image = td[-1].get("WORLD.RGB") - rollout_last_image = td[-1].get(("next", "WORLD.RGB")) + rollout_penultimate_image = td[-1].get(_WORLD_PREFIX + "RGB") + rollout_last_image = td[-1].get(("next", _WORLD_PREFIX + "RGB")) image_from_env = env.get_rgb_image() assert torch.equal(rollout_last_image, image_from_env) assert not torch.equal(rollout_penultimate_image, image_from_env) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 601685a7f66..bc3db37d42d 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -53,9 +53,9 @@ def _global_state_spec_from_obs_spec( ) -> Mapping[str, "dm_env.specs.Array"]: # noqa # We only look at agent 0 since world entries are the same for all agents world_entries = _filter_global_state_from_dict(observation_spec[0], world=True) - if len(world_entries) != 1 and "WORLD.RGB" not in world_entries: + if len(world_entries) != 1 and _WORLD_PREFIX + "RGB" not in world_entries: raise ValueError( - f"Expected only one world entry named WORLD.RGB in observation_spec, but got {world_entries}" + f"Expected only one world entry named {_WORLD_PREFIX}RGB in observation_spec, but got {world_entries}" ) return world_entries @@ -449,7 +449,7 @@ def get_rgb_image(self) -> torch.Tensor: a ``torch.Tensor`` containing image in format WHC. """ - return torch.from_numpy(self._env.observation()[0]["WORLD.RGB"]) + return torch.from_numpy(self._env.observation()[0][_WORLD_PREFIX + "RGB"]) class MeltingpotEnv(MeltingpotWrapper): From 80036ae4b7d2854d7b655ac84f66853baee25bfb Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 11:27:38 +0200 Subject: [PATCH 28/59] add to torchrl[marl] --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 44e772528a7..3c844400306 100644 --- a/setup.py +++ b/setup.py @@ -224,7 +224,7 @@ def _main(argv): "h5py", "pillow", ], - "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1"], + "marl": ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot"], } extra_requires["all"] = set() for key in list(extra_requires.keys()): From 8b776b4e1c9e3d738fad2ada511c3c63011082ed Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 11:41:08 +0200 Subject: [PATCH 29/59] remove world prefix in specs and tds --- test/test_libs.py | 6 +++--- torchrl/envs/libs/meltingpot.py | 24 ++++++++++++++++-------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 74f4e1dee66..5c38ece5079 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -103,7 +103,7 @@ ) from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv -from torchrl.envs.libs.meltingpot import _WORLD_PREFIX, MeltingpotEnv, MeltingpotWrapper +from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper from torchrl.envs.libs.openml import OpenMLEnv from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv from torchrl.envs.libs.robohive import _has_robohive, RoboHiveEnv @@ -3492,8 +3492,8 @@ def test_categorical_actions(self, categorical_actions): def test_render(self, rollout_steps): env = MeltingpotEnv(substrate="commons_harvest__open") td = env.rollout(2) - rollout_penultimate_image = td[-1].get(_WORLD_PREFIX + "RGB") - rollout_last_image = td[-1].get(("next", _WORLD_PREFIX + "RGB")) + rollout_penultimate_image = td[-1].get("RGB") + rollout_last_image = td[-1].get(("next", "RGB")) image_from_env = env.get_rgb_image() assert torch.equal(rollout_last_image, image_from_env) assert not torch.equal(rollout_penultimate_image, image_from_env) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index bc3db37d42d..fd6fd376994 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -57,7 +57,11 @@ def _global_state_spec_from_obs_spec( raise ValueError( f"Expected only one world entry named {_WORLD_PREFIX}RGB in observation_spec, but got {world_entries}" ) - return world_entries + return _remove_world_prefix(world_entries) + + +def _remove_world_prefix(world_entries: Dict) -> Dict: + return {key[len(_WORLD_PREFIX) :]: value for key, value in world_entries.items()} class MeltingpotWrapper(_EnvWrapper): @@ -113,9 +117,9 @@ class MeltingpotWrapper(_EnvWrapper): ... ) >>> env_torchrl = MeltingpotWrapper(env=mp_env) >>> print(env_torchrl.rollout(max_steps=5)) - TensorDict( + TensorDict( fields={ - WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), @@ -133,7 +137,7 @@ class MeltingpotWrapper(_EnvWrapper): done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ observation: TensorDict( @@ -375,7 +379,9 @@ def _reset( td.set(group, agent_tds) # Global state - td.update(_filter_global_state_from_dict(obs[0], world=True)) + td.update( + _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True)) + ) tensordict_out = TensorDict( source=td, @@ -415,7 +421,9 @@ def _step( batch_size=self.batch_size, ) # Global state - td.update(_filter_global_state_from_dict(obs[0], world=True)) + td.update( + _remove_world_prefix(_filter_global_state_from_dict(obs[0], world=True)) + ) for group, agent_names in self.group_map.items(): agent_tds = [] @@ -504,7 +512,7 @@ class MeltingpotEnv(MeltingpotWrapper): >>> print(env_torchrl.rollout(max_steps=5)) TensorDict( fields={ - WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ action: Tensor(shape=torch.Size([5, 7]), device=cpu, dtype=torch.int64, is_shared=False), @@ -522,7 +530,7 @@ class MeltingpotEnv(MeltingpotWrapper): done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ - WORLD.RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), + RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( fields={ observation: TensorDict( From 6188003c49148e6d1f1d5127a777e23cb3e62eb4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 11:42:04 +0200 Subject: [PATCH 30/59] typo --- torchrl/envs/libs/meltingpot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index fd6fd376994..b45776080e4 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -117,7 +117,7 @@ class MeltingpotWrapper(_EnvWrapper): ... ) >>> env_torchrl = MeltingpotWrapper(env=mp_env) >>> print(env_torchrl.rollout(max_steps=5)) - TensorDict( + TensorDict( fields={ RGB: Tensor(shape=torch.Size([5, 144, 192, 3]), device=cpu, dtype=torch.uint8, is_shared=False), agents: TensorDict( From 2a3c4655ec68cd07a41082035ac72b94816ab218 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 12:48:09 +0200 Subject: [PATCH 31/59] amend --- test/test_libs.py | 2 +- torchrl/envs/libs/meltingpot.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 5c38ece5079..2912b3e0f7c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3453,7 +3453,7 @@ def test_collector(self): @pytest.mark.skipif(not _has_meltingpot, reason="Meltingpot not found") class TestMeltingpot: - @pytest.mark.parametrize("substrate", MeltingpotEnv.available_envs) + @pytest.mark.parametrize("substrate", MeltingpotWrapper.available_envs) def test_all_envs(self, substrate): env = MeltingpotEnv(substrate=substrate) check_env_specs(env) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index b45776080e4..2b797e06c03 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -26,9 +26,9 @@ def _get_envs(): if not _has_meltingpot: raise ImportError("meltingpot is not installed in your virtual environment.") - from meltingpot.substrate import SUBSTRATES + import meltingpot - return list(SUBSTRATES) + return list(meltingpot.substrate.SUBSTRATES) def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa From 2370be9e03ae3a62f7e25aa5066cfa711fa86564 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 15:08:16 +0200 Subject: [PATCH 32/59] change python version --- .github/workflows/test-linux-libs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index f86da7a9557..362e31f8365 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -273,7 +273,7 @@ jobs: fi set -euo pipefail - export PYTHON_VERSION="3.9" + export PYTHON_VERSION="3.11" export CU_VERSION="12.1" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" From 24d54287a43cfc425b697e056dd9d896e0ac7b08 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Mon, 8 Apr 2024 15:41:09 +0200 Subject: [PATCH 33/59] remove dependencies --- .github/unittest/linux_libs/scripts_meltingpot/environment.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index f376d3ab1a7..4aacdbcf977 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -5,10 +5,8 @@ dependencies: - pip - pip: - cloudpickle - - importlib-metadata - numpy - torch - - zipp - pytest - pytest-cov - pytest-mock From 467a48014e7134d57b07530e6c1be95a56ce9cc4 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 9 Apr 2024 08:15:20 +0200 Subject: [PATCH 34/59] explicit dmlab2d? --- .github/unittest/linux_libs/scripts_meltingpot/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index 4aacdbcf977..f79a9bec745 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -15,4 +15,5 @@ dependencies: - pytest-error-for-skips - expecttest - pyyaml + - dmlab2d - dm-meltingpot From ee35caec3f89b42ddd273df4ac677fe341f81684 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 13:04:34 +0200 Subject: [PATCH 35/59] amend --- .../unittest/linux_libs/scripts_meltingpot/environment.yml | 1 - .github/unittest/linux_libs/scripts_meltingpot/install.sh | 4 ++++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index 8eb8faf8e64..f509cf2046b 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -17,4 +17,3 @@ dependencies: - pytest-error-for-skips - expecttest - pyyaml - - meltingpot diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 5ac346c95c5..5805c8e1468 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -56,3 +56,7 @@ python -c "import tensordict" printf "* Installing torchrl\n" python setup.py develop python -c "import torchrl" + +# Install meltingpot from git +LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') +pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From c6f516d313fb19cd2827537517b37601e8bdbc8f Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 9 Apr 2024 14:11:47 +0200 Subject: [PATCH 36/59] python 3.10? --- .github/workflows/test-linux-libs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 362e31f8365..803902ab60a 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -273,7 +273,7 @@ jobs: fi set -euo pipefail - export PYTHON_VERSION="3.11" + export PYTHON_VERSION="3.10" export CU_VERSION="12.1" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" From 4f6ffc324e66e1cc2eaf6c2645f1605569d7b0be Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 9 Apr 2024 14:16:35 +0200 Subject: [PATCH 37/59] less explicit deps --- .github/unittest/linux_libs/scripts_meltingpot/environment.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index f79a9bec745..87a2537898b 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -5,7 +5,6 @@ dependencies: - pip - pip: - cloudpickle - - numpy - torch - pytest - pytest-cov @@ -14,6 +13,4 @@ dependencies: - pytest-rerunfailures - pytest-error-for-skips - expecttest - - pyyaml - - dmlab2d - dm-meltingpot From 18d7bce2610d6baab5f98d66dc482bc8d25527b3 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 9 Apr 2024 14:21:34 +0200 Subject: [PATCH 38/59] attempt different import --- torchrl/envs/libs/meltingpot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/meltingpot.py b/torchrl/envs/libs/meltingpot.py index 2b797e06c03..446b3dac292 100644 --- a/torchrl/envs/libs/meltingpot.py +++ b/torchrl/envs/libs/meltingpot.py @@ -26,9 +26,9 @@ def _get_envs(): if not _has_meltingpot: raise ImportError("meltingpot is not installed in your virtual environment.") - import meltingpot + from meltingpot.configs import substrates as substrate_configs - return list(meltingpot.substrate.SUBSTRATES) + return list(substrate_configs.SUBSTRATES) def _filter_global_state_from_dict(obs_dict: Dict, world: bool) -> Dict: # noqa From 8d095315007790cdb95e0cbaec44edc9eebc40e6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 14:38:49 +0200 Subject: [PATCH 39/59] amend --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 5805c8e1468..470808b2206 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -58,5 +58,7 @@ python setup.py develop python -c "import torchrl" # Install meltingpot from git +LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/lab2d/tags" | jq -r '.[0].name') +pip3 install git+https://github.com/google-deepmind/lab2d@${LATEST_TAG} LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From 6c8ff357beae5cdf17434e0ab4ae4a29ca1b2947 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 14:44:45 +0200 Subject: [PATCH 40/59] amend --- .github/workflows/test-linux-libs.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 803902ab60a..4c79bcd9767 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -257,6 +257,7 @@ jobs: unittests-meltingpot: uses: pytorch/test-infra/.github/workflows/linux_job.yml@main + if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') }} with: repository: pytorch/rl runner: "linux.g5.4xlarge.nvidia.gpu" From 0b4c7d56c62b52faae84efd56fd4b2e4cbe2fc83 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:04:53 +0200 Subject: [PATCH 41/59] amend --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 470808b2206..c89966f5d2f 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -57,6 +57,7 @@ printf "* Installing torchrl\n" python setup.py develop python -c "import torchrl" +conda install conda-forge::jq -y # Install meltingpot from git LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/lab2d/tags" | jq -r '.[0].name') pip3 install git+https://github.com/google-deepmind/lab2d@${LATEST_TAG} From 883b04451b9f3f3daafd45fb22ab35f773160dfa Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:21:11 +0200 Subject: [PATCH 42/59] amend --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index c89966f5d2f..01ff0596339 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -59,7 +59,6 @@ python -c "import torchrl" conda install conda-forge::jq -y # Install meltingpot from git -LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/lab2d/tags" | jq -r '.[0].name') -pip3 install git+https://github.com/google-deepmind/lab2d@${LATEST_TAG} +pip3 install git+https://github.com/google-deepmind/lab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From b8d7527675621bc9539910e8c57e8a160f36cfc8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 9 Apr 2024 15:41:38 +0100 Subject: [PATCH 43/59] Update .github/unittest/linux_libs/scripts_meltingpot/install.sh --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 01ff0596339..2966c3246c3 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -59,6 +59,6 @@ python -c "import torchrl" conda install conda-forge::jq -y # Install meltingpot from git -pip3 install git+https://github.com/google-deepmind/lab2d +pip3 install dmlab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From 73873924f48c07adc1ea682d2210b905edecc953 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 10 Apr 2024 10:19:56 +0200 Subject: [PATCH 44/59] attempt 3.9 python --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 8 ++++---- .github/workflows/test-linux-libs.yml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 2966c3246c3..b48637dad36 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -57,8 +57,8 @@ printf "* Installing torchrl\n" python setup.py develop python -c "import torchrl" -conda install conda-forge::jq -y +#conda install conda-forge::jq -y # Install meltingpot from git -pip3 install dmlab2d -LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') -pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} +#pip3 install dmlab2d +#LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') +#pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 4c79bcd9767..71984fbe369 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -274,7 +274,7 @@ jobs: fi set -euo pipefail - export PYTHON_VERSION="3.10" + export PYTHON_VERSION="3.9" export CU_VERSION="12.1" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly" From b51f688042e34c7c60762d45dab9817caf9b4e46 Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 10 Apr 2024 10:20:02 +0200 Subject: [PATCH 45/59] attempt 3.9 python --- .github/unittest/linux_libs/scripts_meltingpot/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index 31e8c01407f..87a2537898b 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -13,3 +13,4 @@ dependencies: - pytest-rerunfailures - pytest-error-for-skips - expecttest + - dm-meltingpot From 179f280190bbe239a6f685935792b9b7574a3db3 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 13:00:28 +0200 Subject: [PATCH 46/59] amend --- .../unittest/linux_libs/scripts_meltingpot/environment.yml | 1 - .github/unittest/linux_libs/scripts_meltingpot/install.sh | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml index 87a2537898b..31e8c01407f 100644 --- a/.github/unittest/linux_libs/scripts_meltingpot/environment.yml +++ b/.github/unittest/linux_libs/scripts_meltingpot/environment.yml @@ -13,4 +13,3 @@ dependencies: - pytest-rerunfailures - pytest-error-for-skips - expecttest - - dm-meltingpot diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index b48637dad36..6f80a601947 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -60,5 +60,5 @@ python -c "import torchrl" #conda install conda-forge::jq -y # Install meltingpot from git #pip3 install dmlab2d -#LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') -#pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} +LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') +pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From 4eef4d50329192411c2e9e0456a01de5a456dbb0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 14:54:42 +0200 Subject: [PATCH 47/59] amend --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index 6f80a601947..c72e35023ef 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -57,7 +57,7 @@ printf "* Installing torchrl\n" python setup.py develop python -c "import torchrl" -#conda install conda-forge::jq -y +conda install conda-forge::jq -y # Install meltingpot from git #pip3 install dmlab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') From bf1ff61a70ec6619607ca94c64ebafe3ee0c51ff Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Wed, 10 Apr 2024 15:11:45 +0200 Subject: [PATCH 48/59] new attempt --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index c72e35023ef..e3a067a428b 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -61,4 +61,4 @@ conda install conda-forge::jq -y # Install meltingpot from git #pip3 install dmlab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') -pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} +pip3 install dm-meltingpot From 87937ca6104a9cda3f400eb7249c99f780aacd1c Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 14:54:57 +0100 Subject: [PATCH 49/59] Update .github/unittest/linux_libs/scripts_meltingpot/install.sh --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index e3a067a428b..c72e35023ef 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -61,4 +61,4 @@ conda install conda-forge::jq -y # Install meltingpot from git #pip3 install dmlab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') -pip3 install dm-meltingpot +pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From ea445056ef00d847c96cbf841de2541f245a229e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 15:01:04 +0100 Subject: [PATCH 50/59] Update .github/unittest/linux_libs/scripts_meltingpot/install.sh --- .github/unittest/linux_libs/scripts_meltingpot/install.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/install.sh b/.github/unittest/linux_libs/scripts_meltingpot/install.sh index c72e35023ef..7c13fbf54d1 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/install.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/install.sh @@ -61,4 +61,7 @@ conda install conda-forge::jq -y # Install meltingpot from git #pip3 install dmlab2d LATEST_TAG=$(curl "https://api.github.com/repos/google-deepmind/meltingpot/tags" | jq -r '.[0].name') + +echo $(ldd --version) + pip3 install git+https://github.com/google-deepmind/meltingpot@${LATEST_TAG} From c04f821bc5715e81782b9307b004cf12c57af3ec Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 15:45:47 +0100 Subject: [PATCH 51/59] Update .github/workflows/test-linux-libs.yml --- .github/workflows/test-linux-libs.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 71984fbe369..0b48032418c 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -262,7 +262,8 @@ jobs: repository: pytorch/rl runner: "linux.g5.4xlarge.nvidia.gpu" gpu-arch-type: cuda - gpu-arch-version: "11.7" + gpu-arch-version: "12.1" + docker-image: pytorch/manylinux-cuda121 timeout: 120 script: | if [[ "${{ github.ref }}" =~ release/* ]]; then From ea1367b64e7e7225513a3b3402ae8008ca08268e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 16:04:03 +0100 Subject: [PATCH 52/59] Update setup_env.sh --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index dc524958e5e..e1da162ef3c 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -14,6 +14,10 @@ root_dir="$(git rev-parse --show-toplevel)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" +echo $(rpm -qa | grep glibc) +yum update +yum install glibc.x86_64 + cd "${root_dir}" case "$(uname -s)" in From 8376e47c8483e52955bdb5f782e684b550653bea Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 16:27:10 +0100 Subject: [PATCH 53/59] Update .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index e1da162ef3c..c6468a8d3cb 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -15,8 +15,8 @@ conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" echo $(rpm -qa | grep glibc) -yum update -yum install glibc.x86_64 +yum update -y +yum install glibc.x86_64 -y cd "${root_dir}" From 99b5ff674e4df9fb46ac39b0d0cb938cff1d2a60 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 10 Apr 2024 17:00:15 +0100 Subject: [PATCH 54/59] Apply suggestions from code review --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 5 +---- .github/workflows/test-linux-libs.yml | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index c6468a8d3cb..85b9975198d 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -5,7 +5,7 @@ # # Do not install PyTorch and torchvision here, otherwise they also get cached. -set -e +set -e -v this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" @@ -14,9 +14,6 @@ root_dir="$(git rev-parse --show-toplevel)" conda_dir="${root_dir}/conda" env_dir="${root_dir}/env" -echo $(rpm -qa | grep glibc) -yum update -y -yum install glibc.x86_64 -y cd "${root_dir}" diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 0b48032418c..1023440642c 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -263,7 +263,7 @@ jobs: runner: "linux.g5.4xlarge.nvidia.gpu" gpu-arch-type: cuda gpu-arch-version: "12.1" - docker-image: pytorch/manylinux-cuda121 + docker-image: cuda:12.4.1-runtime-ubuntu22.04 timeout: 120 script: | if [[ "${{ github.ref }}" =~ release/* ]]; then From 6196e5540d7fd31cc35be478b1b5dc28d9604030 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 18:01:48 +0100 Subject: [PATCH 55/59] Update test-linux-libs.yml --- .github/workflows/test-linux-libs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 1023440642c..fdbe84147c9 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -263,7 +263,7 @@ jobs: runner: "linux.g5.4xlarge.nvidia.gpu" gpu-arch-type: cuda gpu-arch-version: "12.1" - docker-image: cuda:12.4.1-runtime-ubuntu22.04 + docker-image: "nvidia/cuda:12.4.1-runtime-ubuntu22.04" timeout: 120 script: | if [[ "${{ github.ref }}" =~ release/* ]]; then From 418fd3f6806f7e4523c9cf3f44243a1afc2280c0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 18:13:15 +0100 Subject: [PATCH 56/59] Update setup_env.sh --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index 85b9975198d..68babd68986 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -7,6 +7,11 @@ set -e -v +apt-get install -y git wget g++ gcc + +apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 +apt-get upgrade -y libstdc++6 + this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" # Avoid error: "fatal: unsafe repository" git config --global --add safe.directory '*' From bcc7da26ee29dd41b7bd3a28cc29b05b4d0b408a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 18:19:31 +0100 Subject: [PATCH 57/59] Update setup_env.sh --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index 68babd68986..717c1978a7c 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -7,9 +7,8 @@ set -e -v -apt-get install -y git wget g++ gcc - -apt-get install -y libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 +apt-get update && apt-get upgrade -y +apt-get install -y git wget g++ gcc libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 apt-get upgrade -y libstdc++6 this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" From 74ca60fb831f391126f05a63b341ff3743c3176f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 20:45:01 +0100 Subject: [PATCH 58/59] Update .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh --- .github/unittest/linux_libs/scripts_meltingpot/setup_env.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh index 717c1978a7c..b342c57f099 100755 --- a/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_meltingpot/setup_env.sh @@ -8,7 +8,7 @@ set -e -v apt-get update && apt-get upgrade -y -apt-get install -y git wget g++ gcc libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 +apt-get install -y git wget g++ gcc libglfw3 libgl1-mesa-glx libosmesa6 libglew-dev libglvnd0 libgl1 libglx0 libegl1 libgles2 curl apt-get upgrade -y libstdc++6 this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" From 4953bac63139e5d32ef7f74fc5939a390b54471a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 17 Apr 2024 21:06:35 +0100 Subject: [PATCH 59/59] amend --- .github/workflows/test-linux-libs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-linux-libs.yml b/.github/workflows/test-linux-libs.yml index 71984fbe369..f9c7ab6304d 100644 --- a/.github/workflows/test-linux-libs.yml +++ b/.github/workflows/test-linux-libs.yml @@ -274,7 +274,7 @@ jobs: fi set -euo pipefail - export PYTHON_VERSION="3.9" + export PYTHON_VERSION="3.11" export CU_VERSION="12.1" export TAR_OPTIONS="--no-same-owner" export UPLOAD_CHANNEL="nightly"