Skip to content

Commit

Permalink
BUGFIX: Fixing return type issues of envs with custom reset
Browse files Browse the repository at this point in the history
- gym/gymnasium compatibility doesn't port over when custom reset is provided
- we are now recommending a call to super().reset(**kwargs) to get the appropritae return type ensuring compatibility
  • Loading branch information
vikashplus committed Apr 30, 2024
1 parent 63b44ad commit a9d227e
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 34 deletions.
4 changes: 2 additions & 2 deletions robohive/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ def get_input_seed(self):

def _reset(self, reset_qpos=None, reset_qvel=None, seed=None, **kwargs):
"""
Reset the environment
Default implemention provided. Override if env needs custom reset
Reset the environment (Default implemention provided).
Override if env needs custom reset. Carefully handle return type for gym/gymnasium compatibility
"""
qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel
Expand Down
8 changes: 3 additions & 5 deletions robohive/envs/hands/baoding_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,16 +256,15 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def _reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs):
def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs):
# reset counters
self.counter=0

# reset goal
self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy()

# reset scene
obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs)
return obs
return super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs)

def create_goal_trajectory(self, time_step=.1, time_period=6):
len_of_goals = 1000 # assumes that its greator than env horizon
Expand Down Expand Up @@ -326,5 +325,4 @@ def create_goal_trajectory(self, time_step=.1, time_period=6):
class BaodingRandomEnvV1(BaodingFixedEnvV1):

def reset(self, **kwargs):
obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs)
return obs
return super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs)
8 changes: 2 additions & 6 deletions robohive/envs/hands/door_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,13 @@ def get_reward_dict(self, obs_dict):
return rwd_dict


def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, **kwargs):
self.sim.reset()
qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs)

self.sim.model.body_pos[self.door_bid,0] = self.np_random.uniform(low=-0.3, high=-0.2)
self.sim.model.body_pos[self.door_bid, 1] = self.np_random.uniform(low=0.25, high=0.35)
self.sim.model.body_pos[self.door_bid,2] = self.np_random.uniform(low=0.252, high=0.35)
self.sim.forward()
return self.get_obs()
return super().reset(**kwargs)


def get_env_state(self):
Expand Down
9 changes: 3 additions & 6 deletions robohive/envs/hands/hammer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,12 @@ def get_obs_dict(self, sim):
return obs_dict


def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, **kwargs):
self.sim.reset()
qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs)

self.sim.model.body_pos[self.target_bid,2] = self.np_random.uniform(low=0.1, high=0.25)
self.sim.forward()
return self.get_obs()
return super().reset(**kwargs)


def get_env_state(self):
"""
Expand Down
9 changes: 2 additions & 7 deletions robohive/envs/hands/pen_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,14 @@ def get_reward_dict(self, obs_dict):
return rwd_dict


def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, **kwargs):
self.sim.reset()
qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs)

desired_orien = np.zeros(3)
desired_orien[0] = self.np_random.uniform(low=-1, high=1)
desired_orien[1] = self.np_random.uniform(low=-1, high=1)
self.sim.model.body_quat[self.target_obj_bid] = euler2quat(desired_orien)
self.sim.forward()

return self.get_obs()
return super().reset(**kwargs)


def get_env_state(self):
Expand Down
10 changes: 3 additions & 7 deletions robohive/envs/hands/relocate_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,20 +136,16 @@ def get_obs_dict(self, sim):
return obs_dict


def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
def reset(self, **kwargs):
self.sim.reset()
qp = self.init_qpos.copy() if reset_qpos is None else reset_qpos
qv = self.init_qvel.copy() if reset_qvel is None else reset_qvel
self.robot.reset(reset_pos=qp, reset_vel=qv, **kwargs)


self.sim.model.body_pos[self.obj_bid,0] = self.np_random.uniform(low=-0.15, high=0.15)
self.sim.model.body_pos[self.obj_bid,1] = self.np_random.uniform(low=-0.15, high=0.3)
self.sim.model.site_pos[self.target_obj_sid, 0] = self.np_random.uniform(low=-0.2, high=0.2)
self.sim.model.site_pos[self.target_obj_sid,1] = self.np_random.uniform(low=-0.2, high=0.2)
self.sim.model.site_pos[self.target_obj_sid,2] = self.np_random.uniform(low=0.15, high=0.35)
self.sim.forward()
return self.get_obs()
return super().reset(**kwargs)


def get_env_state(self):
"""
Expand Down
18 changes: 17 additions & 1 deletion robohive/tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest

from robohive.utils import gym
from robohive.utils.implement_for import implement_for
import numpy as np
import pickle
import copy
Expand Down Expand Up @@ -66,7 +67,8 @@ def check_env(self, environment_id, input_seed):
rwd_dict1 = env1.get_reward_dict(obs_dict1)
assert len(rwd_dict1) > 0
# reset env
env1.reset()
reset_data = env1.reset()
self.check_reset(reset_data)

# serialize / deserialize env ------------
env2w = pickle.loads(pickle.dumps(env1w))
Expand Down Expand Up @@ -102,6 +104,20 @@ def check_env(self, environment_id, input_seed):
del(env1)
del(env2)


@implement_for("gym", None, "0.26")
def check_reset(self, reset_data):
assert isinstance(reset_data, np.ndarray), "Reset should return the observation vector"

@implement_for("gym", "0.26", None)
def check_reset(self, reset_data):
assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2"
assert isinstance(reset_data[1], dict), "second element returned should be a dict"
@implement_for("gymnasium")
def check_reset(self, reset_data):
assert isinstance(reset_data, tuple) and len(reset_data) == 2, "Reset should return a tuple of length 2"
assert isinstance(reset_data[1], dict), "second element returned should be a dict"

def check_old_envs(self, module_name, env_names, lite=False, seed=1234):
print("\nTesting module:: ", module_name)
for env_name in env_names:
Expand Down

0 comments on commit a9d227e

Please sign in to comment.