diff --git a/robohive/__init__.py b/robohive/__init__.py index 49a71a1a..845af553 100644 --- a/robohive/__init__.py +++ b/robohive/__init__.py @@ -16,63 +16,78 @@ # Register RoboHive Envs -import gym -_current_gym_envs = gym.envs.registration.registry.env_specs.keys() +from robohive.utils.import_utils import import_gym; gym = import_gym() +from robohive.utils.implement_for import implement_for + +#TODO: check versions +@implement_for("gym", None, "0.24") +def gym_registry_specs(): + return gym.envs.registry.env_specs + +@implement_for("gym", "0.24", None) +def gym_registry_specs(): + return gym.envs.registry + +@implement_for("gymnasium") +def gym_registry_specs(): + return gym.envs.registry + +_current_gym_envs = gym_registry_specs().keys() _current_gym_envs = set(_current_gym_envs) robohive_env_suite = set() # Register Arms Suite import robohive.envs.arms # noqa -robohive_arm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_arm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_arm_suite = set(sorted(robohive_arm_suite)) robohive_env_suite = robohive_env_suite | robohive_arm_suite # Register MyoBase Suite import robohive.envs.myo.myobase # noqa -robohive_myobase_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_myobase_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_myobase_suite robohive_myobase_suite = sorted(robohive_myobase_suite) # Register MyoChal Suite import robohive.envs.myo.myochallenge # noqa -robohive_myochal_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_myochal_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_myochal_suite robohive_myochal_suite = sorted(robohive_myochal_suite) # Register MyoDM Suite import robohive.envs.myo.myodm # noqa -robohive_myodm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_myodm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_myodm_suite robohive_myodm_suite = sorted(robohive_myodm_suite) # Register FM suite import robohive.envs.fm # noqa -robohive_fm_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_fm_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_fm_suite robohive_fm_suite = sorted(robohive_fm_suite) # Register Hands Suite import robohive.envs.hands # noqa # import robohive.envs.tcdm # noqa # WIP -robohive_hand_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_hand_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_hand_suite robohive_hand_suite = sorted(robohive_hand_suite) # Register Claw suite import robohive.envs.claws # noqa -robohive_claw_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_claw_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_claw_suite robohive_claw_suite = sorted(robohive_claw_suite) # Register Multi-task Suite import robohive.envs.multi_task # noqa -robohive_multitask_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_multitask_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_multitask_suite robohive_multitask_suite = sorted(robohive_multitask_suite) # Register Locomotion Suite import robohive.envs.quadrupeds # noqa -robohive_quad_suite = set(gym.envs.registration.registry.env_specs.keys())-robohive_env_suite-_current_gym_envs +robohive_quad_suite = set(gym_registry_specs().keys())-robohive_env_suite-_current_gym_envs robohive_env_suite = robohive_env_suite | robohive_quad_suite robohive_quad_suite = sorted(robohive_quad_suite) diff --git a/robohive/envs/arms/__init__.py b/robohive/envs/arms/__init__.py index c479017b..01c4949a 100644 --- a/robohive/envs/arms/__init__.py +++ b/robohive/envs/arms/__init__.py @@ -5,7 +5,9 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +# from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + from robohive.envs.env_variants import register_env_variant import os curr_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/arms/pick_place_v0.py b/robohive/envs/arms/pick_place_v0.py index 298dc187..f7fb1fd4 100644 --- a/robohive/envs/arms/pick_place_v0.py +++ b/robohive/envs/arms/pick_place_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base @@ -70,6 +70,12 @@ def _setup(self, self.randomize = randomize self.geom_sizes = geom_sizes + # Save body init pos + self.init_body_pos = {} + for body in ["obj0", "obj1", "obj2"]: + bid = self.sim.model.body_name2id(body) + self.init_body_pos[body] = self.sim.model.body_pos[bid].copy() + super()._setup(obs_keys=obs_keys, weighted_reward_keys=weighted_reward_keys, reward_mode=reward_mode, @@ -119,12 +125,12 @@ def reset(self): # object shapes and locations for body in ["obj0", "obj1", "obj2"]: bid = self.sim.model.body_name2id(body) - self.sim.model.body_pos[bid] += self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos + self.sim.model.body_pos[bid] = self.init_body_pos[body] + self.np_random.uniform(low=[-.010, -.010, -.010], high=[-.010, -.010, -.010])# random pos self.sim.model.body_quat[bid] = euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat for gid in range(self.sim.model.body_geomnum[bid]): gid+=self.sim.model.body_geomadr[bid] - self.sim.model.geom_type[gid]=self.np_random.randint(low=2, high=7) # random shape + self.sim.model.geom_type[gid]=self.np_random.choice([2,3,4,5,6,7]) # random shape self.sim.model.geom_size[gid]=self.np_random.uniform(low=self.geom_sizes['low'], high=self.geom_sizes['high']) # random size self.sim.model.geom_pos[gid]=self.np_random.uniform(low=-1*self.sim.model.geom_size[gid], high=self.sim.model.geom_size[gid]) # random pos self.sim.model.geom_quat[gid]=euler2quat(self.np_random.uniform(low=(-np.pi/2, -np.pi/2, -np.pi/2), high=(np.pi/2, np.pi/2, np.pi/2)) ) # random quat diff --git a/robohive/envs/arms/push_base_v0.py b/robohive/envs/arms/push_base_v0.py index fe0ba699..f7ff8957 100644 --- a/robohive/envs/arms/push_base_v0.py +++ b/robohive/envs/arms/push_base_v0.py @@ -6,7 +6,8 @@ ================================================= """ import collections -import gym +# from robohive.utils.import_utils import import_gym; gym = import_gym() +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/arms/reach_base_v0.py b/robohive/envs/arms/reach_base_v0.py index 6a78e031..00c6944c 100644 --- a/robohive/envs/arms/reach_base_v0.py +++ b/robohive/envs/arms/reach_base_v0.py @@ -6,7 +6,8 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() +# from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/claws/__init__.py b/robohive/envs/claws/__init__.py index 1b78c421..4eb65403 100644 --- a/robohive/envs/claws/__init__.py +++ b/robohive/envs/claws/__init__.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register import os curr_dir = os.path.dirname(os.path.abspath(__file__)) from robohive.envs.env_variants import register_env_variant @@ -38,7 +38,7 @@ 'model_path': curr_dir+'/trifinger/trifinger_reorient.xml', 'object_site_name': "object", 'target_site_name': "target", - 'target_xyz_range': {'high':[.05, .05, 0.9], 'low':[-.05, -.05, 0.99]}, + 'target_xyz_range': {'high':[.05, .05, 0.99], 'low':[-.05, -.05, 0.9]}, 'target_euler_range': {'high':[1, 1, 1], 'low':[-1, -1, -1]} } ) diff --git a/robohive/envs/claws/reorient_v0.py b/robohive/envs/claws/reorient_v0.py index ceabeef5..d165faaf 100644 --- a/robohive/envs/claws/reorient_v0.py +++ b/robohive/envs/claws/reorient_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 9bff0158..4910f973 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -5,7 +5,9 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +# TODO: find how to make this compatible with gymnasium. Maybe a global variable that indicates what to use as backend? +# from robohive.utils.import_utils import import_gym; gym = import_gym() +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np import os import time as timer @@ -13,6 +15,7 @@ from robohive.envs.obs_vec_dict import ObsVecDict from robohive.utils import tensor_utils from robohive.robot.robot import Robot +from robohive.utils.implement_for import implement_for from robohive.utils.prompt_utils import prompt, Prompt import skvideo.io from sys import platform @@ -130,7 +133,7 @@ def _setup(self, self._setup_rgb_encoders(self.visual_keys, device=None) # reset to get the env ready - observation, _reward, done, _info = self.step(np.zeros(self.sim.model.nu)) + observation, _reward, done, *_, _info = self.step(np.zeros(self.sim.model.nu)) # Question: Should we replace above with following? Its specially helpful for hardware as it forces a env reset before continuing, without which the hardware will make a big jump from its position to the position asked by step. # observation = self.reset() assert not done, "Check initialization. Simulation starts in a done state." @@ -263,8 +266,23 @@ def step(self, a, **kwargs): render_cbk=self.mj_render if self.mujoco_render_frames else None) return self.forward(**kwargs) + @implement_for("gym", None, "0.24") + def forward(self, **kwargs): + return self._forward(**kwargs) + + @implement_for("gym", "0.24", None) + def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + @implement_for("gymnasium") def forward(self, **kwargs): + obs, reward, done, info = self._forward(**kwargs) + terminal = done + return obs, reward, terminal, False, info + + def _forward(self, **kwargs): """ Forward propagate env to recover env details Returns current obs(t), rwd(t), done(t), info(t) @@ -476,7 +494,7 @@ def get_input_seed(self): return self.input_seed - def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): """ Reset the environment Default implemention provided. Override if env needs custom reset @@ -485,11 +503,19 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel self.robot.reset(qpos, qvel, **kwargs) return self.get_obs() + @implement_for("gym", None, "0.26") + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) + @implement_for("gym", "0.26", None) + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} + @implement_for("gymnasium") + def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} - - @property - def _step(self, a): - return self.step(a) + # @property + # def _step(self, a): + # return self.step(a) @property @@ -702,7 +728,7 @@ def examine_policy(self, ep_rwd = 0.0 while t < horizon and done is False: a = policy.get_action(o)[0] if mode == 'exploration' else policy.get_action(o)[1]['evaluation'] - next_o, rwd, done, env_info = self.step(a) + next_o, rwd, done, *_, env_info = self.step(a) ep_rwd += rwd # render offscreen visuals if render =='offscreen': @@ -794,7 +820,7 @@ def examine_policy_new(self, ep_rwd = 0.0 # Rollout -------------------------------- - obs, rwd, done, env_info = self.forward(update_exteroception=True) # t=0 + obs, rwd, done, *_, env_info = self.forward(update_exteroception=True) # t=0 while t < horizon and done is False: # print(t, t*self.dt, self.time, t*self.dt-self.time) @@ -825,7 +851,7 @@ def examine_policy_new(self, # step env using actions from t=>t+1 ---------------------- - obs, rwd, done, env_info = self.step(act, update_exteroception=True) + obs, rwd, done, *_, env_info = self.step(act, update_exteroception=True) t = t+1 ep_rwd += rwd diff --git a/robohive/envs/env_variants.py b/robohive/envs/env_variants.py index 07bb3374..f675bc54 100644 --- a/robohive/envs/env_variants.py +++ b/robohive/envs/env_variants.py @@ -5,12 +5,65 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register import collections from copy import deepcopy from flatten_dict import flatten, unflatten +from robohive.utils.implement_for import implement_for + +#TODO: check versions +@implement_for("gym", None, "0.24") +def gym_registry_specs(): + return gym.envs.registry.env_specs + +@implement_for("gym", "0.24", None) +def gym_registry_specs(): + return gym.envs.registry + +@implement_for("gymnasium") +def gym_registry_specs(): + return gym.envs.registry + +# TODO: move to within the function? +@implement_for("gym", None, "0.24") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gym", "0.24", None) +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gymnasium") +def _update_env_spec_kwarg(env_variant_specs, variants, override_keys): + env_variant_specs.kwargs, variants_update_keyval_str = update_dict(env_variant_specs.kwargs, variants, override_keys=override_keys) + return variants_update_keyval_str + +@implement_for("gym", None, "0.24") +def _entry_point(env_variant_specs): + return env_variant_specs._entry_point + +@implement_for("gym", "0.24", None) +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gymnasium") +def _entry_point(env_variant_specs): + return env_variant_specs.entry_point + +@implement_for("gym", None, "0.24") +def _kwargs(env_variant_specs): + return env_variant_specs._kwargs + +@implement_for("gym", "0.24", None) +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs + +@implement_for("gymnasium") +def _kwargs(env_variant_specs): + return env_variant_specs.kwargs # Update base_dict using update_dict def update_dict(base_dict:dict, update_dict:dict, override_keys:list=None): @@ -47,10 +100,10 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals """ # check if the base env is registered - assert env_id in gym.envs.registry.env_specs.keys(), "ERROR: {} not found in env registry".format(env_id) + assert env_id in gym_registry_specs().keys(), "ERROR: {} not found in env registry".format(env_id) # recover the specs of the existing env - env_variant_specs = deepcopy(gym.envs.registry.env_specs[env_id]) + env_variant_specs = deepcopy(gym_registry_specs()[env_id]) env_variant_id = env_variant_specs.id[:-3] # update horizon if requested @@ -60,16 +113,16 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals del variants['max_episode_steps'] # merge specs._kwargs with variants - env_variant_specs._kwargs, variants_update_keyval_str = update_dict(env_variant_specs._kwargs, variants, override_keys=override_keys) + variants_update_keyval_str = _update_env_spec_kwarg(env_variant_specs, variants, override_keys) env_variant_id += variants_update_keyval_str # finalize name and register env env_variant_specs.id = env_variant_id+env_variant_specs.id[-3:] if variant_id is None else variant_id register( id=env_variant_specs.id, - entry_point=env_variant_specs._entry_point, + entry_point=_entry_point(env_variant_specs), max_episode_steps=env_variant_specs.max_episode_steps, - kwargs=env_variant_specs._kwargs + kwargs=_kwargs(env_variant_specs) ) if not silent: print("Registered a new env-variant:", env_variant_specs.id) @@ -96,11 +149,11 @@ def register_env_variant(env_id:str, variants:dict, variant_id=None, silent=Fals # Test variant print("Base-env kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[base_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[base_env_name]._kwargs) print("Env-variant kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_env_name]._kwargs) print("Env-variant (with override) kwargs: ") - pprint.pprint(gym.envs.registry.env_specs[variant_overide_env_name]._kwargs) + pprint.pprint(gym_registry_specs()[variant_overide_env_name]._kwargs) # Test one of the newly minted env env = gym.make(variant_env_name) diff --git a/robohive/envs/fm/__init__.py b/robohive/envs/fm/__init__.py index 16dd0562..9311e4cc 100644 --- a/robohive/envs/fm/__init__.py +++ b/robohive/envs/fm/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + import numpy as np import os curr_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/fm/franka_ee_pose_v0.py b/robohive/envs/fm/franka_ee_pose_v0.py index 896ad2f7..9e2f5749 100644 --- a/robohive/envs/fm/franka_ee_pose_v0.py +++ b/robohive/envs/fm/franka_ee_pose_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym(); import numpy as np from robohive.envs import env_base from robohive.physics.sim_scene import SimScene diff --git a/robohive/envs/fm/franka_robotiq_data_v0.py b/robohive/envs/fm/franka_robotiq_data_v0.py index bf7d1463..738357f2 100644 --- a/robohive/envs/fm/franka_robotiq_data_v0.py +++ b/robohive/envs/fm/franka_robotiq_data_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym(); import numpy as np from robohive.envs import env_base from robohive.physics.sim_scene import SimScene diff --git a/robohive/envs/hands/__init__.py b/robohive/envs/hands/__init__.py index 11001702..cf3df605 100644 --- a/robohive/envs/hands/__init__.py +++ b/robohive/envs/hands/__init__.py @@ -5,7 +5,8 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + from robohive.envs.env_variants import register_env_variant import os curr_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/hands/baoding_v1.py b/robohive/envs/hands/baoding_v1.py index 1e643916..c0fbac9d 100644 --- a/robohive/envs/hands/baoding_v1.py +++ b/robohive/envs/hands/baoding_v1.py @@ -7,7 +7,7 @@ import collections import enum -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base @@ -272,7 +272,7 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): # populate go-to task with a target location if self.which_task==Task.MOVE_TO_LOCATION: - goal_pos = np.random.randint(4) + goal_pos = self.np_random.choice([0,1,2,3]) desired_position = [] if goal_pos==0: desired_position.append(0.01) #x diff --git a/robohive/envs/hands/door_v1.py b/robohive/envs/hands/door_v1.py index 877a84de..9aff3455 100644 --- a/robohive/envs/hands/door_v1.py +++ b/robohive/envs/hands/door_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/hands/hammer_v1.py b/robohive/envs/hands/hammer_v1.py index a7607b1d..06b862fd 100644 --- a/robohive/envs/hands/hammer_v1.py +++ b/robohive/envs/hands/hammer_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.utils.quat_math import * diff --git a/robohive/envs/hands/pen_v1.py b/robohive/envs/hands/pen_v1.py index 2a4bcbcb..de2509d4 100644 --- a/robohive/envs/hands/pen_v1.py +++ b/robohive/envs/hands/pen_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.utils.vector_math import calculate_cosine diff --git a/robohive/envs/hands/relocate_v1.py b/robohive/envs/hands/relocate_v1.py index 5c19b326..7513954e 100644 --- a/robohive/envs/hands/relocate_v1.py +++ b/robohive/envs/hands/relocate_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/multi_task/common/franka_appliance_v1.py b/robohive/envs/multi_task/common/franka_appliance_v1.py index 14fc2d1d..eb01784c 100644 --- a/robohive/envs/multi_task/common/franka_appliance_v1.py +++ b/robohive/envs/multi_task/common/franka_appliance_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.utils.quat_math import euler2quat diff --git a/robohive/envs/multi_task/common/franka_kitchen_v2.py b/robohive/envs/multi_task/common/franka_kitchen_v2.py index 50217e5b..b7b32fcc 100644 --- a/robohive/envs/multi_task/common/franka_kitchen_v2.py +++ b/robohive/envs/multi_task/common/franka_kitchen_v2.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.multi_task.multi_task_base_v1 import KitchenBase class FrankaKitchen(KitchenBase): diff --git a/robohive/envs/multi_task/multi_task_base_v1.py b/robohive/envs/multi_task/multi_task_base_v1.py index 3b4a443f..2c15632b 100644 --- a/robohive/envs/multi_task/multi_task_base_v1.py +++ b/robohive/envs/multi_task/multi_task_base_v1.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/multi_task/substeps1/__init__.py b/robohive/envs/multi_task/substeps1/__init__.py index 839a8f43..faa66487 100644 --- a/robohive/envs/multi_task/substeps1/__init__.py +++ b/robohive/envs/multi_task/substeps1/__init__.py @@ -6,7 +6,8 @@ ================================================= """ import os -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + CURR_DIR = os.path.dirname(os.path.abspath(__file__)) diff --git a/robohive/envs/multi_task/substeps1/franka_kitchen.py b/robohive/envs/multi_task/substeps1/franka_kitchen.py index 6305681e..bd4f3610 100644 --- a/robohive/envs/multi_task/substeps1/franka_kitchen.py +++ b/robohive/envs/multi_task/substeps1/franka_kitchen.py @@ -6,7 +6,8 @@ ================================================= """ import os -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + from robohive.envs.multi_task.common.franka_kitchen_v2 import FrankaKitchen import copy diff --git a/robohive/envs/multi_task/utils/parse_demos.py b/robohive/envs/multi_task/utils/parse_demos.py index 9bed1914..2e1763bf 100644 --- a/robohive/envs/multi_task/utils/parse_demos.py +++ b/robohive/envs/multi_task/utils/parse_demos.py @@ -23,7 +23,7 @@ import robohive import time as timer # import skvideo.io -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from tqdm import tqdm diff --git a/robohive/envs/myo/myobase/__init__.py b/robohive/envs/myo/myobase/__init__.py index 3ed59ef0..b6f505ca 100644 --- a/robohive/envs/myo/myobase/__init__.py +++ b/robohive/envs/myo/myobase/__init__.py @@ -3,7 +3,8 @@ Authors :: Vikash Kumar (vikashplus@gmail.com), Vittorio Caggiano (caggiano@gmail.com) ================================================= """ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + from robohive.envs.env_variants import register_env_variant import os @@ -299,7 +300,7 @@ def register_env_with_variants(id, entry_point, max_episode_steps, kwargs): max_episode_steps=150, kwargs={ 'model_path': curr_dir+leg_model, - 'joint_random_range': (0.2, -0.2), #range of joint randomization (jnt = init_qpos + random(range) + 'joint_random_range': (-.2, 0.2), #range of joint randomization (jnt = init_qpos + random(range) 'target_reach_range': { 'pelvis': ((-.05, -.05, 0), (0.05, 0.05, 0)), }, diff --git a/robohive/envs/myo/myobase/baoding_v1.py b/robohive/envs/myo/myobase/baoding_v1.py index 42b7163c..7a2ace77 100644 --- a/robohive/envs/myo/myobase/baoding_v1.py +++ b/robohive/envs/myo/myobase/baoding_v1.py @@ -5,7 +5,7 @@ import collections import enum -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs.myo.base_v0 import BaseV0 @@ -281,7 +281,7 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): # populate go-to task with a target location (pos likely needs update) if self.which_task==Task.MOVE_TO_LOCATION: - goal_pos = np.random.randint(4) + goal_pos = self.np_random.choice([0,1,2,3]) desired_position = [] if goal_pos==0: desired_position.append(-.195) #x diff --git a/robohive/envs/myo/myobase/key_turn_v0.py b/robohive/envs/myo/myobase/key_turn_v0.py index 37bdb988..9adea748 100644 --- a/robohive/envs/myo/myobase/key_turn_v0.py +++ b/robohive/envs/myo/myobase/key_turn_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 diff --git a/robohive/envs/myo/myobase/obj_hold_v0.py b/robohive/envs/myo/myobase/obj_hold_v0.py index 7aa3ae7b..155e94f8 100644 --- a/robohive/envs/myo/myobase/obj_hold_v0.py +++ b/robohive/envs/myo/myobase/obj_hold_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 diff --git a/robohive/envs/myo/myobase/pen_v0.py b/robohive/envs/myo/myobase/pen_v0.py index 6a0f9f87..122171cf 100644 --- a/robohive/envs/myo/myobase/pen_v0.py +++ b/robohive/envs/myo/myobase/pen_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import euler2quat diff --git a/robohive/envs/myo/myobase/pose_v0.py b/robohive/envs/myo/myobase/pose_v0.py index 97798bac..7dae8eb2 100644 --- a/robohive/envs/myo/myobase/pose_v0.py +++ b/robohive/envs/myo/myobase/pose_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs.myo.base_v0 import BaseV0 diff --git a/robohive/envs/myo/myobase/reach_v0.py b/robohive/envs/myo/myobase/reach_v0.py index 1c9896bc..35b2cfc7 100644 --- a/robohive/envs/myo/myobase/reach_v0.py +++ b/robohive/envs/myo/myobase/reach_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs.myo.base_v0 import BaseV0 diff --git a/robohive/envs/myo/myobase/reorient_sar_v0.py b/robohive/envs/myo/myobase/reorient_sar_v0.py index a87f9339..cd33c85b 100644 --- a/robohive/envs/myo/myobase/reorient_sar_v0.py +++ b/robohive/envs/myo/myobase/reorient_sar_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import euler2quat, mulQuat, negQuat, mat2quat diff --git a/robohive/envs/myo/myobase/walk_v0.py b/robohive/envs/myo/myobase/walk_v0.py index 53349c08..fa735a3d 100644 --- a/robohive/envs/myo/myobase/walk_v0.py +++ b/robohive/envs/myo/myobase/walk_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import quat2mat diff --git a/robohive/envs/myo/myochallenge/__init__.py b/robohive/envs/myo/myochallenge/__init__.py index 7fa2d003..45d1f8d5 100644 --- a/robohive/envs/myo/myochallenge/__init__.py +++ b/robohive/envs/myo/myochallenge/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + import os curr_dir = os.path.dirname(os.path.abspath(__file__)) @@ -16,7 +17,7 @@ 'frame_skip': 5, 'pos_th': 0.1, # cover entire base of the receptacle 'rot_th': np.inf, # ignore rotation errors - 'target_xyz_range': {'high':[0.2, -.35, 0.9], 'low':[0.0, -.1, 0.9]}, + 'target_xyz_range': {'high':[0.2, -.1, 0.9], 'low':[0.0, -.35, 0.9]}, 'target_rxryrz_range': {'high':[0.0, 0.0, 0.0], 'low':[0.0, 0.0, 0.0]} } ) diff --git a/robohive/envs/myo/myochallenge/baoding_v1.py b/robohive/envs/myo/myochallenge/baoding_v1.py index a46a15fd..967e1916 100644 --- a/robohive/envs/myo/myochallenge/baoding_v1.py +++ b/robohive/envs/myo/myochallenge/baoding_v1.py @@ -5,7 +5,7 @@ import collections import enum -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs.myo.base_v0 import BaseV0 diff --git a/robohive/envs/myo/myochallenge/chasetag_v0.py b/robohive/envs/myo/myochallenge/chasetag_v0.py index 28568857..cc528725 100644 --- a/robohive/envs/myo/myochallenge/chasetag_v0.py +++ b/robohive/envs/myo/myochallenge/chasetag_v0.py @@ -4,7 +4,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np import pink import os @@ -223,7 +223,8 @@ def _populate_patches(self): self._fill_patch(i, j, terrain_type) # put special terrain only once in 20% of episodes if self.rng.uniform() < 0.2: - i, j = self.rng.randint(0, self.patches_per_side, size=2) + i = self.rng.choice(range(self.patches_per_side)) + j = self.rng.choice(range(self.patches_per_side)) self._fill_patch(i, j, SpecialTerrains.RELIEF) def _fill_patch(self, i, j, terrain_type='FLAT'): diff --git a/robohive/envs/myo/myochallenge/relocate_v0.py b/robohive/envs/myo/myochallenge/relocate_v0.py index 6f4fab00..6da947b8 100644 --- a/robohive/envs/myo/myochallenge/relocate_v0.py +++ b/robohive/envs/myo/myochallenge/relocate_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import mat2euler, euler2quat @@ -153,7 +153,7 @@ def reset(self, reset_qpos=None, reset_qvel=None): for gid in range(self.sim.model.body_geomnum[bid]): gid+=self.sim.model.body_geomadr[bid] # get geom ids # update type, size, and collision bounds - self.sim.model.geom_type[gid]=self.np_random.randint(low=2, high=7) # random shape + self.sim.model.geom_type[gid]=self.np_random.choice([2,3,4,5,6]) # random shape self.sim.model.geom_size[gid]=self.np_random.uniform(low=self.obj_geom_range['low'], high=self.obj_geom_range['high']) # random size self.sim.model.geom_aabb[gid][3:]= self.obj_geom_range['high'] # bounding box, (center, size) self.sim.model.geom_rbound[gid] = 2.0*max(self.obj_geom_range['high']) # radius of bounding sphere diff --git a/robohive/envs/myo/myochallenge/reorient_v0.py b/robohive/envs/myo/myochallenge/reorient_v0.py index fb15ec37..62380767 100644 --- a/robohive/envs/myo/myochallenge/reorient_v0.py +++ b/robohive/envs/myo/myochallenge/reorient_v0.py @@ -5,7 +5,7 @@ import collections import numpy as np -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs.myo.base_v0 import BaseV0 from robohive.utils.quat_math import mat2euler, euler2quat diff --git a/robohive/envs/myo/myodm/__init__.py b/robohive/envs/myo/myodm/__init__.py index a49237e5..b53ecebe 100644 --- a/robohive/envs/myo/myodm/__init__.py +++ b/robohive/envs/myo/myodm/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + import collections import os import numpy as np diff --git a/robohive/envs/myo/myodm/myodm_v0.py b/robohive/envs/myo/myodm/myodm_v0.py index d59feda6..062b74b0 100644 --- a/robohive/envs/myo/myodm/myodm_v0.py +++ b/robohive/envs/myo/myodm/myodm_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat diff --git a/robohive/envs/myo/myomimic/__init__.py b/robohive/envs/myo/myomimic/__init__.py index 88bd1c3e..b92b0744 100644 --- a/robohive/envs/myo/myomimic/__init__.py +++ b/robohive/envs/myo/myomimic/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + import collections import os import numpy as np diff --git a/robohive/envs/myo/myomimic/myomimic_v0.py b/robohive/envs/myo/myomimic/myomimic_v0.py index 4100dde1..a5736e2a 100644 --- a/robohive/envs/myo/myomimic/myomimic_v0.py +++ b/robohive/envs/myo/myomimic/myomimic_v0.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat diff --git a/robohive/envs/quadrupeds/__init__.py b/robohive/envs/quadrupeds/__init__.py index b56c9fb5..5a3af0ef 100644 --- a/robohive/envs/quadrupeds/__init__.py +++ b/robohive/envs/quadrupeds/__init__.py @@ -1,4 +1,5 @@ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + from robohive.envs.env_variants import register_env_variant import numpy as np import os diff --git a/robohive/envs/quadrupeds/orient_v0.py b/robohive/envs/quadrupeds/orient_v0.py index ef92f28c..cc857535 100644 --- a/robohive/envs/quadrupeds/orient_v0.py +++ b/robohive/envs/quadrupeds/orient_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/quadrupeds/stand_v0.py b/robohive/envs/quadrupeds/stand_v0.py index 071b7e17..f5597c69 100644 --- a/robohive/envs/quadrupeds/stand_v0.py +++ b/robohive/envs/quadrupeds/stand_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/quadrupeds/walk_v0.py b/robohive/envs/quadrupeds/walk_v0.py index 4a793c3b..d2d03149 100644 --- a/robohive/envs/quadrupeds/walk_v0.py +++ b/robohive/envs/quadrupeds/walk_v0.py @@ -6,7 +6,7 @@ ================================================= """ import collections -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from robohive.envs import env_base diff --git a/robohive/envs/tcdm/__init__.py b/robohive/envs/tcdm/__init__.py index 7a63ca85..c9d90979 100644 --- a/robohive/envs/tcdm/__init__.py +++ b/robohive/envs/tcdm/__init__.py @@ -5,7 +5,8 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -from gym.envs.registration import register +from robohive.utils.import_utils import import_gym; gym = import_gym(); register=gym.register + import numpy as np import os import collections diff --git a/robohive/envs/tcdm/playback_mocap.py b/robohive/envs/tcdm/playback_mocap.py index 719593fc..d6bf6435 100644 --- a/robohive/envs/tcdm/playback_mocap.py +++ b/robohive/envs/tcdm/playback_mocap.py @@ -1,5 +1,5 @@ import robohive -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np from mocap_utils import MoCapController, MoCapTask import dm_env diff --git a/robohive/envs/tcdm/track.py b/robohive/envs/tcdm/track.py index e865ee80..dda2ac92 100644 --- a/robohive/envs/tcdm/track.py +++ b/robohive/envs/tcdm/track.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.envs import env_base from robohive.logger.reference_motion import ReferenceMotion from robohive.utils.quat_math import quat2euler, euler2quat, quatDiff2Vel, mat2quat diff --git a/robohive/logger/examine_logs.py b/robohive/logger/examine_logs.py index 95df7050..1b09e876 100644 --- a/robohive/logger/examine_logs.py +++ b/robohive/logger/examine_logs.py @@ -17,7 +17,7 @@ from robohive.utils.paths_utils import plot as plotnsave_paths from robohive.utils import tensor_utils -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import click import numpy as np import time @@ -48,7 +48,7 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu # seed and load environments np.random.seed(seed) env = gym.make(env_name) if env_args==None else gym.make(env_name, **(eval(env_args))) - env = env.env + env = env.unwrapped env.seed(seed) # Start a "trace" for recording rollouts @@ -122,7 +122,7 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu trace_horizon = horizon if mode=='record' else path_data['time'].shape[0]-1 # Rollout path -------------------------------- - obs, rwd, done, env_info = env.forward(update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.forward(update_exteroception=include_exteroception) ep_rwd = rwd for i_step in range(trace_horizon+1): @@ -205,10 +205,10 @@ def examine_logs(env_name, rollout_path, rollout_format, mode, horizon, seed, nu env.set_env_state(path_state[i_step+1]) else: raise NotImplementedError("Settings not found") - obs, rwd, done, env_info = env.forward(update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.forward(update_exteroception=include_exteroception) ep_rwd += rwd elif i_step < trace_horizon: # incase last step actions (nans) can cause issues in step - obs, rwd, done, env_info = env.step(act, update_exteroception=include_exteroception) + obs, rwd, done, *_, env_info = env.step(act, update_exteroception=include_exteroception) ep_rwd += rwd # save offscreen buffers as video and clear the dataset diff --git a/robohive/logger/examine_reference.py b/robohive/logger/examine_reference.py index 1aa7c9b6..4560b882 100644 --- a/robohive/logger/examine_reference.py +++ b/robohive/logger/examine_reference.py @@ -1,5 +1,5 @@ import robohive -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() import time import click from tqdm import tqdm diff --git a/robohive/robot/robot.py b/robohive/robot/robot.py index bc1d47d2..ea266dc2 100644 --- a/robohive/robot/robot.py +++ b/robohive/robot/robot.py @@ -792,7 +792,7 @@ def __del__(self): def demo_robot(): - import gym + from robohive.utils.import_utils import import_gym; gym = import_gym() prompt("Starting Robot===================") env = gym.make('FrankaReachFixed-v0') diff --git a/robohive/tests/test_envs.py b/robohive/tests/test_envs.py index 9ce870d9..0b01253e 100644 --- a/robohive/tests/test_envs.py +++ b/robohive/tests/test_envs.py @@ -7,12 +7,26 @@ import unittest -import gym +# from robohive.utils.import_utils import import_gym; gym = import_gym() +from robohive.utils.import_utils import import_gym; gym = import_gym() import numpy as np import pickle import copy -import torch.testing +# import torch.testing import os +from flatten_dict import flatten + +def assert_close(prm1, prm2, atol=1e-05, rtol=1e-08): + if prm1 is None and prm2 is None: + return True + elif isinstance(prm1,dict) and isinstance(prm2, dict): + prm1_dict = flatten(prm1) + prm2_dict = flatten(prm2) + for key in prm1_dict.keys(): + assert_close(prm1_dict[key], prm2_dict[key], atol=atol, rtol=rtol) + else: + np.testing.assert_allclose(prm1, prm2, atol=atol, rtol=rtol) + # torch.testing.assert_close(prm1, prm2, atol=atol, rtol=rtol) class TestEnvs(unittest.TestCase): @@ -35,14 +49,15 @@ def check_env(self, environment_id, input_seed): # test init env1 = gym.make(environment_id, seed=input_seed) assert env1.get_input_seed() == input_seed - # test reset - env1.env.reset() + # test reseed and reset + env1.seed(input_seed) + reset_obs1, *_ = env1.env.reset() # step u = 0.01*np.random.uniform(low=0, high=1, size=env1.env.sim.model.nu) # small controls - obs1, rwd1, done1, infos1 = env1.env.step(u.copy()) + obs1, rwd1, done1, *_, infos1 = env1.env.step(u.copy()) infos1 = copy.deepcopy(infos1) #info points to internal variables. - proprio1 = env1.env.get_proprioception() + proprio1_t, proprio1_vec, proprio1_dict = env1.env.get_proprioception() extero1 = env1.env.get_exteroception() assert len(obs1>0) # assert len(rwd1>0) @@ -57,26 +72,31 @@ def check_env(self, environment_id, input_seed): # serialize / deserialize env ------------ env2 = pickle.loads(pickle.dumps(env1)) - # test reset - env2.reset() # test seed assert env2.get_input_seed() == input_seed assert env1.get_input_seed() == env2.get_input_seed(), {env1.get_input_seed(), env2.get_input_seed()} # check input output spaces assert env1.action_space == env2.action_space, (env1.action_space, env2.action_space) assert env1.observation_space == env2.observation_space, (env1.observation_space, env2.observation_space) + + # test reseed and reset + env2.seed(input_seed) + reset_obs2, *_ = env2.env.reset() + assert_close(reset_obs1, reset_obs2) + # step - obs2, rwd2, done2, infos2 = env2.env.step(u) + obs2, rwd2, done2, *_, infos2 = env2.env.step(u) infos2 = copy.deepcopy(infos2) - proprio2 = env2.env.get_proprioception() + proprio2_t, proprio2_vec, proprio2_dict = env2.env.get_proprioception() extero2 = env2.env.get_exteroception() - torch.testing.assert_close(obs1, obs2) - torch.testing.assert_close(proprio1, proprio2) - torch.testing.assert_close(extero1, extero2, atol=2, rtol=0.04) - torch.testing.assert_close(rwd1, rwd2) + + assert_close(obs1, obs2) + assert_close(proprio1_vec, proprio2_vec)#, f"Difference in Proprio: {proprio1_vec-proprio2_vec}" + assert_close(extero1, extero2, atol=2, rtol=0.04)#, f"Difference in Extero {extero1}, {extero2}" + assert_close(rwd1, rwd2)#, "Difference in Rewards" assert (done1==done2), (done1, done2) assert len(infos1)==len(infos2), (infos1, infos2) - torch.testing.assert_close(infos1, infos2) + assert_close(infos1, infos2) # reset env2.reset() diff --git a/robohive/tests/test_logger.py b/robohive/tests/test_logger.py index 56dbf194..d07e0010 100644 --- a/robohive/tests/test_logger.py +++ b/robohive/tests/test_logger.py @@ -6,6 +6,7 @@ from robohive.logger.examine_logs import examine_logs from robohive.utils.examine_env import main as examine_env import os +import re class TestTrace(unittest.TestCase): def teast_trace(self): @@ -24,7 +25,8 @@ def test_logs_playback(self): "--render", "none",\ "--save_paths", True,\ "--output_name", "door_test_logs"]) - log_name = result.output.strip()[-38:] + log_name_pattern = re.compile(r'Saved: (?:.+\.h5)') + log_name = log_name_pattern.search(result.output)[0][7:] result = runner.invoke(examine_logs, ["--env_name", "door-v1", \ "--rollout_path", log_name, \ diff --git a/robohive/tests/test_versions.sh b/robohive/tests/test_versions.sh new file mode 100755 index 00000000..0d49553a --- /dev/null +++ b/robohive/tests/test_versions.sh @@ -0,0 +1,23 @@ +pip uninstall -y gym +pip uninstall -y gymnasium + +echo "=================== Testing gym==0.13 ===================" +pip install gym==0.13 +python tests/test_arms.py +python tests/test_examine_env.py +python tests/test_examine_robot.py +python tests/test_logger.py +python tests/test_robot.py +pip uninstall -y gym + +echo "=================== Testing gym==0.26.2 ===================" +pip install gym==0.26.2 +python tests/test_arms.py +python tests/test_all.py +pip uninstall -y gym + +echo "=================== Testing gymnasium ===================" +pip install gymnasium +python tests/test_arms.py +python tests/test_all.py +pip uninstall -y gymnasium diff --git a/robohive/tutorials/ee_teleop.py b/robohive/tutorials/ee_teleop.py index e2abbea2..bdcff8d6 100644 --- a/robohive/tutorials/ee_teleop.py +++ b/robohive/tutorials/ee_teleop.py @@ -18,7 +18,7 @@ from robohive.logger.grouped_datasets import Trace as RoboHive_Trace import numpy as np import click -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() try: from vtils.input.keyboard import KeyInput as KeyBoard diff --git a/robohive/tutorials/ee_teleop_oculus.py b/robohive/tutorials/ee_teleop_oculus.py index 01d5fd0e..30d7292a 100644 --- a/robohive/tutorials/ee_teleop_oculus.py +++ b/robohive/tutorials/ee_teleop_oculus.py @@ -15,7 +15,7 @@ import time import numpy as np import click -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.utils.quat_math import euler2quat, euler2mat, mat2quat, diffQuat, mulQuat from robohive.utils.inverse_kinematics import IKResult, qpos_from_site_pose from robohive.logger.roboset_logger import RoboSet_Trace diff --git a/robohive/utils/examine_env.py b/robohive/utils/examine_env.py index d70f6b41..f8986869 100644 --- a/robohive/utils/examine_env.py +++ b/robohive/utils/examine_env.py @@ -5,7 +5,7 @@ License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================= """ -import gym +from robohive.utils.import_utils import import_gym; gym = import_gym() from robohive.utils.paths_utils import plot as plotnsave_paths import click import numpy as np @@ -28,7 +28,7 @@ class rand_policy(): def __init__(self, env, seed): self.env = env - self.env.action_space.np_random.seed(seed) # requires exlicit seeding + self.env.action_space.seed(seed) # requires explicit seeding def get_action(self, obs): # return self.env.np_random.uniform(high=self.env.action_space.high, low=self.env.action_space.low) diff --git a/robohive/utils/implement_for.py b/robohive/utils/implement_for.py new file mode 100644 index 00000000..c1b4e102 --- /dev/null +++ b/robohive/utils/implement_for.py @@ -0,0 +1,211 @@ +from __future__ import annotations +import collections +import inspect +import sys +from copy import copy +from functools import wraps +from importlib import import_module +from typing import Union, Callable, Dict +from packaging.version import parse + +class implement_for: + """A version decorator that checks the version in the environment and implements a function with the fitting one. + + If specified module is missing or there is no fitting implementation, call of the decorated function + will lead to the explicit error. + In case of intersected ranges, last fitting implementation is used. + + This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, + numpy vs jax-numpy etc). + + Args: + module_name (str or callable): version is checked for the module with this + name (e.g. "gym"). If a callable is provided, it should return the + module. + from_version: version from which implementation is compatible. Can be open (None). + to_version: version from which implementation is no longer compatible. Can be open (None). + + Examples: + >>> @implement_for("gym", "0.13", "0.14") + >>> def fun(self, x): + ... # Older gym versions will return x + 1 + ... return x + 1 + ... + >>> @implement_for("gym", "0.14", "0.23") + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for(lambda: import_module("gym"), "0.23", None) + >>> def fun(self, x): + ... # More recent gym versions will return x + 2 + ... return x + 2 + ... + >>> @implement_for("gymnasium", "0.27", None) + >>> def fun(self, x): + ... # If gymnasium is to be used instead of gym, x+3 will be returned + ... return x + 3 + ... + + This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. + """ + + # Stores pointers to fitting implementations: dict[func_name] = func_pointer + _implementations = {} + _setters = [] + _cache_modules = {} + + def __init__( + self, + module_name: Union[str, Callable], + from_version: str = None, + to_version: str = None, + ): + self.module_name = module_name + self.from_version = from_version + self.to_version = to_version + implement_for._setters.append(self) + + @staticmethod + def check_version(version, from_version, to_version): + return (from_version is None or parse(version) >= parse(from_version)) and ( + to_version is None or parse(version) < parse(to_version) + ) + + @staticmethod + def get_class_that_defined_method(f): + """Returns the class of a method, if it is defined, and None otherwise.""" + out = f.__globals__.get(f.__qualname__.split(".")[0], None) + return out + + @classmethod + def get_func_name(cls, fn): + # produces a name like torchrl.module.Class.method or torchrl.module.function + first = str(fn).split(".")[0][len(" str: + """Imports module and returns its version.""" + if not callable(module_name): + module = cls._cache_modules.get(module_name, None) + if module is None: + if module_name in sys.modules: + sys.modules[module_name] = module = import_module(module_name) + else: + cls._cache_modules[module_name] = module = import_module( + module_name + ) + else: + module = module_name() + return module.__version__ + + _lazy_impl = collections.defaultdict(list) + + def _delazify(self, func_name): + for local_call in implement_for._lazy_impl[func_name]: + out = local_call() + return out + + def __call__(self, fn): + # function names are unique + self.func_name = self.get_func_name(fn) + self.fn = fn + implement_for._lazy_impl[self.func_name].append(self._call) + + @wraps(fn) + def _lazy_call_fn(*args, **kwargs): + # first time we call the function, we also do the replacement. + # This will cause the imports to occur only during the first call to fn + return self._delazify(self.func_name)(*args, **kwargs) + + return _lazy_call_fn + + def _call(self): + + # If the module is missing replace the function with the mock. + fn = self.fn + func_name = self.func_name + implementations = implement_for._implementations + + @wraps(fn) + def unsupported(*args, **kwargs): + raise ModuleNotFoundError( + f"Supported version of '{func_name}' has not been found." + ) + + self.do_set = False + # Return fitting implementation if it was encountered before. + if func_name in implementations: + try: + # check that backends don't conflict + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + if not self.do_set: + return implementations[func_name].fn + except ModuleNotFoundError: + # then it's ok, there is no conflict + return implementations[func_name].fn + else: + try: + version = self.import_module(self.module_name) + if self.check_version(version, self.from_version, self.to_version): + self.do_set = True + except ModuleNotFoundError: + return unsupported + if self.do_set: + self.module_set() + return fn + return unsupported + + @classmethod + def reset(cls, setters_dict: Dict[str, implement_for] = None): + """Resets the setters in setter_dict. + + ``setter_dict`` is a copy of implementations. We just need to iterate through its + values and call :meth:`~.module_set` for each. + + """ + if setters_dict is None: + setters_dict = copy(cls._implementations) + for setter in setters_dict.values(): + setter.module_set() + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"module_name={self.module_name}({self.from_version, self.to_version}), " + f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})" + ) diff --git a/robohive/utils/import_utils.py b/robohive/utils/import_utils.py index 7dfa9f12..dfd3475b 100644 --- a/robohive/utils/import_utils.py +++ b/robohive/utils/import_utils.py @@ -3,6 +3,15 @@ from os.path import expanduser import git + +def import_gym(): + if importlib.util.find_spec("gymnasium"): + import gymnasium as gg + elif importlib.util.find_spec("gym"): + import gym as gg + return gg + + def mujoco_py_isavailable(): help = """ Options: