You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}" doesn't work with traced x.
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
Traceback (most recent call last):
File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 782, in __bool__
return self.aval._bool(self)
File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 1538, in error
raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function fn at /Users/dongseong/Workspaces/axlearn/axlearn/common/base_layer.py:329 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument kwargs['inputs'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
So check_numerics doesn't work inside jit, pmap, and scan.
def check_numerics(x: Tensor, msg_fmt: str = "", **msg_kwargs):
"""Checks that all elements in `x` are finite."""
global _enable_numeric_checks # pylint: disable=global-statement,global-variable-not-assigned
if _enable_numeric_checks:
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
return x
assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
doesn't work with tracedx
.So
check_numerics
doesn't work insidejit
,pmap
, andscan
.There is jax checkify, but it requires wrapped by
check.checkify(main)
. It's not trivial to use it in axlearn.https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html
The text was updated successfully, but these errors were encountered: