Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574170487
  • Loading branch information
rchen152 authored and learned_optimization authors committed Oct 17, 2023
1 parent 463ab9a commit 1f9823d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion learned_optimization/outer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def metrics_and_info_from_gradients(
max_stale = current_step - onp.min(steps)
metrics["max_staleness"] = max_stale

return metrics, worker_ids, applied_inner_steps
return metrics, worker_ids, applied_inner_steps # pytype: disable=bad-return-type


def maybe_resample_gradient_estimators(
Expand Down
2 changes: 1 addition & 1 deletion learned_optimization/outer_trainers/full_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def single_vec_batch(theta, state, key_data):

es_grad = jax.tree_util.tree_map(lambda x: jnp.mean(x, axis=0), vec_es_grad)

return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad
return jnp.mean((pos_loss + neg_loss) / 2.0), es_grad # pytype: disable=bad-return-type


@gin.configurable
Expand Down

0 comments on commit 1f9823d

Please sign in to comment.