Skip to content

Commit

Permalink
Ignore incorrect annotations related to chex.PRNGKey
Browse files Browse the repository at this point in the history
Currently chex.PRNGKey is effectively treated as Any, and making this more strict reveals a number of incorrect annotations in existing code.

PiperOrigin-RevId: 565164491
  • Loading branch information
Jake VanderPlas authored and learned_optimization authors committed Sep 20, 2023
1 parent fcded5d commit 9540382
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions learned_optimization/optimizers/optax_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,11 +384,13 @@ def __init__(self,

# SM3 doesn't support scalars, so we have to reshape the params and grads.

def init(self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None) -> SM3OptState:
def init( # type: ignore
self,
params: Any,
model_state: Optional[Any] = None,
num_steps: Optional[int] = None,
key: chex.PRNGKey = None,
) -> SM3OptState:
should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test
params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape)
out = super().init(params, model_state, num_steps, key)
Expand Down

0 comments on commit 9540382

Please sign in to comment.