Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Vsys feature: massively parallel domain randomization #458

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions brax/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from brax.envs import walker2d
from brax.envs.base import Env, PipelineEnv, State, Wrapper
from brax.envs.wrappers import training
from brax.envs.wrappers.vsys import IdentityVSysWrapper

_envs = {
'ant': ant.Ant,
Expand Down Expand Up @@ -79,6 +80,7 @@ def create(
action_repeat: int = 1,
auto_reset: bool = True,
batch_size: Optional[int] = None,
no_vsys: bool = True,
**kwargs,
) -> Env:
"""Creates an environment from the registry.
Expand All @@ -102,5 +104,7 @@ def create(
env = training.VmapWrapper(env, batch_size)
if auto_reset:
env = training.AutoResetWrapper(env)
if no_vsys:
env = IdentityVSysWrapper(env)

return env
16 changes: 8 additions & 8 deletions brax/envs/ant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint:disable=g-multiple-import
"""Trains an ant to run in the +x direction."""

from brax import base
from brax import base, System
from brax import math
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
Expand Down Expand Up @@ -202,17 +202,17 @@ def __init__(
if self._use_contact_forces:
raise NotImplementedError('use_contact_forces not implemented.')

def reset(self, rng: jax.Array) -> State:
def reset(self, sys: System, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
q = sys.init_q + jax.random.uniform(
rng1, (sys.q_size(),), minval=low, maxval=hi
)
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
qd = hi * jax.random.normal(rng2, (sys.qd_size(),))

pipeline_state = self.pipeline_init(q, qd)
pipeline_state = self.pipeline_init(sys, q, qd)
obs = self._get_obs(pipeline_state)

reward, done, zero = jp.zeros(3)
Expand All @@ -228,12 +228,12 @@ def reset(self, rng: jax.Array) -> State:
'y_velocity': zero,
'forward_reward': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward , done, sys, metrics)

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
pipeline_state = self.pipeline_step(state.sys, pipeline_state0, action)

velocity = (pipeline_state.x.pos[0] - pipeline_state0.x.pos[0]) / self.dt
forward_reward = velocity[0]
Expand Down
27 changes: 18 additions & 9 deletions brax/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import abc
from typing import Any, Dict, List, Optional, Sequence, Union

from brax import base
from jax.random import KeyArray

from brax import base, System
from brax.generalized import pipeline as g_pipeline
from brax.io import image
from brax.mjx import pipeline as m_pipeline
Expand All @@ -30,6 +32,8 @@
from mujoco import mjx
import numpy as np

def none_factory():
return None

@struct.dataclass
class State(base.Base):
Expand All @@ -39,15 +43,18 @@ class State(base.Base):
obs: jax.Array
reward: jax.Array
done: jax.Array
sys: System
metrics: Dict[str, jax.Array] = struct.field(default_factory=dict)
info: Dict[str, Any] = struct.field(default_factory=dict)
vsys_rng: Optional[KeyArray] = None #struct.field(pytree_node=False, default_factory=none_factory)
vsys_stepcount: Optional[int] = None #struct.field(pytree_node=True, default_factory=none_factory)


class Env(abc.ABC):
"""Interface for driving training and inference."""

@abc.abstractmethod
def reset(self, rng: jax.Array) -> State:
def reset(self, sys: System, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""

@abc.abstractmethod
Expand Down Expand Up @@ -114,16 +121,18 @@ def __init__(
self._n_frames = n_frames
self._debug = debug

def pipeline_init(self, q: jax.Array, qd: jax.Array) -> base.State:
def pipeline_init(self, sys: System, q: jax.Array, qd: jax.Array) -> base.State:
"""Initializes the pipeline state."""
return self._pipeline.init(self.sys, q, qd, self._debug)
return self._pipeline.init(sys, q, qd, self._debug)

def pipeline_step(self, pipeline_state: Any, action: jax.Array) -> base.State:
def pipeline_step(
self, sys: System, pipeline_state: Any, action: jax.Array
) -> base.State:
"""Takes a physics step using the physics pipeline."""

def f(state, _):
return (
self._pipeline.step(self.sys, state, action, self._debug),
self._pipeline.step(sys, state, action, self._debug),
None,
)

Expand All @@ -137,7 +146,7 @@ def dt(self) -> jax.Array:
@property
def observation_size(self) -> int:
rng = jax.random.PRNGKey(0)
reset_state = self.unwrapped.reset(rng)
reset_state = self.unwrapped.reset(self.sys, rng)
return reset_state.obs.shape[-1]

@property
Expand Down Expand Up @@ -165,8 +174,8 @@ class Wrapper(Env):
def __init__(self, env: Env):
self.env = env

def reset(self, rng: jax.Array) -> State:
return self.env.reset(rng)
def reset(self, sys: System, rng: jax.Array) -> State:
return self.env.reset(sys, rng)

def step(self, state: State, action: jax.Array) -> State:
return self.env.step(state, action)
Expand Down
16 changes: 8 additions & 8 deletions brax/envs/half_cheetah.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint:disable=g-multiple-import
"""Trains a halfcheetah to run in the +x direction."""

from brax import base
from brax import base, System
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from etils import epath
Expand Down Expand Up @@ -153,17 +153,17 @@ def __init__(
exclude_current_positions_from_observation
)

def reset(self, rng: jax.Array) -> State:
def reset(self, sys: System, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
qpos = sys.init_q + jax.random.uniform(
rng1, (sys.q_size(),), minval=low, maxval=hi
)
qvel = hi * jax.random.normal(rng2, (self.sys.qd_size(),))
qvel = hi * jax.random.normal(rng2, (sys.qd_size(),))

pipeline_state = self.pipeline_init(qpos, qvel)
pipeline_state = self.pipeline_init(sys, qpos, qvel)

obs = self._get_obs(pipeline_state)
reward, done, zero = jp.zeros(3)
Expand All @@ -173,12 +173,12 @@ def reset(self, rng: jax.Array) -> State:
'reward_ctrl': zero,
'reward_run': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done,sys, metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return State(pipeline_state, obs, reward, done,sys, metrics)
return State(pipeline_state, obs, reward, done, sys, metrics)


def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
pipeline_state = self.pipeline_step(state.sys, pipeline_state0, action)

x_velocity = (
pipeline_state.x.pos[0, 0] - pipeline_state0.x.pos[0, 0]
Expand Down
16 changes: 8 additions & 8 deletions brax/envs/hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Tuple

from brax import base
from brax import base, System
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
from etils import epath
Expand Down Expand Up @@ -195,19 +195,19 @@ def __init__(
exclude_current_positions_from_observation
)

def reset(self, rng: jax.Array) -> State:
def reset(self, sys: System, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
qpos = sys.init_q + jax.random.uniform(
rng1, (sys.q_size(),), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.qd_size(),), minval=low, maxval=hi
rng2, (sys.qd_size(),), minval=low, maxval=hi
)

pipeline_state = self.pipeline_init(qpos, qvel)
pipeline_state = self.pipeline_init(sys, qpos, qvel)

obs = self._get_obs(pipeline_state)
reward, done, zero = jp.zeros(3)
Expand All @@ -218,12 +218,12 @@ def reset(self, rng: jax.Array) -> State:
'x_position': zero,
'x_velocity': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done,sys, metrics)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return State(pipeline_state, obs, reward, done,sys, metrics)
return State(pipeline_state, obs, reward, done, sys, metrics)


def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
pipeline_state = self.pipeline_step(state.sys, pipeline_state0, action)

x_velocity = (
pipeline_state.x.pos[0, 0] - pipeline_state0.x.pos[0, 0]
Expand Down
38 changes: 19 additions & 19 deletions brax/envs/humanoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# pylint:disable=g-multiple-import
"""Trains a humanoid to run in the +x direction."""

from brax import actuator
from brax import actuator, System
from brax import base
from brax.envs.base import PipelineEnv, State
from brax.io import mjcf
Expand Down Expand Up @@ -224,21 +224,21 @@ def __init__(
exclude_current_positions_from_observation
)

def reset(self, rng: jax.Array) -> State:
def reset(self, sys: System, rng: jax.Array) -> State:
"""Resets the environment to an initial state."""
rng, rng1, rng2 = jax.random.split(rng, 3)

low, hi = -self._reset_noise_scale, self._reset_noise_scale
qpos = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
qpos = sys.init_q + jax.random.uniform(
rng1, (sys.q_size(),), minval=low, maxval=hi
)
qvel = jax.random.uniform(
rng2, (self.sys.qd_size(),), minval=low, maxval=hi
rng2, (sys.qd_size(),), minval=low, maxval=hi
)

pipeline_state = self.pipeline_init(qpos, qvel)
pipeline_state = self.pipeline_init(sys, qpos, qvel)

obs = self._get_obs(pipeline_state, jp.zeros(self.sys.act_size()))
obs = self._get_obs(sys, pipeline_state, jp.zeros(sys.act_size()))
reward, done, zero = jp.zeros(3)
metrics = {
'forward_reward': zero,
Expand All @@ -251,15 +251,15 @@ def reset(self, rng: jax.Array) -> State:
'x_velocity': zero,
'y_velocity': zero,
}
return State(pipeline_state, obs, reward, done, metrics)
return State(pipeline_state, obs, reward, done, sys, metrics)

def step(self, state: State, action: jax.Array) -> State:
"""Runs one timestep of the environment's dynamics."""
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)
pipeline_state = self.pipeline_step(state.sys, pipeline_state0, action)

com_before, *_ = self._com(pipeline_state0)
com_after, *_ = self._com(pipeline_state)
com_before, *_ = self._com(state.sys, pipeline_state0)
com_after, *_ = self._com(state.sys, pipeline_state)
velocity = (com_after - com_before) / self.dt
forward_reward = self._forward_reward_weight * velocity[0]

Expand All @@ -273,7 +273,7 @@ def step(self, state: State, action: jax.Array) -> State:

ctrl_cost = self._ctrl_cost_weight * jp.sum(jp.square(action))

obs = self._get_obs(pipeline_state, action)
obs = self._get_obs(state.sys, pipeline_state, action)
reward = forward_reward + healthy_reward - ctrl_cost
done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0
state.metrics.update(
Expand All @@ -293,7 +293,7 @@ def step(self, state: State, action: jax.Array) -> State:
)

def _get_obs(
self, pipeline_state: base.State, action: jax.Array
self, sys: System, pipeline_state: base.State, action: jax.Array
) -> jax.Array:
"""Observes humanoid body position, velocities, and angles."""
position = pipeline_state.q
Expand All @@ -302,7 +302,7 @@ def _get_obs(
if self._exclude_current_positions_from_observation:
position = position[2:]

com, inertia, mass_sum, x_i = self._com(pipeline_state)
com, inertia, mass_sum, x_i = self._com(sys, pipeline_state)
cinr = x_i.replace(pos=x_i.pos - com).vmap().do(inertia)
com_inertia = jp.hstack(
[cinr.i.reshape((cinr.i.shape[0], -1)), inertia.mass[:, None]]
Expand All @@ -318,7 +318,7 @@ def _get_obs(
com_velocity = jp.hstack([com_vel, com_ang])

qfrc_actuator = actuator.to_tau(
self.sys, action, pipeline_state.q, pipeline_state.qd)
sys, action, pipeline_state.q, pipeline_state.qd)

# external_contact_forces are excluded
return jp.concatenate([
Expand All @@ -329,15 +329,15 @@ def _get_obs(
qfrc_actuator,
])

def _com(self, pipeline_state: base.State) -> jax.Array:
inertia = self.sys.link.inertia
def _com(self, sys: System, pipeline_state: base.State) -> jax.Array:
inertia = sys.link.inertia
if self.backend in ['spring', 'positional']:
inertia = inertia.replace(
i=jax.vmap(jp.diag)(
jax.vmap(jp.diagonal)(inertia.i)
** (1 - self.sys.spring_inertia_scale)
** (1 - sys.spring_inertia_scale)
),
mass=inertia.mass ** (1 - self.sys.spring_mass_scale),
mass=inertia.mass ** (1 - sys.spring_mass_scale),
)
mass_sum = jp.sum(inertia.mass)
x_i = pipeline_state.x.vmap().do(inertia.transform)
Expand Down
Loading