Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded while loop #923

Open
lockwo opened this issue Jan 4, 2025 · 2 comments · May be fixed by patrick-kidger/diffrax#559
Open

Bounded while loop #923

lockwo opened this issue Jan 4, 2025 · 2 comments · May be fixed by patrick-kidger/diffrax#559

Comments

@lockwo
Copy link
Contributor

lockwo commented Jan 4, 2025

While working with some DirectAdjoint modifications, I noticed repeated empty assertion errors stemming from bounded while loops,

test/test_adjoint.py:333: in _run_inexact
    return _run(eqx.combine(inexact, static), saveat, adjoint)
test/test_adjoint.py:291: in _run
    ys = diffrax.diffeqsolve(
diffrax/_integrate.py:1462: in diffeqsolve
    final_state, aux_stats = adjoint.loop(
diffrax/_adjoint.py:405: in loop
    final_state = self._loop(
diffrax/_integrate.py:641: in loop
    final_state = outer_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/contextlib.py:79: in inner
    return func(*args, **kwds)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/loop.py:119: in while_loop
    return bounded_while_loop(
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:59: in bounded_while_loop
    _, _, _, val = _while_loop(cond_fun_, body_fun_, init_val_, rounded_max_steps, base)
../../miniforge3/envs/dev_diffrax/lib/python3.10/site-packages/equinox/internal/_loop/bounded.py:78: in _while_loop
    return lax.scan(scan_fn, val, xs=None, length=base)[0]
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

>   return lax.cond(cond_fun(val), call, lambda x: x, val), None
E   jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError

I was curious if you had ever seen this before? I will work to get a MVC in the meantime.

@lockwo
Copy link
Contributor Author

lockwo commented Jan 4, 2025

Maybe this belongs in diffrax, but since the core code is equinox bounded while loop (and DirectAdjoint is a pretty thin layer over them) I put it here

@lockwo lockwo linked a pull request Jan 5, 2025 that will close this issue
@patrick-kidger
Copy link
Owner

Hmmm nope, this one isn't familiar. From the traceback -- a totally innocuous-looking line -- this looks like it might be coming from JAX internals, probably an assert statement inside one of the cond_p rules? I'll take a look at the MWE once you have it. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants