Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705290306
Change-Id: I2b9d2ceb1442a653381aa1f4b80e44b947f535fe
  • Loading branch information
Brax Team authored and btaba committed Dec 12, 2024
1 parent a371f9f commit 69637a3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
6 changes: 3 additions & 3 deletions brax/training/agents/ars/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TrainingState:

normalizer_params: running_statistics.RunningStatisticsState
policy_params: Params
num_env_steps: int
num_env_steps: jax.Array


# TODO: Pass the network as argument.
Expand Down Expand Up @@ -289,7 +289,7 @@ def training_epoch(
TrainingState( # type: ignore # jnp-type
normalizer_params=normalizer_params,
policy_params=policy_params,
num_env_steps=num_env_steps,
num_env_steps=jnp.array(num_env_steps, dtype=jnp.int64),
),
metrics,
)
Expand Down Expand Up @@ -323,7 +323,7 @@ def training_epoch_with_timing(
training_state = TrainingState(
normalizer_params=normalizer_params,
policy_params=policy_params,
num_env_steps=0,
num_env_steps=jnp.array(0, dtype=jnp.int64),
)

if not eval_env:
Expand Down
8 changes: 4 additions & 4 deletions brax/training/agents/es/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TrainingState:
normalizer_params: running_statistics.RunningStatisticsState
optimizer_state: optax.OptState
policy_params: Params
num_env_steps: int
num_env_steps: jax.Array


# Centered rank from: https://arxiv.org/pdf/1703.03864.pdf
Expand Down Expand Up @@ -336,7 +336,7 @@ def training_epoch(

num_env_steps = (
training_state.num_env_steps
+ jnp.sum(obs_weights, dtype=jnp.int32) * action_repeat
+ jnp.sum(obs_weights, dtype=jnp.int64) * action_repeat
)

metrics = {
Expand All @@ -350,7 +350,7 @@ def training_epoch(
normalizer_params=normalizer_params,
optimizer_state=optimizer_state,
policy_params=policy_params,
num_env_steps=num_env_steps,
num_env_steps=jnp.array(num_env_steps, dtype=jnp.int64),
),
metrics,
)
Expand Down Expand Up @@ -386,7 +386,7 @@ def training_epoch_with_timing(
normalizer_params=normalizer_params,
optimizer_state=optimizer_state,
policy_params=policy_params,
num_env_steps=0,
num_env_steps=jnp.array(0, dtype=jnp.int64),
)

if not eval_env:
Expand Down
6 changes: 4 additions & 2 deletions brax/training/agents/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@ def f(carry, unused_t):
optimizer_state=optimizer_state,
params=params,
normalizer_params=normalizer_params,
env_steps=training_state.env_steps + env_step_per_training_step,
env_steps=jnp.array(
training_state.env_steps + env_step_per_training_step,
dtype=jnp.int64),
)
return (new_training_state, state, new_key), metrics

Expand Down Expand Up @@ -523,7 +525,7 @@ def training_epoch_with_timing(
normalizer_params=running_statistics.init_state(
_remove_pixels(obs_shape)
),
env_steps=0,
env_steps=jnp.array(0, dtype=jnp.int64),
)

if (
Expand Down
13 changes: 9 additions & 4 deletions brax/training/agents/sac/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _init_training_state(
q_params=q_params,
target_q_params=q_params,
gradient_steps=jnp.zeros(()),
env_steps=jnp.zeros(()),
env_steps=jnp.zeros((), dtype=jnp.int64),
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=log_alpha,
normalizer_params=normalizer_params,
Expand Down Expand Up @@ -314,7 +314,7 @@ def sgd_step(
q_params=q_params,
target_q_params=new_target_q_params,
gradient_steps=training_state.gradient_steps + 1,
env_steps=training_state.env_steps,
env_steps=jnp.array(training_state.env_steps, dtype=jnp.int64),
alpha_optimizer_state=alpha_optimizer_state,
alpha_params=alpha_params,
normalizer_params=training_state.normalizer_params,
Expand Down Expand Up @@ -367,7 +367,9 @@ def training_step(
)
training_state = training_state.replace(
normalizer_params=normalizer_params,
env_steps=training_state.env_steps + env_steps_per_actor_step,
env_steps=jnp.array(
training_state.env_steps + env_steps_per_actor_step, dtype=jnp.int64
),
)

buffer_state, transitions = replay_buffer.sample(buffer_state)
Expand Down Expand Up @@ -404,7 +406,10 @@ def f(carry, unused):
)
new_training_state = training_state.replace(
normalizer_params=new_normalizer_params,
env_steps=training_state.env_steps + env_steps_per_actor_step,
env_steps=jnp.array(
training_state.env_steps + env_steps_per_actor_step,
dtype=jnp.int64,
),
)
return (new_training_state, env_state, buffer_state, new_key), ()

Expand Down

0 comments on commit 69637a3

Please sign in to comment.