-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Adjoint of a Symplectic solver #541
Comments
I think right now Diffrax doesn't support reverse integration through symplectic solves. This kind of behaviour may improve in the future though... #528 :) |
Hello, Thank you for your answer I am extremely interested in a backsolve that works for Symplectic Solvers I created my own LeapFrog solver that is specific for Particle Mesh simulations, and I need to do the BackSolve process which is trivial (on paper) since I will be using only Constant Step Size The way to do is is to run the adjoint terms in reverse (so perhaps have an adjoint step in the solver?) If this is part of #528 then great !! If not I would be more than happy to help 1 - Either remove the error I would find the first solution better for me (and easier) but the second is perhaps cleaner |
So I think in the short term the simplest best thing to do would be for you to define a custom adjoint method (subclass In the long term I think we'll keep this use-case in mind as part of the reversible backprop work, in which case hopefully it will 'just work' at some point in the future. |
Thank you very much for you answer 🙏 Before clossing the issue, I have a related question I am trying to use use the the So I did this MWE that simulate the issue from jax import core
import jax
import jax.numpy as jnp
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import mlir
from jax import custom_vjp , custom_jvp
square_prim_p = core.Primitive("multiply_add") # Create the primitive
def square_prim(x):
return square_prim_p.bind(x)
def square_impl(x):
return jnp.power(x, 2)
def square_abstract_eval(x):
return core.ShapedArray(x.shape, x.dtype)
def square_lowering(ctx , xc):
return hlo.MulOp(xc, xc).results
@custom_vjp
def square_vjp(x):
return square_prim(x)
def square_vjp_fwd(x):
return square_prim(x), x * 2
def square_vjp_bwd(res, g):
jax.debug.print("res: {res}, g: {g}",res=res, g=g)
return g * res ,
@custom_jvp
def square_jvp(x):
return square_prim(x)
@square_jvp.defjvp
def square_jvp_impl(primals, tangents):
x, = primals
x_dot, = tangents
primals_out = square_jvp(x)
tangents_out = 2 * x_dot * x
return primals_out, tangents_out
square_vjp.defvjp(square_vjp_fwd, square_vjp_bwd)
square_prim_p.def_impl(square_impl)
square_prim_p.def_abstract_eval(square_abstract_eval)
mlir.register_lowering(square_prim_p, square_lowering)
vjp_grad = jax.jit(jax.grad(square_vjp))(jnp.array(3.0)) # works
jvp_grad = jax.jit(jax.grad(square_jvp))(jnp.array(3.0)) # works
print(f"vjp_grad: {vjp_grad}")
print(f"jvp_grad: {jvp_grad}")
def _fn_vjp(x):
dy , vjp = jax.vjp(square_vjp, x)
return vjp(dy)[0]
def _fn_jvp(x):
dy , vjp = jax.vjp(square_jvp, x)
return vjp(dy)[0]
jax.grad(_fn_jvp)(jnp.array(3.0)) # works
jax.grad(_fn_vjp)(jnp.array(3.0)) # fails So I can safely say that |
For your example here, you can fix this via: @@ -25,7 +25,7 @@ def square_vjp(x):
return square_prim(x)
def square_vjp_fwd(x):
- return square_prim(x), x * 2
+ return square_vjp(x), x * 2
def square_vjp_bwd(res, g):
jax.debug.print("res: {res}, g: {g}",res=res, g=g) Note that as you are performing second-order autodifferentiation then the internals of your first-order autodiff rule must themselves be autodifferentiable. |
Hello Patrick
Quick question on the adjoint methods
If I use a Symplectic solver is there a way to just do a reverse integration ?
With symplectic solvers, running back the steps adheres to the discritize then optimize strategy because we can follow back the same path (if I understood correctly)
Is using
diffrax.BacksolveAdjoint
enough?Thanks
The text was updated successfully, but these errors were encountered: