-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix RoboHiveEnv tests #2062
Changes from 10 commits
7ee8790
ba59958
6cddb66
b8a2e4d
1414632
2801a30
19a5953
7e956ac
6e5e284
f51dd4c
07c0c11
2a24400
11f7780
4d655f7
073e655
b4a28fa
772e893
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,19 @@ | |
import numpy as np | ||
import torch | ||
from tensordict import make_tensordict, TensorDict | ||
from torchrl._utils import implement_for | ||
from torchrl.data.tensor_specs import UnboundedContinuousTensorSpec | ||
from torchrl.envs.libs.gym import _AsyncMeta, _gym_to_torchrl_spec_transform, GymEnv | ||
from torchrl.envs.libs.gym import ( | ||
_AsyncMeta, | ||
_gym_to_torchrl_spec_transform, | ||
gym_backend, | ||
GymEnv, | ||
) | ||
from torchrl.envs.utils import _classproperty, make_composite_from_td | ||
|
||
_has_gym = importlib.util.find_spec("gym") is not None | ||
_has_gym = ( | ||
importlib.util.find_spec("gym") is not None | ||
or importlib.util.find_spec("gymnasium") is not None | ||
) | ||
_has_robohive = importlib.util.find_spec("robohive") is not None and _has_gym | ||
|
||
if _has_robohive: | ||
|
@@ -126,7 +133,7 @@ def CURR_DIR(cls): | |
def available_envs(cls): | ||
if not _has_robohive: | ||
return [] | ||
RoboHiveEnv.register_envs() | ||
cls.register_envs() | ||
return cls.env_list | ||
|
||
@classmethod | ||
|
@@ -143,25 +150,6 @@ def register_envs(cls): | |
if not len(robohive_envs): | ||
raise RuntimeError("did not load any environment.") | ||
|
||
@implement_for( | ||
"gymnasium", | ||
) # make sure gym 0.13 is installed, otherwise raise an exception | ||
def _build_env(self, *args, **kwargs): # noqa: F811 | ||
raise NotImplementedError( | ||
"Your gym version is too recent, RoboHiveEnv is only compatible with gym==0.13." | ||
) | ||
|
||
@implement_for( | ||
"gym", "0.14", None | ||
) # make sure gym 0.13 is installed, otherwise raise an exception | ||
def _build_env(self, *args, **kwargs): # noqa: F811 | ||
raise NotImplementedError( | ||
"Your gym version is too recent, RoboHiveEnv is only compatible with gym 0.13." | ||
) | ||
|
||
@implement_for( | ||
"gym", None, "0.14" | ||
) # make sure gym 0.13 is installed, otherwise raise an exception | ||
def _build_env( # noqa: F811 | ||
self, | ||
env_name: str, | ||
|
@@ -233,6 +221,7 @@ def register_visual_env(cls, env_name, cams): | |
|
||
if not len(cams): | ||
raise RuntimeError("Cannot create a visual envs without cameras.") | ||
cams = [i.replace("A:", "A_") for i in cams] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me know if this is something you think should be changed at RoboHive's end. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see your comment above. I believe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would encourage using environment (and derived camera etc) names that are conventional and can be used for serialization. I can imagine someone using the camera name within the experiment name and use the experiment name to save things on disk There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have name a note of this for the next release. |
||
cams = sorted(cams) | ||
new_env_name = "-".join([cam[:-3] for cam in cams] + [env_name]) | ||
if new_env_name in cls.env_list: | ||
|
@@ -382,8 +371,6 @@ def to(self, *args, **kwargs): | |
|
||
@classmethod | ||
def get_available_cams(cls, env_name): | ||
import gym | ||
|
||
env = gym.make(env_name) | ||
env = gym_backend().make(env_name) | ||
cams = [env.sim.model.id2name(ic, 7) for ic in range(env.sim.model.ncam)] | ||
return cams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RoboHive's prints can be contorled using an env variable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure
The goal of this comment is for anyone running the tests and seeing loads of messages from imports. That could be a bit surprising and annoying. We can add that there's a strategy to silence them. I guess one could argue that importing a library should not print anything by default though...