You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When it comes to annotations, optax currently relies heavily on optax.Updates and optax.Params, which are all aliases for chex.ArrayTree.
This makes sense, but for folks who run type checkers means that a lot of type errors happen when working with pytrees that aren't strictly nested Iterable or Mapping types as specified in chex. For example:
fromtypingimportTupleimportoptaxfromjaximportnumpyasjnpimportflax.struct@flax.struct.dataclassclassParams:
weights: jnp.ndarraybias: jnp.ndarraydefmake_optimizer(
params: Params,
) ->Tuple[optax.GradientTransformation, optax.OptState]:
"""Make an optimizer."""optimizer=optax.sgd(learning_rate=1e-3)
state=optimizer.init(params) # Type error.returnoptimizer, state
A few questions from this:
Is this considered a bug, or something that the optax team would be open to supporting? Are there better solutions for suppressing this error than simply adding a # type: ignore?
It seems like type safety with optax could benefit immensely from support for generics, which have been present since Python 3.5 (typing.Generic, typing.TypeVar). Any chance this would be something that optax would be open to supporting?
Simple example: with ArrayTreeT = TypeVar("ArrayTreeT", bound=chex.ArrayTree), optax.apply_updates() could be annotated as optax.apply_updates(params: ArrayTreeT, updates: ArrayTreeT) -> ArrayTreeT to indicate that the argument and return types should all be the same.
The text was updated successfully, but these errors were encountered:
thanks a lot for pointing this out! This is definitely something we should discuss especially if it would be convenient for flax to have these types supported.
I think we should prefer to stick to chex types as the standard to make it easier to ensure safe interoperability with other jax libraries that use chex (e.g. this bug isn't directly related to typing but it shows the bugs that can arise by differences in how the libraries treat more complicated pytrees). I think we should avoid defining our own versions of common types in optax if possible.
@hbq1 : has there been a discussion in chex on extending ArrayTree to include some (common) dataclass implementations?
Hello!
When it comes to annotations,
optax
currently relies heavily onoptax.Updates
andoptax.Params
, which are all aliases forchex.ArrayTree
.This makes sense, but for folks who run type checkers means that a lot of type errors happen when working with pytrees that aren't strictly nested
Iterable
orMapping
types as specified inchex
. For example:A few questions from this:
# type: ignore
?typing.Generic
,typing.TypeVar
). Any chance this would be something that optax would be open to supporting?ArrayTreeT = TypeVar("ArrayTreeT", bound=chex.ArrayTree)
,optax.apply_updates()
could be annotated asoptax.apply_updates(params: ArrayTreeT, updates: ArrayTreeT) -> ArrayTreeT
to indicate that the argument and return types should all be the same.The text was updated successfully, but these errors were encountered: