-
Notifications
You must be signed in to change notification settings - Fork 265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Vsys feature: massively parallel domain randomization #458
base: main
Are you sure you want to change the base?
Conversation
Hey @Velythyl I've been looking forward to this feature for a while now, thanks a lot for sharing this! |
@lebrice Hey! Sorry, I realized I had some cleanup to do, and it was way past 5pm so I wanted to go home. I reopened it now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I'm not a maintainer, this is just a fix for some spacing typos, this is very clean!)
@@ -173,12 +173,12 @@ def reset(self, rng: jax.Array) -> State: | |||
'reward_ctrl': zero, | |||
'reward_run': zero, | |||
} | |||
return State(pipeline_state, obs, reward, done, metrics) | |||
return State(pipeline_state, obs, reward, done,sys, metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return State(pipeline_state, obs, reward, done,sys, metrics) | |
return State(pipeline_state, obs, reward, done, sys, metrics) |
@@ -218,12 +218,12 @@ def reset(self, rng: jax.Array) -> State: | |||
'x_position': zero, | |||
'x_velocity': zero, | |||
} | |||
return State(pipeline_state, obs, reward, done, metrics) | |||
return State(pipeline_state, obs, reward, done,sys, metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return State(pipeline_state, obs, reward, done,sys, metrics) | |
return State(pipeline_state, obs, reward, done, sys, metrics) |
Thanks @Velythyl ! The recent comment made me just realize that maintainers hadn't commented on the PR. There were a few design decisions that went into DomainRandomizationVmapWrapper:
The cons of the impl at HEAD are that:
What I think would make sense to merge, is to add a wrapper with the same API as |
Hello!
For an unrelated research project, I needed a massively parallel RL environment with domain randomization capabilities. Isaac Sim/Gym/Omniverse fit the bill, but I also needed the simulator to be differentiable w.r.t. each domain randomization parameters.
So I set out to implement DR in brax. This is research code, so it's obviously a little janky and ad-hoc. But I thought maybe the brax community could find this interesting, and perhaps (with a lot of tuning) even merge it into brax main.
Special thanks to this github issue from which I stole some code ;) here
Note that this domain randomization method is more powerful than this. With this code, we can randomize every single simulation step, if we so wish.
The summary of the implementation is simple: we just augment the simulation state to contain
sys
, thereby allowing every single parallel environment access to its own separatesys
. Also, this enables us to resamplesys
according to some rule (for example, "resample every 50 steps").Features:
The vsys wrapper allows for a vectorized
sys
variable that might contain different domain randomization values for each vectorized envDomain randomization is controlled via a simple yaml file format that describes the path to a domain randomization target. Example:
This randomizes over the 7 links of the robot. For the mass, the base is "r", so the value is "read" from the default value defined in the URDF file. The min-max ranges are both relative to the base, so the current setup randomizes from [r-0.5, r+0.5]. For the damping, no base is given, which defaults to "r". One could also set the base to a float value. Another possible value for the base is "n" ("none"), which disables randomization for this index.
Domain randomization is differentiable (!)
For example, running a simple optax optimizer, we can obtain the true domain randomization parameters in play for a specific timestep.
Known issues:
sys
to be included in the state, a few of the python type hints are brokenif __name__ == "__main__":
function. Specifically, here: https://github.com/Velythyl/brax/blob/b6cab6449ba677108e37739286e0521f7c226a9e/brax/envs/wrappers/vsys.py#L553Again, I don't expect this to be merged as-is. But perhaps the implementation might be interesting to the community, hence the reason for this PR.