-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaxStackTraceBeforeTransformation error with parametrized ODE #513
Comments
This is a known issue that arose in JAX 0.4.34. The tangent types of integers in custom autodiff was changed from matching the primal to instead being a I've updated Equinox to be compatible in patrick-kidger/equinox#871. I'll do a new release soon. In the mean time you can either install Equinox directly from HEAD, or you can downgrade to JAX 0.4.33. I hope that helps! :) (I can see that you said you already tried downgrading. I have just double-checked and Equinox v0.11.7 + JAX 0.4.33 works for me, so I think something else has probably gone wrong for you there. :) ) |
Don't know what I'm doing wrong here. I just tried equinox 0.11.7 and jax 0.4.33 and still the same issue. Maybe the new release will help ... Fortunately it's not urgend :) |
Just want to note that I get a similar problem
with jax==0.4.34,jaxlib==0.4.34,diffrax==0.6.0,equinox==0.11.8 on python 3.13. |
Issue above is fixed with optimistix 0.0.9 :) |
Hi,
based on this tutorial I tried to get started with Jax and neural ODEs: https://colab.research.google.com/drive/1ZlK36VgWy1vBjBNXjSUg6Cb-7zeoa3jh
However, I get the a JaxStackTraceBeforeTransformation error (detailed error message below). I boiled down the code to a small working example (also provided below) and noted the error only occors when the equation in test_func contains an argument.
Since this issue seemed similar to one raised in an earlier post (jax-ml/jax#13629) I tried downgrading jax to version 0.4.23. I also tried setting up a fresh python environment with only the necessary packages installed. Nothing helped, so far. I'd appreciate your help :)
(Even though it's labeled JaxStack... error, @dfm pointed out it might actually be a problem with diffrax: "The error reported here is actually a TypeError being raised because of an issue with the return types in a jax.custom_jvp. It's hard to see from this error report exactly which custom_jvp is the culprit, but it seems like it must be something within diffrax or equinox, so I'd recommend opening the issue on the https://github.com/patrick-kidger/diffrax issue tracker." jax-ml/jax#24253)
Working example:
Error message:
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.34
jaxlib: 0.4.34
numpy: 1.26.4
python: 3.11.1 (tags/v3.11.1:a7a450f, Dec 6 2022, 19:58:39) [MSC v.1934 64 bit (AMD64)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Windows', release='10', version='10.0.19044', machine='AMD64')
jupyterlab: 4.2.2
diffrax: 0.4.1
The text was updated successfully, but these errors were encountered: