From 9540382d80d96e7e730c4d2ced2a4ed5034d5204 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 13 Sep 2023 14:37:15 -0700 Subject: [PATCH] Ignore incorrect annotations related to chex.PRNGKey Currently chex.PRNGKey is effectively treated as Any, and making this more strict reveals a number of incorrect annotations in existing code. PiperOrigin-RevId: 565164491 --- learned_optimization/optimizers/optax_opts.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/learned_optimization/optimizers/optax_opts.py b/learned_optimization/optimizers/optax_opts.py index e8c1edd..9765554 100644 --- a/learned_optimization/optimizers/optax_opts.py +++ b/learned_optimization/optimizers/optax_opts.py @@ -384,11 +384,13 @@ def __init__(self, # SM3 doesn't support scalars, so we have to reshape the params and grads. - def init(self, - params: Any, - model_state: Optional[Any] = None, - num_steps: Optional[int] = None, - key: chex.PRNGKey = None) -> SM3OptState: + def init( # type: ignore + self, + params: Any, + model_state: Optional[Any] = None, + num_steps: Optional[int] = None, + key: chex.PRNGKey = None, + ) -> SM3OptState: should_reshape = jax.tree_util.tree_map(lambda x: len(x.shape) == 0, params) # pylint: disable=g-explicit-length-test params = jax.tree_util.tree_map(_expand_scalar, params, should_reshape) out = super().init(params, model_state, num_steps, key)