Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check_numerics doesn't work inside repeat.py #815

Open
ds-hwang opened this issue Nov 5, 2024 · 0 comments
Open

check_numerics doesn't work inside repeat.py #815

ds-hwang opened this issue Nov 5, 2024 · 0 comments

Comments

@ds-hwang
Copy link
Contributor

ds-hwang commented Nov 5, 2024

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}" doesn't work with traced x.

assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
Traceback (most recent call last):
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 782, in __bool__
    return self.aval._bool(self)
  File "/Users/dongseong/miniforge3/envs/ajax/lib/python3.10/site-packages/jax/_src/core.py", line 1538, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function fn at /Users/dongseong/Workspaces/axlearn/axlearn/common/base_layer.py:329 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument kwargs['inputs'].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

So check_numerics doesn't work inside jit, pmap, and scan.

def check_numerics(x: Tensor, msg_fmt: str = "", **msg_kwargs):
    """Checks that all elements in `x` are finite."""
    global _enable_numeric_checks  # pylint: disable=global-statement,global-variable-not-assigned
    if _enable_numeric_checks:
        assert bool(jnp.isfinite(x).all()), f"Check numerics {msg_fmt.format(**msg_kwargs)}: {x}"
    return x

There is jax checkify, but it requires wrapped by check.checkify(main). It's not trivial to use it in axlearn.
https://jax.readthedocs.io/en/latest/debugging/checkify_guide.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant