Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703555269
Change-Id: Iab364e8db6be86e4632cbc962f1eeef2356c8d56
  • Loading branch information
Brax Team authored and btaba committed Dec 6, 2024
1 parent 49f03c0 commit 8becede
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 6 deletions.
8 changes: 2 additions & 6 deletions brax/envs/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def reset(self, rng: jax.Array) -> State:
'pixels/view_1': jp.zeros((4, 4, 3)),
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = obs
elif self._obs_mode == ObservationMode.DICT_PIXELS:
if self._obs_mode == ObservationMode.DICT_PIXELS:
obs = pixels
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = {**obs, **pixels}
Expand Down Expand Up @@ -104,9 +102,7 @@ def step(self, state: State, action: jax.Array) -> State:
'pixels/view_1': jp.zeros((4, 4, 3)),
}

if self._obs_mode == ObservationMode.DICT_STATE:
obs = obs
elif self._obs_mode == ObservationMode.DICT_PIXELS:
if self._obs_mode == ObservationMode.DICT_PIXELS:
obs = pixels
elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE:
obs = {**obs, **pixels}
Expand Down
1 change: 1 addition & 0 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _random_translate_pixels(
Returns:
A dictionary of observations with translated pixels
"""
obs = core.FrozenDict(obs)

@jax.vmap
def rt_all_views(
Expand Down
1 change: 1 addition & 0 deletions brax/training/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def make_policy_network_vision(
)

def apply(processor_params, policy_params, obs):
obs = core.FrozenDict(obs)
if state_obs_key:
state_obs = preprocess_observations_fn(
obs[state_obs_key], normalizer_select(processor_params, state_obs_key)
Expand Down

0 comments on commit 8becede

Please sign in to comment.