From 7ee8790db3f8175add7f21c13bd60881c9f66f2e Mon Sep 17 00:00:00 2001 From: Sriram S K Date: Sun, 7 Apr 2024 17:49:12 +0530 Subject: [PATCH 01/14] make robohive tests actually run --- test/test_libs.py | 58 ++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index c72fcc562e6..7cfad1712a9 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3329,7 +3329,7 @@ def test_collector(self, task, parallel): break -@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found") +@pytest.mark.skipif(not _has_robohive, reason="RoboHive not found") class TestRoboHive: # unfortunately we must import robohive to get the available envs # and this import will occur whenever pytest is run on this file. @@ -3339,37 +3339,33 @@ class TestRoboHive: # Locally these imports can be annoying, especially given the amount of # stuff printed by robohive. @pytest.mark.parametrize("from_pixels", [True, False]) + @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) @set_gym_backend("gym") - def test_robohive(self, from_pixels): - for envname in RoboHiveEnv.available_envs: - try: - if any( - substr in envname - for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") - ): - torchrl_logger.info("not testing envs with prebuilt rendering") - return - if "Adroit" in envname: - torchrl_logger.info("tcdm are broken") - return - try: - env = RoboHiveEnv(envname) - except AttributeError as err: - if "'MjData' object has no attribute 'get_body_xipos'" in str(err): - torchrl_logger.info("tcdm are broken") - return - else: - raise err - if ( - from_pixels - and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 - ): - torchrl_logger.info("no camera") - return - check_env_specs(env) - except Exception as err: - raise RuntimeError(f"Test with robohive end {envname} failed.") from err - + def test_robohive(self, envname, from_pixels): + if any( + substr in envname + for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") + ): + torchrl_logger.info("not testing envs with prebuilt rendering") + return + if "Adroit" in envname: + torchrl_logger.info("tcdm are broken") + return + if ( + from_pixels + and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 + ): + torchrl_logger.info("no camera") + return + try: + env = RoboHiveEnv(envname, from_pixels=from_pixels) + except AttributeError as err: + if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + torchrl_logger.info("tcdm are broken") + return + else: + raise err + check_env_specs(env) @pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") class TestSmacv2: From ba59958309487b7cf6ea61156c82c443b70dbcca Mon Sep 17 00:00:00 2001 From: Sriram S K Date: Sun, 7 Apr 2024 21:05:52 +0530 Subject: [PATCH 02/14] lint changes --- test/test_libs.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 7cfad1712a9..321aa9135e7 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3339,22 +3339,16 @@ class TestRoboHive: # Locally these imports can be annoying, especially given the amount of # stuff printed by robohive. @pytest.mark.parametrize("from_pixels", [True, False]) - @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) + @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs[:10]) @set_gym_backend("gym") def test_robohive(self, envname, from_pixels): - if any( - substr in envname - for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") - ): + if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): torchrl_logger.info("not testing envs with prebuilt rendering") return if "Adroit" in envname: torchrl_logger.info("tcdm are broken") return - if ( - from_pixels - and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 - ): + if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: torchrl_logger.info("no camera") return try: @@ -3367,6 +3361,7 @@ def test_robohive(self, envname, from_pixels): raise err check_env_specs(env) + @pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") class TestSmacv2: def test_env_procedural(self): From b8a2e4d444314aa0518a92ca43429fbea7795bf8 Mon Sep 17 00:00:00 2001 From: Sriram S K Date: Mon, 8 Apr 2024 20:31:12 +0530 Subject: [PATCH 03/14] hack to fix camera name for DKitty envs --- test/test_libs.py | 2 +- torchrl/envs/libs/robohive.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_libs.py b/test/test_libs.py index 321aa9135e7..68d8f2b4f86 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3339,7 +3339,7 @@ class TestRoboHive: # Locally these imports can be annoying, especially given the amount of # stuff printed by robohive. @pytest.mark.parametrize("from_pixels", [True, False]) - @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs[:10]) + @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) @set_gym_backend("gym") def test_robohive(self, envname, from_pixels): if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 4d4998eb721..66e176226a9 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -233,6 +233,7 @@ def register_visual_env(cls, env_name, cams): if not len(cams): raise RuntimeError("Cannot create a visual envs without cameras.") + cams = [i.replace("A:", "A_") for i in cams] cams = sorted(cams) new_env_name = "-".join([cam[:-3] for cam in cams] + [env_name]) if new_env_name in cls.env_list: From 1414632f1196ed8927767af44026aa96158a5f04 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 09:47:14 +0100 Subject: [PATCH 04/14] amend --- test/test_libs.py | 4 ++-- torchrl/envs/libs/robohive.py | 37 ++++++++++++----------------------- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 321aa9135e7..f834c550f6a 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3338,10 +3338,10 @@ class TestRoboHive: # In the CI, robohive should not coexist with other libs so that's fine. # Locally these imports can be annoying, especially given the amount of # stuff printed by robohive. - @pytest.mark.parametrize("from_pixels", [True, False]) + @pytest.mark.parametrize("from_pixels", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs[:10]) - @set_gym_backend("gym") def test_robohive(self, envname, from_pixels): + torchrl_logger.info(f"{envname}-{from_pixels}") if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): torchrl_logger.info("not testing envs with prebuilt rendering") return diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 4d4998eb721..4b0454b3f1c 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -14,10 +14,18 @@ from tensordict import make_tensordict, TensorDict from torchrl._utils import implement_for from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec -from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv +from torchrl.envs.libs.gym import ( + _AsyncMeta, + _gym_to_torchrl_spec_transform, + gym_backend, + GymEnv, +) from torchrl.envs.utils import _classproperty, make_composite_from_td -_has_gym = importlib.util.find_spec("gym") is not None +_has_gym = ( + importlib.util.find_spec("gym") is not None + or importlib.util.find_spec("gymnasium") is not None +) _has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym if _has_robohive: @@ -126,7 +134,7 @@ def CURR_DIR(cls): def available_envs(cls): if not _has_robohive: return [] - RoboHiveEnv.register_envs() + cls.register_envs() return cls.env_list @classmethod @@ -143,25 +151,6 @@ def register_envs(cls): if not len(robohive_envs): raise RuntimeError("did not load any environment.") - @implement_for( - "gymnasium", - ) # make sure gym 0.13 is installed, otherwise raise an exception - def _build_env(self, *args, **kwargs): # noqa: F811 - raise NotImplementedError( - "Your gym version is too recent, RoboHiveEnv is only compatible with gym==0.13." - ) - - @implement_for( - "gym", "0.14", None - ) # make sure gym 0.13 is installed, otherwise raise an exception - def _build_env(self, *args, **kwargs): # noqa: F811 - raise NotImplementedError( - "Your gym version is too recent, RoboHiveEnv is only compatible with gym 0.13." - ) - - @implement_for( - "gym", None, "0.14" - ) # make sure gym 0.13 is installed, otherwise raise an exception def _build_env( # noqa: F811 self, env_name: str, @@ -382,8 +371,6 @@ def to(self, *args, **kwargs): @classmethod def get_available_cams(cls, env_name): - import gym - - env = gym.make(env_name) + env = gym_backend().make(env_name) cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)] return cams From 19a5953f9af7a0c30111ab6c44b933c308abd82e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 09:47:37 +0100 Subject: [PATCH 05/14] amend --- torchrl/envs/libs/robohive.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index cb957336ce9..16af12a0264 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -12,7 +12,6 @@ import numpy as np import torch from tensordict import make_tensordict, TensorDict -from torchrl._utils import implement_for from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import ( _AsyncMeta, From 7e956ac422bb5a8a25d7c4499d894cc4a6f55b89 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 09:49:51 +0100 Subject: [PATCH 06/14] amend --- .github/unittest/linux_libs/scripts_robohive/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_robohive/environment.yml b/.github/unittest/linux_libs/scripts_robohive/environment.yml index 705b8522c92..cff88245d1e 100644 --- a/.github/unittest/linux_libs/scripts_robohive/environment.yml +++ b/.github/unittest/linux_libs/scripts_robohive/environment.yml @@ -6,7 +6,7 @@ dependencies: - protobuf - pip: # Initial version is required to install Atari ROMS in setup_env.sh - - gym==0.13 + - gymnasium - hypothesis - future - cloudpickle From 6e5e28405afb04c1f238507b25a7580701ea1485 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 10:59:26 +0100 Subject: [PATCH 07/14] amend --- .github/unittest/linux_libs/scripts_robohive/setup_env.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index 66a496eb3ff..73339a25a64 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -78,4 +78,6 @@ conda env update --file "${this_dir}/environment.yml" --prune conda install conda-forge::ffmpeg -y -pip install git+https://github.com/vikashplus/robohive@main +pip install robohive +# make sure only gymnasium is available +pip uninstall gym -y From f51dd4c6fa08855da4247027a47b7d755ba35b4a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 11:16:15 +0100 Subject: [PATCH 08/14] amend --- .github/unittest/linux_libs/scripts_robohive/setup_env.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index 73339a25a64..f98087074f5 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -79,5 +79,6 @@ conda env update --file "${this_dir}/environment.yml" --prune conda install conda-forge::ffmpeg -y pip install robohive + # make sure only gymnasium is available -pip uninstall gym -y +# pip uninstall gym -y From 07c0c11c4d3ae59838a121e336efff9023b6a9d7 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 13:00:24 +0100 Subject: [PATCH 09/14] amend --- .github/unittest/linux_libs/scripts_robohive/setup_env.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh index f98087074f5..38e6d350354 100755 --- a/.github/unittest/linux_libs/scripts_robohive/setup_env.sh +++ b/.github/unittest/linux_libs/scripts_robohive/setup_env.sh @@ -80,5 +80,7 @@ conda install conda-forge::ffmpeg -y pip install robohive +python3 -m robohive_init + # make sure only gymnasium is available # pip uninstall gym -y From 2a244007bf60269a3c1ed0a8fc68122c8aec9677 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 13:18:04 +0100 Subject: [PATCH 10/14] amend --- test/test_libs.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index a14e2626b14..5055fa51688 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -7,6 +7,8 @@ from contextlib import nullcontext from pathlib import Path +import robohive + from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay @@ -3341,25 +3343,26 @@ class TestRoboHive: @pytest.mark.parametrize("from_pixels", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) def test_robohive(self, envname, from_pixels): - torchrl_logger.info(f"{envname}-{from_pixels}") - if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): - torchrl_logger.info("not testing envs with prebuilt rendering") - return - if "Adroit" in envname: - torchrl_logger.info("tcdm are broken") - return - if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: - torchrl_logger.info("no camera") - return - try: - env = RoboHiveEnv(envname, from_pixels=from_pixels) - except AttributeError as err: - if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + with set_gym_backend("gymnasium"): + torchrl_logger.info(f"{envname}-{from_pixels}") + if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): + torchrl_logger.info("not testing envs with prebuilt rendering") + return + if "Adroit" in envname: torchrl_logger.info("tcdm are broken") return - else: - raise err - check_env_specs(env) + if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: + torchrl_logger.info("no camera") + return + try: + env = RoboHiveEnv(envname, from_pixels=from_pixels) + except AttributeError as err: + if "'MjData' object has no attribute 'get_body_xipos'" in str(err): + torchrl_logger.info("tcdm are broken") + return + else: + raise err + check_env_specs(env) @pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found") From 4d655f70e54bd6e43e14195dc29ef9e4849a8aba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 13:33:59 +0100 Subject: [PATCH 11/14] amend --- torchrl/envs/libs/robohive.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 16af12a0264..60403e2ec47 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -273,14 +273,16 @@ def get_obs(): _dict = {} _dict.update(get_obs()) _dict["action"] = action = env.action_space.sample() - _, r, d, _ = env.step(action) + _, r, trunc, term, done, _ = self._output_transform(env.step(action)) _dict[("next", "reward")] = r.reshape(1) _dict[("next", "done")] = [1] + _dict[("next", "terminated")] = [1] + _dict[("next", "truncated")] = [1] _dict["next"] = get_obs() rollout[i] = TensorDict(_dict, []) observation_spec = make_composite_from_td( - rollout.get("next").exclude("done", "reward")[0] + rollout.get("next").exclude("done", "reward", "terminated", "truncated")[0] ) self.observation_spec = observation_spec From 073e6557d6b8bbffffc5874a56c866efff81a92d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 13:52:51 +0100 Subject: [PATCH 12/14] amend --- test/test_libs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_libs.py b/test/test_libs.py index 30d667167e5..8e9f84a0bd6 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -3412,6 +3412,7 @@ def test_robohive(self, envname, from_pixels): return else: raise err + torchrl_logger.info("rollout", env.rollout(4)) check_env_specs(env) From b4a28fac45c9a99b17e225364fccbfa9bf6511cc Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 14:10:47 +0100 Subject: [PATCH 13/14] amend --- test/test_libs.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 8e9f84a0bd6..81c0bfecd2e 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -7,8 +7,6 @@ from contextlib import nullcontext from pathlib import Path -import robohive - from torchrl._utils import logger as torchrl_logger from torchrl.data.datasets.gen_dgrl import GenDGRLExperienceReplay From 772e8935ea95f1027f3b3adbaeac55677a9c142a Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Apr 2024 15:15:46 +0100 Subject: [PATCH 14/14] amend --- test/test_libs.py | 24 ++++++++++++---- torchrl/envs/common.py | 3 +- torchrl/envs/libs/robohive.py | 54 ++++++++++++++++++++++++++--------- 3 files changed, 60 insertions(+), 21 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index 81c0bfecd2e..3f14a63850c 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -45,7 +45,12 @@ rollout_consistency_assertion, ) from packaging import version -from tensordict import assert_allclose_td, LazyStackedTensorDict, TensorDict +from tensordict import ( + assert_allclose_td, + is_tensor_collection, + LazyStackedTensorDict, + TensorDict, +) from tensordict.nn import ( ProbabilisticTensorDictModule, TensorDictModule, @@ -3386,20 +3391,24 @@ class TestRoboHive: # The other option would be not to use parametrize but that also # means less informative error trace stacks. # In the CI, robohive should not coexist with other libs so that's fine. - # Locally these imports can be annoying, especially given the amount of - # stuff printed by robohive. + # Robohive logging behaviour can be controlled via ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT @pytest.mark.parametrize("from_pixels", [False, True]) @pytest.mark.parametrize("envname", RoboHiveEnv.available_envs) def test_robohive(self, envname, from_pixels): with set_gym_backend("gymnasium"): torchrl_logger.info(f"{envname}-{from_pixels}") - if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")): + if any( + substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s") + ): torchrl_logger.info("not testing envs with prebuilt rendering") return if "Adroit" in envname: torchrl_logger.info("tcdm are broken") return - if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0: + if ( + from_pixels + and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0 + ): torchrl_logger.info("no camera") return try: @@ -3410,7 +3419,10 @@ def test_robohive(self, envname, from_pixels): return else: raise err - torchrl_logger.info("rollout", env.rollout(4)) + # Make sure that the stack is dense + for val in env.rollout(4).values(True): + if is_tensor_collection(val): + assert not isinstance(val, LazyStackedTensorDict) check_env_specs(env) diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index c536d308db8..6e724bf9245 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2781,6 +2781,7 @@ def empty_cache(self): they may change during the execution of the code (eg, when adding a transform). """ + self.__dict__["_step_mdp_value"] = None self.__dict__["_reward_keys"] = None self.__dict__["_done_keys"] = None self.__dict__["_action_keys"] = None @@ -2813,7 +2814,7 @@ def reset_keys(self) -> List[NestedKey]: @property def _filtered_reset_keys(self): - """Returns the only the effective reset keys, discarding nested resets if they're not being used.""" + """Returns only the effective reset keys, discarding nested resets if they're not being used.""" reset_keys = self.reset_keys result = [] diff --git a/torchrl/envs/libs/robohive.py b/torchrl/envs/libs/robohive.py index 60403e2ec47..6c39a30e2a5 100644 --- a/torchrl/envs/libs/robohive.py +++ b/torchrl/envs/libs/robohive.py @@ -11,7 +11,7 @@ import numpy as np import torch -from tensordict import make_tensordict, TensorDict +from tensordict import TensorDict from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec from torchrl.envs.libs.gym import ( _AsyncMeta, @@ -214,6 +214,15 @@ def _build_env( # noqa: F811 self.set_info_dict_reader(self.read_info) return env + def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821 + out = super()._make_specs(env=env, batch_size=batch_size) + self.env.reset() + *_, info = self.env.step(self.env.action_space.sample()) + info = self.read_info(info, TensorDict({}, [])) + info = info.get("info") + self.observation_spec["observation"] = make_composite_from_td(info) + return out + @classmethod def register_visual_env(cls, env_name, cams): with set_directory(cls.CURR_DIR): @@ -221,9 +230,9 @@ def register_visual_env(cls, env_name, cams): if not len(cams): raise RuntimeError("Cannot create a visual envs without cameras.") - cams = [i.replace("A:", "A_") for i in cams] cams = sorted(cams) - new_env_name = "-".join([cam[:-3] for cam in cams] + [env_name]) + cams_rep = [i.replace("A:", "A_") for i in cams] + new_env_name = "-".join([cam[:-3] for cam in cams_rep] + [env_name]) if new_env_name in cls.env_list: return new_env_name visual_keys = [f"rgb:{c}:224x224:2d" for c in cams] @@ -298,6 +307,7 @@ def get_obs(): rollout = rollout[..., 0] spec = make_composite_from_td(rollout) self.observation_spec.update(spec) + self.empty_cache() def set_from_pixels(self, from_pixels: bool) -> None: """Sets the from_pixels attribute to an existing environment. @@ -344,18 +354,34 @@ def read_obs(self, observation): def read_info(self, info, tensordict_out): out = {} - for key, value in info.items(): - if key in ("obs_dict", "done", "reward", *self._env.obs_keys, "act"): - continue - if isinstance(value, dict): - value = {key: _val for key, _val in value.items() if _val is not None} - value = make_tensordict(value, batch_size=[]) - if value is not None: - out[key] = value - tensordict_out.update(out) - tensordict_out.update( - tensordict_out.apply(lambda x: x.reshape((1,)) if not x.shape else x) + if not info: + info_spec = self.observation_spec.get("info", None) + if info_spec is None: + return tensordict_out + tensordict_out.set("info", info_spec.zero()) + return tensordict_out + out = ( + TensorDict(info, []) + .filter_non_tensor_data() + .exclude("obs_dict", "done", "reward", *self._env.obs_keys, "act") ) + if "info" in self.observation_spec.keys(): + info_spec = self.observation_spec["info"] + + def func(name, x): + spec = info_spec.get(name, None) + if spec is None: + return None + return x.reshape(info_spec[name].shape) + + out.update(out.named_apply(func, nested_keys=True, filter_empty=True)) + else: + out.update( + out.apply( + lambda x: x.reshape((1,)) if not x.shape else x, filter_empty=True + ) + ) + tensordict_out.set("info", out) return tensordict_out def _init_env(self):