Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix RoboHiveEnv tests #2062

Merged
merged 17 commits into from
Apr 15, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- protobuf
- pip:
# Initial version is required to install Atari ROMS in setup_env.sh
- gym==0.13
- gymnasium
- hypothesis
- future
- cloudpickle
Expand Down
5 changes: 4 additions & 1 deletion .github/unittest/linux_libs/scripts_robohive/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,7 @@ conda env update --file "${this_dir}/environment.yml" --prune

conda install conda-forge::ffmpeg -y

pip install git+https://github.com/vikashplus/robohive@main
pip install robohive

# make sure only gymnasium is available
# pip uninstall gym -y
55 changes: 23 additions & 32 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3329,7 +3329,7 @@ def test_collector(self, task, parallel):
break


@pytest.mark.skipif(not _has_robohive, reason="SMACv2 not found")
@pytest.mark.skipif(not _has_robohive, reason="RoboHive not found")
class TestRoboHive:
# unfortunately we must import robohive to get the available envs
# and this import will occur whenever pytest is run on this file.
Expand All @@ -3338,37 +3338,28 @@ class TestRoboHive:
# In the CI, robohive should not coexist with other libs so that's fine.
# Locally these imports can be annoying, especially given the amount of

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RoboHive's prints can be contorled using an env variable

ROBOHIVE_VERBOSITY=ALL/INFO/(WARN)/ERROR/ONCE/ALWAYS/SILENT

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure
The goal of this comment is for anyone running the tests and seeing loads of messages from imports. That could be a bit surprising and annoying. We can add that there's a strategy to silence them. I guess one could argue that importing a library should not print anything by default though...

# stuff printed by robohive.
@pytest.mark.parametrize("from_pixels", [True, False])
@set_gym_backend("gym")
def test_robohive(self, from_pixels):
for envname in RoboHiveEnv.available_envs:
try:
if any(
substr in envname
for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")
):
torchrl_logger.info("not testing envs with prebuilt rendering")
return
if "Adroit" in envname:
torchrl_logger.info("tcdm are broken")
return
try:
env = RoboHiveEnv(envname)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
return
else:
raise err
if (
from_pixels
and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0
):
torchrl_logger.info("no camera")
return
check_env_specs(env)
except Exception as err:
raise RuntimeError(f"Test with robohive end {envname} failed.") from err
@pytest.mark.parametrize("from_pixels", [False, True])
@pytest.mark.parametrize("envname", RoboHiveEnv.available_envs)
def test_robohive(self, envname, from_pixels):
torchrl_logger.info(f"{envname}-{from_pixels}")
if any(substr in envname for substr in ("_vr3m", "_vrrl", "_vflat", "_vvc1s")):
torchrl_logger.info("not testing envs with prebuilt rendering")
return
if "Adroit" in envname:
torchrl_logger.info("tcdm are broken")
return
if from_pixels and len(RoboHiveEnv.get_available_cams(env_name=envname)) == 0:
torchrl_logger.info("no camera")
return
try:
env = RoboHiveEnv(envname, from_pixels=from_pixels)
except AttributeError as err:
if "'MjData' object has no attribute 'get_body_xipos'" in str(err):
torchrl_logger.info("tcdm are broken")
return
else:
raise err
check_env_specs(env)


@pytest.mark.skipif(not _has_smacv2, reason="SMACv2 not found")
Expand Down
39 changes: 13 additions & 26 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@
import numpy as np
import torch
from tensordict import make_tensordict, TensorDict
from torchrl._utils import implement_for
from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec
from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv
from torchrl.envs.libs.gym import (
_AsyncMeta,
_gym_to_torchrl_spec_transform,
gym_backend,
GymEnv,
)
from torchrl.envs.utils import _classproperty, make_composite_from_td

_has_gym = importlib.util.find_spec("gym") is not None
_has_gym = (
importlib.util.find_spec("gym") is not None
or importlib.util.find_spec("gymnasium") is not None
)
_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym

if _has_robohive:
Expand Down Expand Up @@ -126,7 +133,7 @@ def CURR_DIR(cls):
def available_envs(cls):
if not _has_robohive:
return []
RoboHiveEnv.register_envs()
cls.register_envs()
return cls.env_list

@classmethod
Expand All @@ -143,25 +150,6 @@ def register_envs(cls):
if not len(robohive_envs):
raise RuntimeError("did not load any environment.")

@implement_for(
"gymnasium",
) # make sure gym 0.13 is installed, otherwise raise an exception
def _build_env(self, *args, **kwargs): # noqa: F811
raise NotImplementedError(
"Your gym version is too recent, RoboHiveEnv is only compatible with gym==0.13."
)

@implement_for(
"gym", "0.14", None
) # make sure gym 0.13 is installed, otherwise raise an exception
def _build_env(self, *args, **kwargs): # noqa: F811
raise NotImplementedError(
"Your gym version is too recent, RoboHiveEnv is only compatible with gym 0.13."
)

@implement_for(
"gym", None, "0.14"
) # make sure gym 0.13 is installed, otherwise raise an exception
def _build_env( # noqa: F811
self,
env_name: str,
Expand Down Expand Up @@ -233,6 +221,7 @@ def register_visual_env(cls, env_name, cams):

if not len(cams):
raise RuntimeError("Cannot create a visual envs without cameras.")
cams = [i.replace("A:", "A_") for i in cams]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if this is something you think should be changed at RoboHive's end.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see your comment above. I believe A:trackingZ is a valid string for MuJoCo as well as RoboHive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would encourage using environment (and derived camera etc) names that are conventional and can be used for serialization. I can imagine someone using the camera name within the experiment name and use the experiment name to save things on disk

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have name a note of this for the next release.

cams = sorted(cams)
new_env_name = "-".join([cam[:-3] for cam in cams] + [env_name])
if new_env_name in cls.env_list:
Expand Down Expand Up @@ -382,8 +371,6 @@ def to(self, *args, **kwargs):

@classmethod
def get_available_cams(cls, env_name):
import gym

env = gym.make(env_name)
env = gym_backend().make(env_name)
cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)]
return cams
Loading