-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making all tests pass with Implement_for functionality. Enables support most gym/gymnasium versions #119
Making all tests pass with Implement_for functionality. Enables support most gym/gymnasium versions #119
Changes from all commits
742e4dd
009a373
2b6a164
051d167
a7b61bd
f0f128d
b4545e0
32fd8d1
7739c6a
261300e
5de406a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,14 +5,17 @@ | |
License :: Under Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. | ||
================================================= """ | ||
|
||
import gym | ||
# TODO: find how to make this compatible with gymnasium. Maybe a global variable that indicates what to use as backend? | ||
# from robohive.utils.import_utils import import_gym; gym = import_gym() | ||
from robohive.utils.import_utils import import_gym; gym = import_gym() | ||
import numpy as np | ||
import os | ||
import time as timer | ||
|
||
from robohive.envs.obs_vec_dict import ObsVecDict | ||
from robohive.utils import tensor_utils | ||
from robohive.robot.robot import Robot | ||
from robohive.utils.implement_for import implement_for | ||
from robohive.utils.prompt_utils import prompt, Prompt | ||
import skvideo.io | ||
from sys import platform | ||
|
@@ -130,7 +133,7 @@ def _setup(self, | |
self._setup_rgb_encoders(self.visual_keys, device=None) | ||
|
||
# reset to get the env ready | ||
observation, _reward, done, _info = self.step(np.zeros(self.sim.model.nu)) | ||
observation, _reward, done, *_, _info = self.step(np.zeros(self.sim.model.nu)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could decorate this too but it's also ok if you don't need the truncated oc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems to work for now. Lets revisit in the future incase this isn't sufficient |
||
# Question: Should we replace above with following? Its specially helpful for hardware as it forces a env reset before continuing, without which the hardware will make a big jump from its position to the position asked by step. | ||
# observation = self.reset() | ||
assert not done, "Check initialization. Simulation starts in a done state." | ||
|
@@ -263,8 +266,23 @@ def step(self, a, **kwargs): | |
render_cbk=self.mj_render if self.mujoco_render_frames else None) | ||
return self.forward(**kwargs) | ||
|
||
@implement_for("gym", None, "0.24") | ||
def forward(self, **kwargs): | ||
return self._forward(**kwargs) | ||
|
||
@implement_for("gym", "0.24", None) | ||
def forward(self, **kwargs): | ||
obs, reward, done, info = self._forward(**kwargs) | ||
terminal = done | ||
return obs, reward, terminal, False, info | ||
|
||
@implement_for("gymnasium") | ||
def forward(self, **kwargs): | ||
obs, reward, done, info = self._forward(**kwargs) | ||
terminal = done | ||
return obs, reward, terminal, False, info | ||
|
||
def _forward(self, **kwargs): | ||
""" | ||
Forward propagate env to recover env details | ||
Returns current obs(t), rwd(t), done(t), info(t) | ||
|
@@ -476,7 +494,7 @@ def get_input_seed(self): | |
return self.input_seed | ||
|
||
|
||
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | ||
def _reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | ||
""" | ||
Reset the environment | ||
Default implemention provided. Override if env needs custom reset | ||
|
@@ -485,11 +503,19 @@ def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | |
qvel = self.init_qvel.copy() if reset_qvel is None else reset_qvel | ||
self.robot.reset(qpos, qvel, **kwargs) | ||
return self.get_obs() | ||
@implement_for("gym", None, "0.26") | ||
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | ||
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs) | ||
@implement_for("gym", "0.26", None) | ||
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | ||
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} | ||
@implement_for("gymnasium") | ||
def reset(self, reset_qpos=None, reset_qvel=None, **kwargs): | ||
return self._reset(reset_qpos=reset_qpos, reset_qvel=reset_qvel, **kwargs), {} | ||
|
||
|
||
@property | ||
def _step(self, a): | ||
return self.step(a) | ||
# @property | ||
# def _step(self, a): | ||
# return self.step(a) | ||
|
||
|
||
@property | ||
|
@@ -702,7 +728,7 @@ def examine_policy(self, | |
ep_rwd = 0.0 | ||
while t < horizon and done is False: | ||
a = policy.get_action(o)[0] if mode == 'exploration' else policy.get_action(o)[1]['evaluation'] | ||
next_o, rwd, done, env_info = self.step(a) | ||
next_o, rwd, done, *_, env_info = self.step(a) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
ep_rwd += rwd | ||
# render offscreen visuals | ||
if render =='offscreen': | ||
|
@@ -794,7 +820,7 @@ def examine_policy_new(self, | |
ep_rwd = 0.0 | ||
|
||
# Rollout -------------------------------- | ||
obs, rwd, done, env_info = self.forward(update_exteroception=True) # t=0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
obs, rwd, done, *_, env_info = self.forward(update_exteroception=True) # t=0 | ||
while t < horizon and done is False: | ||
|
||
# print(t, t*self.dt, self.time, t*self.dt-self.time) | ||
|
@@ -825,7 +851,7 @@ def examine_policy_new(self, | |
|
||
|
||
# step env using actions from t=>t+1 ---------------------- | ||
obs, rwd, done, env_info = self.step(act, update_exteroception=True) | ||
obs, rwd, done, *_, env_info = self.step(act, update_exteroception=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
t = t+1 | ||
ep_rwd += rwd | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this is unrelated?