You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I can backpropagate over diffeqsolve without any issue, however when I use vmap over some function which includes diffeqsolve I get the following Error:
Traceback (most recent call last):
File "/home/luggistruggi/Documents/work/test_issue.py", line 45, in <module>
losses = batched_loss_fn(batch_params)
File "/home/luggistruggi/Documents/work/test_issue.py", line 36, in batched_loss_fn
return jax.vmap(single_loss_fn)(params)
File "/home/luggistruggi/Documents/work/test_issue.py", line 20, in single_loss_fn
sol = diffrax.diffeqsolve(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 1401, in diffeqsolve
final_state, aux_stats = adjoint.loop(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_adjoint.py", line 294, in loop
final_state = self._loop(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 640, in loop
event_happened = jnp.any(jnp.stack(flat_mask))
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/jax/_src/numpy/reductions.py", line 681, in any
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Differentiation rule for 'reduce_or' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
I can backpropagate over diffeqsolve without any issue, however when I use vmap over some function which includes diffeqsolve I get the following Error:
Here the code which produced the error:
Any suggestions on how to avoid this or implement this myself? Thank you so much :)
The text was updated successfully, but these errors were encountered: