Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't backprop over vmapped diffeqsolve (NotImplementedError: Differentiation rule for 'reduce_or' not implemented) #568

Open
LuggiStruggi opened this issue Jan 14, 2025 · 2 comments

Comments

@LuggiStruggi
Copy link

LuggiStruggi commented Jan 14, 2025

I can backpropagate over diffeqsolve without any issue, however when I use vmap over some function which includes diffeqsolve I get the following Error:

Traceback (most recent call last):
  File "/home/luggistruggi/Documents/work/test_issue.py", line 45, in <module>
    losses = batched_loss_fn(batch_params)
  File "/home/luggistruggi/Documents/work/test_issue.py", line 36, in batched_loss_fn
    return jax.vmap(single_loss_fn)(params)
  File "/home/luggistruggi/Documents/work/test_issue.py", line 20, in single_loss_fn
    sol = diffrax.diffeqsolve(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 1401, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_adjoint.py", line 294, in loop
    final_state = self._loop(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 640, in loop
    event_happened = jnp.any(jnp.stack(flat_mask))
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/jax/_src/numpy/reductions.py", line 681, in any
    return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Differentiation rule for 'reduce_or' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

Here the code which produced the error:

import jax
import jax.numpy as jnp
import diffrax
import optimistix as optx

def dynamics(t, y, args):
    param = args
    return param - y


def event_fn(t, y, args, **kwargs):
    return y - 1.5

def single_loss_fn(param):
    solver = diffrax.Euler()
    root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
    event = diffrax.Event(event_fn, root_finder)
    term = diffrax.ODETerm(dynamics)

    sol = diffrax.diffeqsolve(
        term,
        solver=solver,
        t0=0.0,
        t1=2.0,
        dt0=0.1,
        y0=0.0,
        args=param,
        event=event,
        max_steps=1000,
    )

    final_y = sol.ys[-1]
    return param**2 + final_y**2

def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(single_loss_fn)(params)

def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
    return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)


if __name__ == "__main__":
    batch_params = jnp.array([1.0, 2.0, 3.0])

    losses = batched_loss_fn(batch_params)
    print("batched_loss_fn =", losses)

    grads = grad_fn(batch_params)
    print("grad_fn =", grads)

Any suggestions on how to avoid this or implement this myself? Thank you so much :)

@LuggiStruggi
Copy link
Author

It seems to work with these changes
#569

@patrick-kidger
Copy link
Owner

I reckon the chances are good that this is due to this upstream JAX bug:

jax-ml/jax#25724

Thanks for identifying a fix, that's super useful that we can work around this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants