-
-
Notifications
You must be signed in to change notification settings - Fork 141
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adjoints question #558
Comments
Exploring more with
Also if I make the noise additive via class EmptyField(eqx.Module):
force: eqx.Module
def __call__(self, t, y, args):
return lx.DiagonalLinearOperator(jnp.zeros_like(y))
vbt_terms = diffrax.MultiTerm(
diffrax.ODETerm(Field(drift_mlp)),
diffrax.ControlTerm(EmptyField(diffusion_mlp), vbt),
) I get the same errors (indicating the silent non-commutativity wasn't an issue) |
I can't comment on the others, but this one is expected and (still) a general limitation in JAX. |
Yea, I get that that error is a valid error, maybe my question is more like, then these solvers that have custom_vjps (I guess empirically I determined it's the SRK ones) are expected fail cases for forward mode? If so, should we document that these solvers won't work with it? Further testing also indicates the SRK methods are not working with backsolve. Switching to euler solves all problems (tbf, it was my mistake for solving an SDE with anything but euler to begin with 😆 ) Backsolve issue seems to be SRK has some requirements on terms (e.g. being multi terms), but back solve gives adjoint terms here adjoint_terms = jtu.tree_map(
AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
) not sure if this is fundamentally necessary tho and if we could get SRK working. |
It is expected that forward-mode autodiff through anything that has a The private functions of |
Okay, so! First of all, for FWIW I regard this as a huge hack -- it's special-casing the combination of built-in adjoints with built-in solvers, whereas ideally these should be orthogonally separate pieces. We need JAX to support forward mode autodiff through (You might also wonder -- why don't we always just use the
Now on to At least for commutativity, I have actually derived the necessary condition on the forward process for the adjoint process to be commutative. It's some kind of 'higher order commutativity' condition. No idea if it's novel or not but at the very least neither me nor James had seen it elsewhere as of a couple of years ago (when I derived this). My assumption is that approximately no-one is sitting around checking that their forward processes satisfy obscure conditions like these, just so that they can use the truly worst adjoint method ever invented. (And which I sometimes consider removing because of its utter uselessness. It survives only so that I don't get inundated with messages asking where it is, and so that I have a spot to put up big warning signs teaching people to stop using the blasted thing.) And so, for this reason, it does not support these kinds of solvers. FWIW we could maybe try to catch this inside Does that cover everything? |
Smh why does google not use a tiny bit of their 93.23 billion cash on hand really advance JAX/XLA and fix all these bugs in a package presumably used extensively internally and the backbone of a multi billion dollar industry (since many LLM companies use it).
Backsolve adjoint being mid seems to be a somewhat common sentiment. What dissuaded you from adding other (non-bad) continuous adjoints, e.g. those of table 1/fig 7 of https://arxiv.org/abs/2406.09699, or those in https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#sensitivity_diffeq, when supposedly there are advantages for these fancier adjoints (https://arxiv.org/abs/2001.04385, https://arxiv.org/abs/1812.01892).
I think this is a good idea (there are already a lot of flags/catches in back solve), just to make it clear. |
Mostly they just don't add very much. That aside we do intentionally have an
Happy to take a PR on this one! ;) |
As a strict adherent to software best practices I will |
I was putting together some tests, when I realized that I'm not sure I fully understand the error bounds on the non recursive (which is what I almost always use) adjoints. Specifically, for a simple diagonal noise SDE, I actually encounter errors for the other adjoints, are these expected or am I doing something wrong?
Forward errors with
TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
and
Backsolve errors with
ValueError: Terms are not compatible with solver!
(this is from a fork of main).
I know there's this remark on closure with Backsolve (https://docs.kidger.site/diffrax/further_details/faq/#im-getting-a-customvjpexception), but back solve here has a different failure and forward I should be passing args, y0, terms directly into the function for forward in case it was the same.
The text was updated successfully, but these errors were encountered: