Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making all tests pass with Implement_for functionality. Enables support most gym/gymnasium versions #119

Merged
merged 11 commits into from
Nov 27, 2023
37 changes: 26 additions & 11 deletions robohive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,63 +16,78 @@


# Register RoboHive Envs
import gym
_current_gym_envs = gym.envs.registration.registry.env_specs.keys()
from robohive.utils.import_utils import import_gym; gym = import_gym()
from robohive.utils.implement_for import implement_for

#TODO: check versions
@implement_for("gym", None, "0.24")
def gym_registry_specs():
return gym.envs.registry.env_specs

@implement_for("gym", "0.24", None)
def gym_registry_specs():
return gym.envs.registry

@implement_for("gymnasium")
def gym_registry_specs():
return gym.envs.registry

_current_gym_envs = gym_registry_specs().keys()
_current_gym_envs = set(_current_gym_envs)
robohive_env_suite = set()

# Register Arms Suite
import robohive.envs.arms # noqa
robohive_arm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_arm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_arm_suite = set(sorted(robohive_arm_suite))
robohive_env_suite = robohive_env_suite | robohive_arm_suite

# Register MyoBase Suite
import robohive.envs.myo.myobase # noqa
robohive_myobase_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_myobase_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_myobase_suite
robohive_myobase_suite = sorted(robohive_myobase_suite)

# Register MyoChal Suite
import robohive.envs.myo.myochallenge # noqa
robohive_myochal_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_myochal_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_myochal_suite
robohive_myochal_suite = sorted(robohive_myochal_suite)

# Register MyoDM Suite
import robohive.envs.myo.myodm # noqa
robohive_myodm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_myodm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_myodm_suite
robohive_myodm_suite = sorted(robohive_myodm_suite)

# Register FM suite
import robohive.envs.fm # noqa
robohive_fm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_fm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_fm_suite
robohive_fm_suite = sorted(robohive_fm_suite)

# Register Hands Suite
import robohive.envs.hands # noqa
# import robohive.envs.tcdm # noqa # WIP
robohive_hand_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_hand_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_hand_suite
robohive_hand_suite = sorted(robohive_hand_suite)

# Register Claw suite
import robohive.envs.claws # noqa
robohive_claw_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_claw_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_claw_suite
robohive_claw_suite = sorted(robohive_claw_suite)

# Register Multi-task Suite
import robohive.envs.multi_task # noqa
robohive_multitask_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_multitask_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_multitask_suite
robohive_multitask_suite = sorted(robohive_multitask_suite)

# Register Locomotion Suite
import robohive.envs.quadrupeds # noqa
robohive_quad_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs
robohive_quad_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs
robohive_env_suite = robohive_env_suite | robohive_quad_suite
robohive_quad_suite = sorted(robohive_quad_suite)

Expand Down
4 changes: 3 additions & 1 deletion robohive/envs/arms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
================================================= """

from gym.envs.registration import register
# from gym.envs.registration import register
from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register

from robohive.envs.env_variants import register_env_variant
import os
curr_dir = os.path.dirname(os.path.abspath(__file__))
Expand Down
12 changes: 9 additions & 3 deletions robohive/envs/arms/pick_place_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
================================================= """

import collections
import gym
from robohive.utils.import_utils import import_gym; gym = import_gym()
import numpy as np

from robohive.envs import env_base
Expand Down Expand Up @@ -70,6 +70,12 @@ def _setup(self,
self.randomize = randomize
self.geom_sizes = geom_sizes

# Save body init pos
self.init_body_pos = {}
for body in ["obj0", "obj1", "obj2"]:
bid = self.sim.model.body_name2id(body)
self.init_body_pos[body] = self.sim.model.body_pos[bid].copy()

super()._setup(obs_keys=obs_keys,
weighted_reward_keys=weighted_reward_keys,
reward_mode=reward_mode,
Expand Down Expand Up @@ -119,12 +125,12 @@ def reset(self):
# object shapes and locations
for body in ["obj0", "obj1", "obj2"]:
bid = self.sim.model.body_name2id(body)
self.sim.model.body_pos[bid] += self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos
self.sim.model.body_pos[bid] = self.init_body_pos[body] + self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos
self.sim.model.body_quat[bid] = euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat

for gid in range(self.sim.model.body_geomnum[bid]):
gid+=self.sim.model.body_geomadr[bid]
self.sim.model.geom_type[gid]=self.np_random.randint(low=2, high=7) # random shape
self.sim.model.geom_type[gid]=self.np_random.choice([2,3,4,5,6,7]) # random shape
self.sim.model.geom_size[gid]=self.np_random.uniform(low=self.geom_sizes['low'], high=self.geom_sizes['high']) # random size
self.sim.model.geom_pos[gid]=self.np_random.uniform(low=-1*self.sim.model.geom_size[gid], high=self.sim.model.geom_size[gid]) # random pos
self.sim.model.geom_quat[gid]=euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat
Expand Down
3 changes: 2 additions & 1 deletion robohive/envs/arms/push_base_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
================================================= """

import collections
import gym
# from robohive.utils.import_utils import import_gym; gym = import_gym()
from robohive.utils.import_utils import import_gym; gym = import_gym()
import numpy as np

from robohive.envs import env_base
Expand Down
3 changes: 2 additions & 1 deletion robohive/envs/arms/reach_base_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
================================================= """

import collections
import gym
from robohive.utils.import_utils import import_gym; gym = import_gym()
# from robohive.utils.import_utils import import_gym; gym = import_gym()
import numpy as np

from robohive.envs import env_base
Expand Down
4 changes: 2 additions & 2 deletions robohive/envs/claws/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
================================================= """

from gym.envs.registration import register
from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register
import os
curr_dir = os.path.dirname(os.path.abspath(__file__))
from robohive.envs.env_variants import register_env_variant
Expand Down Expand Up @@ -38,7 +38,7 @@
'model_path': curr_dir+'/trifinger/trifinger_reorient.xml',
'object_site_name': "object",
'target_site_name': "target",
'target_xyz_range': {'high':[.05, .05, 0.9], 'low':[-.05, -.05, 0.99]},
'target_xyz_range': {'high':[.05, .05, 0.99], 'low':[-.05, -.05, 0.9]},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is unrelated?

'target_euler_range': {'high':[1, 1, 1], 'low':[-1, -1, -1]}
}
)
Expand Down
2 changes: 1 addition & 1 deletion robohive/envs/claws/reorient_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
================================================= """

import collections
import gym
from robohive.utils.import_utils import import_gym; gym = import_gym()
import numpy as np

from robohive.envs import env_base
Expand Down
46 changes: 36 additions & 10 deletions robohive/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
================================================= """

import gym
# TODO: find how to make this compatible with gymnasium. Maybe a global variable that indicates what to use as backend?
# from robohive.utils.import_utils import import_gym; gym = import_gym()
from robohive.utils.import_utils import import_gym; gym = import_gym()
import numpy as np
import os
import time as timer

from robohive.envs.obs_vec_dict import ObsVecDict
from robohive.utils import tensor_utils
from robohive.robot.robot import Robot
from robohive.utils.implement_for import implement_for
from robohive.utils.prompt_utils import prompt, Prompt
import skvideo.io
from sys import platform
Expand Down Expand Up @@ -130,7 +133,7 @@ def _setup(self,
self._setup_rgb_encoders(self.visual_keys, device=None)

# reset to get the env ready
observation, _reward, done, _info = self.step(np.zeros(self.sim.model.nu))
observation, _reward, done, *_, _info = self.step(np.zeros(self.sim.model.nu))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could decorate this too but it's also ok if you don't need the truncated oc

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to work for now. Lets revisit in the future incase this isn't sufficient

# Question: Should we replace above with following? Its specially helpful for hardware as it forces a env reset before continuing, without which the hardware will make a big jump from its position to the position asked by step.
# observation = self.reset()
assert not done, "Check initialization. Simulation starts in a done state."
Expand Down Expand Up @@ -263,8 +266,23 @@ def step(self, a, **kwargs):
render_cbk=self.mj_render if self.mujoco_render_frames else None)
return self.forward(**kwargs)

@implement_for("gym", None, "0.24")
def forward(self, **kwargs):
return self._forward(**kwargs)

@implement_for("gym", "0.24", None)
def forward(self, **kwargs):
obs, reward, done, info = self._forward(**kwargs)
terminal = done
return obs, reward, terminal, False, info

@implement_for("gymnasium")
def forward(self, **kwargs):
obs, reward, done, info = self._forward(**kwargs)
terminal = done
return obs, reward, terminal, False, info

def _forward(self, **kwargs):
"""
Forward propagate env to recover env details
Returns current obs(t), rwd(t), done(t), info(t)
Expand Down Expand Up @@ -476,7 +494,7 @@ def get_input_seed(self):
return self.input_seed


def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
"""
Reset the environment
Default implemention provided. Override if env needs custom reset
Expand All @@ -485,11 +503,19 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(qpos, qvel, **kwargs)
return self.get_obs()
@implement_for("gym", None, "0.26")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs)
@implement_for("gym", "0.26", None)
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}
@implement_for("gymnasium")
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {}


@property
def _step(self, a):
return self.step(a)
# @property
# def _step(self, a):
# return self.step(a)


@property
Expand Down Expand Up @@ -702,7 +728,7 @@ def examine_policy(self,
ep_rwd = 0.0
while t < horizon and done is False:
a = policy.get_action(o)[0] if mode == 'exploration' else policy.get_action(o)[1]['evaluation']
next_o, rwd, done, env_info = self.step(a)
next_o, rwd, done, *_, env_info = self.step(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

ep_rwd += rwd
# render offscreen visuals
if render =='offscreen':
Expand Down Expand Up @@ -794,7 +820,7 @@ def examine_policy_new(self,
ep_rwd = 0.0

# Rollout --------------------------------
obs, rwd, done, env_info = self.forward(update_exteroception=True) # t=0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

obs, rwd, done, *_, env_info = self.forward(update_exteroception=True) # t=0
while t < horizon and done is False:

# print(t, t*self.dt, self.time, t*self.dt-self.time)
Expand Down Expand Up @@ -825,7 +851,7 @@ def examine_policy_new(self,


# step env using actions from t=>t+1 ----------------------
obs, rwd, done, env_info = self.step(act, update_exteroception=True)
obs, rwd, done, *_, env_info = self.step(act, update_exteroception=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

t = t+1
ep_rwd += rwd

Expand Down
Loading