diff --git a/robohive/envs/env_base.py b/robohive/envs/env_base.py index 1072b132..efc2d540 100644 --- a/robohive/envs/env_base.py +++ b/robohive/envs/env_base.py @@ -495,8 +495,8 @@ def get_input_seed(self): def _reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs): """ - Reset the environment - Default implemention provided. Override if env needs custom reset + Reset the environment (Default implemention provided). + Override if env needs custom reset. Carefully handle return type for gym/gymnasium compatibility """ qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel diff --git a/robohive/envs/hands/baoding_v1.py b/robohive/envs/hands/baoding_v1.py index ce68419e..1b6d97da 100644 --- a/robohive/envs/hands/baoding_v1.py +++ b/robohive/envs/hands/baoding_v1.py @@ -256,7 +256,7 @@ def get_reward_dict(self, obs_dict): rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0) return rwd_dict - def _reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs): + def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs): # reset counters self.counter=0 @@ -264,8 +264,7 @@ def _reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6 self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy() # reset scene - obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) - return obs + return super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs) def create_goal_trajectory(self, time_step=.1, time_period=6): len_of_goals = 1000 # assumes that its greator than env horizon @@ -326,5 +325,4 @@ def create_goal_trajectory(self, time_step=.1, time_period=6): class BaodingRandomEnvV1(BaodingFixedEnvV1): def reset(self, **kwargs): - obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) - return obs + return super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs) \ No newline at end of file diff --git a/robohive/envs/hands/door_v1.py b/robohive/envs/hands/door_v1.py index d0f80b9a..4e756e9f 100644 --- a/robohive/envs/hands/door_v1.py +++ b/robohive/envs/hands/door_v1.py @@ -96,17 +96,13 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - self.sim.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2) self.sim.model.body_pos[self.door_bid, 1] = self.np_random.uniform(low=0.25, high=0.35) self.sim.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) def get_env_state(self): diff --git a/robohive/envs/hands/hammer_v1.py b/robohive/envs/hands/hammer_v1.py index 9ce25020..cbdccb92 100644 --- a/robohive/envs/hands/hammer_v1.py +++ b/robohive/envs/hands/hammer_v1.py @@ -120,15 +120,12 @@ def get_obs_dict(self, sim): return obs_dict - def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - self.sim.model.body_pos[self.target_bid,2] = self.np_random.uniform(low=0.1, high=0.25) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) + def get_env_state(self): """ diff --git a/robohive/envs/hands/pen_v1.py b/robohive/envs/hands/pen_v1.py index 98172bbc..f38c1c1d 100644 --- a/robohive/envs/hands/pen_v1.py +++ b/robohive/envs/hands/pen_v1.py @@ -108,19 +108,14 @@ def get_reward_dict(self, obs_dict): return rwd_dict - def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - desired_orien = np.zeros(3) desired_orien[0] = self.np_random.uniform(low=-1, high=1) desired_orien[1] = self.np_random.uniform(low=-1, high=1) self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien) self.sim.forward() - - return self.get_obs() + return super().reset(**kwargs) def get_env_state(self): diff --git a/robohive/envs/hands/relocate_v1.py b/robohive/envs/hands/relocate_v1.py index 7b8da68c..be2d8ea0 100644 --- a/robohive/envs/hands/relocate_v1.py +++ b/robohive/envs/hands/relocate_v1.py @@ -136,20 +136,16 @@ def get_obs_dict(self, sim): return obs_dict - def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): + def reset(self, **kwargs): self.sim.reset() - qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos - qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel - self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs) - - self.sim.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15) self.sim.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3) self.sim.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2) self.sim.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2) self.sim.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35) self.sim.forward() - return self.get_obs() + return super().reset(**kwargs) + def get_env_state(self): """ diff --git a/robohive/tests/test_envs.py b/robohive/tests/test_envs.py index 78137191..cb01fcae 100644 --- a/robohive/tests/test_envs.py +++ b/robohive/tests/test_envs.py @@ -8,6 +8,7 @@ import unittest from robohive.utils import gym +from robohive.utils.implement_for import implement_for import numpy as np import pickle import copy @@ -66,7 +67,8 @@ def check_env(self, environment_id, input_seed): rwd_dict1 = env1.get_reward_dict(obs_dict1) assert len(rwd_dict1) > 0 # reset env - env1.reset() + reset_data = env1.reset() + self.check_reset(reset_data) # serialize / deserialize env ------------ env2w = pickle.loads(pickle.dumps(env1w)) @@ -102,6 +104,20 @@ def check_env(self, environment_id, input_seed): del(env1) del(env2) + + @implement_for("gym", None, "0.26") + def check_reset(self, reset_data): + assert isinstance(reset_data, np.ndarray), "Reset should return the observation vector" + + @implement_for("gym", "0.26", None) + def check_reset(self, reset_data): + assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2" + assert isinstance(reset_data[1], dict), "second element returned should be a dict" + @implement_for("gymnasium") + def check_reset(self, reset_data): + assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2" + assert isinstance(reset_data[1], dict), "second element returned should be a dict" + def check_old_envs(self, module_name, env_names, lite=False, seed=1234): print("\nTesting module:: ", module_name) for env_name in env_names: