Skip to content

Commit

Permalink
BUGFIX: Adding kwargs to all resets
Browse files Browse the repository at this point in the history
  • Loading branch information
vikashplus committed Jan 1, 2024
1 parent 047f158 commit e4414a1
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 26 deletions.
4 changes: 2 additions & 2 deletions robohive/envs/claws/reorient_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def reset(self):
def reset(self, **kwargs):
desired_pos = self.np_random.uniform(high=self.target_xyz_range['high'], low=self.target_xyz_range['low'])
self.sim.model.site_pos[self.target_sid] = desired_pos
self.sim_obsd.model.site_pos[self.target_sid] = desired_pos
Expand All @@ -108,5 +108,5 @@ def reset(self):
self.sim.model.site_quat[self.target_sid] = euler2quat(desired_orien)
self.sim_obsd.model.site_quat[self.target_sid] = euler2quat(desired_orien)

obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/fm/franka_ee_pose_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ def get_target_pose(self):
return self.np_random.uniform(low=self.sim.model.actuator_ctrlrange[:,0], high=self.sim.model.actuator_ctrlrange[:,1])


def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
self.target_pose = self.get_target_pose()
obs = super().reset(reset_qpos, reset_qvel)
obs = super().reset(reset_qpos, reset_qvel, **kwargs)
return obs

class FrankaRobotiqPose(FrankaEEPose):
Expand Down
8 changes: 4 additions & 4 deletions robohive/envs/hands/baoding_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,15 +256,15 @@ def get_reward_dict(self, obs_dict):
rwd_dict['dense'] = np.sum([wt*rwd_dict[key] for key, wt in self.rwd_keys_wt.items()], axis=0)
return rwd_dict

def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6):
def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=6, **kwargs):
# reset counters
self.counter=0

# reset goal
self.goal = self.create_goal_trajectory(time_period=time_period) if reset_goal is None else reset_goal.copy()

# reset scene
obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel)
obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs)
return obs

def create_goal_trajectory(self, time_step=.1, time_period=6):
Expand Down Expand Up @@ -325,6 +325,6 @@ def create_goal_trajectory(self, time_step=.1, time_period=6):

class BaodingRandomEnvV1(BaodingFixedEnvV1):

def reset(self):
obs = super().reset(time_period = self.np_random.uniform(high=5, low=7))
def reset(self, **kwargs):
obs = super().reset(time_period = self.np_random.uniform(high=5, low=7), **kwargs)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/multi_task/common/franka_appliance_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _setup(
**kwargs,
)

def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
# randomize object bodies, if requested
if self.obj_body_randomize:
for body_name in self.obj_body_randomize:
Expand All @@ -78,4 +78,4 @@ def reset(self, reset_qpos=None, reset_qvel=None):
* (self.robot_ranges[:, 1] - self.robot_ranges[:, 0])
)

return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel)
return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs)
4 changes: 2 additions & 2 deletions robohive/envs/multi_task/common/franka_kitchen_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _setup(
)


def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):
if reset_qpos is None:
reset_qpos = self.init_qpos.copy()

Expand All @@ -128,4 +128,4 @@ def reset(self, reset_qpos=None, reset_qvel=None):
if self.robot_base_range:
self.sim.model.body_pos[self.robot_base_bid] = self.robot_base_pos + self.np_random.uniform(**self.robot_base_range)

return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel)
return super().reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs)
4 changes: 2 additions & 2 deletions robohive/envs/myo/myobase/baoding_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def evaluate_success(self, paths, logger=None, successful_steps=5):
logger.log_kv('effort', effort)
return success_percentage

def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None):
def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=None, **kwargs):
# reset counters
self.counter=0
self.x_radius=self.np_random.uniform(low=self.goal_xrange[0], high=self.goal_xrange[1])
Expand All @@ -273,7 +273,7 @@ def reset(self, reset_pose=None, reset_vel=None, reset_goal=None, time_period=No
self.goal = self.create_goal_trajectory(time_step=self.dt, time_period=time_period) if reset_goal is None else reset_goal.copy()

# reset scene
obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel)
obs = super().reset(reset_qpos=reset_pose, reset_qvel=reset_vel, **kwargs)
return obs

def create_goal_trajectory(self, time_step=.1, time_period=6):
Expand Down
4 changes: 2 additions & 2 deletions robohive/envs/myo/myodm/myodm_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ def playback(self):
return idxs[0] < self.ref.horizon-1


def reset(self):
def reset(self, **kwargs):
# print("Reset")
self.ref.reset()
obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
# print(self.time, self.sim.data.qpos)
return obs

Expand Down
4 changes: 2 additions & 2 deletions robohive/envs/myo/myomimic/myomimic_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ def playback(self):
return idxs[0] < self.ref.horizon-1


def reset(self):
def reset(self, **kwargs):
# print("Reset")
self.ref.reset()
obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
# print(self.time, self.sim.data.qpos)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/quadrupeds/orient_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict):
return rwd_dict


def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):

reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos
reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6)
Expand All @@ -182,5 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None):
self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0])
# Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target
self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0])
obs = super().reset(reset_qpos, reset_qvel)
obs = super().reset(reset_qpos, reset_qvel, **kwargs)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/quadrupeds/stand_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_reward_dict(self, obs_dict):
return rwd_dict


def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):

if reset_qpos is None:
reset_qpos = self.init_qpos.copy()
Expand All @@ -189,5 +189,5 @@ def reset(self, reset_qpos=None, reset_qvel=None):
else:
raise TypeError(f"Unknown reset type: {self.reset_type}")

obs = super().reset(reset_qpos, reset_qvel)
obs = super().reset(reset_qpos, reset_qvel, **kwargs)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/quadrupeds/walk_v0.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_reward_dict(self, obs_dict):
return rwd_dict


def reset(self, reset_qpos=None, reset_qvel=None):
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs):

reset_qpos = self.init_qpos.copy() if reset_qpos is None else reset_qpos
reset_qpos[6:] += np.pi/8*self.np_random.uniform(low=-1, high=1, size=self.sim.model.nq-6)
Expand All @@ -182,5 +182,5 @@ def reset(self, reset_qpos=None, reset_qvel=None):
self.sim.model.site_pos[self.target_sid] = target_dist * np.array([np.cos(target_theta), np.sin(target_theta), 0])
# Heading target is a bit farther away to avoid heading oscillations when quad is near xy_target
self.sim.model.site_pos[self.heading_sid] = (target_dist+0.5) * np.array([np.cos(target_theta), np.sin(target_theta), 0])
obs = super().reset(reset_qpos, reset_qvel)
obs = super().reset(reset_qpos, reset_qvel, **kwargs)
return obs
4 changes: 2 additions & 2 deletions robohive/envs/tcdm/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ def playback(self):
return idxs[0] < self.ref.horizon-1


def reset(self):
def reset(self, **kwargs):
# print("Reset")
self.ref.reset()
obs = super().reset(self.init_qpos, self.init_qvel)
obs = super().reset(self.init_qpos, self.init_qvel, **kwargs)
# print(self.time, self.sim.data.qpos)
return obs

Expand Down

0 comments on commit e4414a1

Please sign in to comment.