-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Vsys feature: massively parallel domain randomization #458
Open
Velythyl
wants to merge
20
commits into
google:main
Choose a base branch
from
Velythyl:vsys
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
4d569d4
Merge pull request #1 from google/main
Velythyl ea0c6ae
dr works. onto prod
Velythyl a60207d
done?
Velythyl d8d4914
upd
Velythyl 769d057
added tracking of current vals
Velythyl 0fc424b
fixed vals
Velythyl 2d1fdcd
fixed vals again\?
Velythyl e0317b3
added logging for resampling
Velythyl dda64ef
hmm
Velythyl 3c66f57
fix logging of skr vals
Velythyl 5408fcc
wat
Velythyl 5952601
merge
Velythyl b6cab64
pr
Velythyl 3da30d0
small cleanup
Velythyl 0d835eb
huge refactor
Velythyl ca84df8
fix rngs
Velythyl 3ae61cf
more explicit
Velythyl f952b9f
fix pusher
Velythyl 3ff0565
upd
Velythyl 7767989
upd
Velythyl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -17,7 +17,7 @@ | |||||
|
||||||
from typing import Tuple | ||||||
|
||||||
from brax import base | ||||||
from brax import base, System | ||||||
from brax.envs.base import PipelineEnv, State | ||||||
from brax.io import mjcf | ||||||
from etils import epath | ||||||
|
@@ -195,19 +195,19 @@ def __init__( | |||||
exclude_current_positions_from_observation | ||||||
) | ||||||
|
||||||
def reset(self, rng: jax.Array) -> State: | ||||||
def reset(self, sys: System, rng: jax.Array) -> State: | ||||||
"""Resets the environment to an initial state.""" | ||||||
rng, rng1, rng2 = jax.random.split(rng, 3) | ||||||
|
||||||
low, hi = -self._reset_noise_scale, self._reset_noise_scale | ||||||
qpos = self.sys.init_q + jax.random.uniform( | ||||||
rng1, (self.sys.q_size(),), minval=low, maxval=hi | ||||||
qpos = sys.init_q + jax.random.uniform( | ||||||
rng1, (sys.q_size(),), minval=low, maxval=hi | ||||||
) | ||||||
qvel = jax.random.uniform( | ||||||
rng2, (self.sys.qd_size(),), minval=low, maxval=hi | ||||||
rng2, (sys.qd_size(),), minval=low, maxval=hi | ||||||
) | ||||||
|
||||||
pipeline_state = self.pipeline_init(qpos, qvel) | ||||||
pipeline_state = self.pipeline_init(sys, qpos, qvel) | ||||||
|
||||||
obs = self._get_obs(pipeline_state) | ||||||
reward, done, zero = jp.zeros(3) | ||||||
|
@@ -218,12 +218,12 @@ def reset(self, rng: jax.Array) -> State: | |||||
'x_position': zero, | ||||||
'x_velocity': zero, | ||||||
} | ||||||
return State(pipeline_state, obs, reward, done, metrics) | ||||||
return State(pipeline_state, obs, reward, done,sys, metrics) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def step(self, state: State, action: jax.Array) -> State: | ||||||
"""Runs one timestep of the environment's dynamics.""" | ||||||
pipeline_state0 = state.pipeline_state | ||||||
pipeline_state = self.pipeline_step(pipeline_state0, action) | ||||||
pipeline_state = self.pipeline_step(state.sys, pipeline_state0, action) | ||||||
|
||||||
x_velocity = ( | ||||||
pipeline_state.x.pos[0, 0] - pipeline_state0.x.pos[0, 0] | ||||||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.