Skip to content

Commit

Permalink
return depth from RoboHiveEnv
Browse files Browse the repository at this point in the history
  • Loading branch information
sriramsk1999 committed Apr 17, 2024
1 parent d2cfd28 commit 4e40877
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions torchrl/envs/libs/robohive.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ class RoboHiveEnv(GymEnv, metaclass=_RoboHiveBuild):
be returned (by default under the ``"pixels"`` entry in the output tensordict).
If ``False``, observations (eg, states) and pixels will be returned
whenever ``from_pixels=True``. Defaults to ``True``.
from_depths (bool, optional): if ``True``, an attempt to return the depth
observations from the env will be performed. By default, these observations
will be written under the ``"depths"`` entry. Requires ``from_pixels`` to be ``True``.
Defaults to ``False``.
frame_skip (int, optional): if provided, indicates for how many steps the
same action is to be repeated. The observation returned will be the
last observation of the sequence, whereas the reward will be the sum
Expand Down Expand Up @@ -155,6 +159,7 @@ def _build_env( # noqa: F811
env_name: str,
from_pixels: bool = False,
pixels_only: bool = False,
from_depths: bool = False,
**kwargs,
) -> "gym.core.Env": # noqa: F821
if from_pixels:
Expand All @@ -168,7 +173,9 @@ def _build_env( # noqa: F811
)
kwargs["cameras"] = self.get_available_cams(env_name)
cams = list(kwargs.pop("cameras"))
env_name = self.register_visual_env(cams=cams, env_name=env_name)
env_name = self.register_visual_env(
cams=cams, env_name=env_name, from_depths=from_depths
)

elif "cameras" in kwargs and kwargs["cameras"]:
raise RuntimeError("Got a list of cameras but from_pixels is set to False.")
Expand Down Expand Up @@ -209,6 +216,7 @@ def _build_env( # noqa: F811
# except Exception as err:
# raise RuntimeError(f"Failed to build env {env_name}.") from err
self.from_pixels = from_pixels
self.from_depths = from_depths
self.render_device = render_device
if kwargs.get("read_info", True):
self.set_info_dict_reader(self.read_info)
Expand All @@ -224,7 +232,7 @@ def _make_specs(self, env: "gym.Env", batch_size=None) -> None: # noqa: F821
return out

@classmethod
def register_visual_env(cls, env_name, cams):
def register_visual_env(cls, env_name, cams, from_depths):
with set_directory(cls.CURR_DIR):
from robohive.envs.env_variants import register_env_variant

Expand All @@ -236,6 +244,8 @@ def register_visual_env(cls, env_name, cams):
if new_env_name in cls.env_list:
return new_env_name
visual_keys = [f"rgb:{c}:224x224:2d" for c in cams]
if from_depths:
visual_keys.extend([f"d:{c}:224x224:2d" for c in cams])
register_env_variant(
env_name,
variants={
Expand All @@ -262,20 +272,26 @@ def get_obs():
if self.from_pixels:
visual = self.env.get_exteroception()
obs_dict.update(visual)
pixel_list = []
pixel_list, depth_list = [], []
for obs_key in obs_dict:
if obs_key.startswith("rgb"):
pix = obs_dict[obs_key]
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif obs_key.startswith("d:"):
dep = obs_dict[obs_key]
dep = dep[None]
depth_list.append(dep)
elif obs_key in env.obs_keys:
value = env.obs_dict[obs_key]
if not value.shape:
value = value[None]
_dict[obs_key] = value
if pixel_list:
_dict["pixels"] = np.concatenate(pixel_list, 0)
if depth_list:
_dict["depths"] = np.concatenate(depth_list, 0)
return _dict

for i in range(3):
Expand Down Expand Up @@ -335,7 +351,7 @@ def read_obs(self, observation):
pass
# recover vec
obsdict = {}
pixel_list = []
pixel_list, depth_list = [], []
if self.from_pixels:
visual = self.env.get_exteroception()
observations.update(visual)
Expand All @@ -345,6 +361,10 @@ def read_obs(self, observation):
if not pix.shape[0] == 1:
pix = pix[None]
pixel_list.append(pix)
elif key.startswith("d:"):
dep = observations[key]
dep = dep[None]
depth_list.append(dep)
elif key in self._env.obs_keys:
value = observations[key]
if not value.shape:
Expand All @@ -354,6 +374,8 @@ def read_obs(self, observation):
# obsvec = np.concatenate(obsvec, 0)
if self.from_pixels:
obsdict.update({"pixels": np.concatenate(pixel_list, 0)})
if self.from_pixels and self.from_depths:
obsdict.update({"depths": np.concatenate(depth_list, 0)})
out = obsdict
return super().read_obs(out)

Expand Down

0 comments on commit 4e40877

Please sign in to comment.