Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjoints question #558

Open
lockwo opened this issue Jan 4, 2025 · 8 comments
Open

Adjoints question #558

lockwo opened this issue Jan 4, 2025 · 8 comments
Labels
question User queries

Comments

@lockwo
Copy link
Contributor

lockwo commented Jan 4, 2025

I was putting together some tests, when I realized that I'm not sure I fully understand the error bounds on the non recursive (which is what I almost always use) adjoints. Specifically, for a simple diagonal noise SDE, I actually encounter errors for the other adjoints, are these expected or am I doing something wrong?

import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import diffrax
import lineax as lx

key = jax.random.PRNGKey(42)
key, subkey = jax.random.split(key)
driftkey, diffusionkey, ykey = jr.split(subkey, 3)
drift_mlp = eqx.nn.MLP(
    in_size=3,
    out_size=3,
    width_size=8,
    depth=2,
    activation=jax.nn.swish,
    final_activation=jnp.tanh,
    key=driftkey,
)
diffusion_mlp = eqx.nn.MLP(
    in_size=3,
    out_size=3,
    width_size=8,
    depth=2,
    activation=jax.nn.swish,
    final_activation=jnp.tanh,
    key=diffusionkey,
)

class Field(eqx.Module):
    force: eqx.nn.MLP

    def __call__(self, t, y, args):
        return self.force(y)

class DiffusionField(eqx.Module):
    force: eqx.nn.MLP

    def __call__(self, t, y, args):
        return lx.DiagonalLinearOperator(self.force(y))

y0 = jr.normal(ykey, (3,))

k1, k2, k3 = jax.random.split(key, 3)
vbt = diffrax.VirtualBrownianTree(0.3, 9.5, 1e-4, (3,), k1, levy_area=diffrax.SpaceTimeLevyArea)
vbt_terms = diffrax.MultiTerm(
    diffrax.ODETerm(Field(drift_mlp)),
    diffrax.ControlTerm(DiffusionField(diffusion_mlp), vbt),
)
solver = diffrax.GeneralShARK()
y0_args_term0 = (y0, None, vbt_terms)

def _run(y0__args__term, saveat, adjoint):
    y0_, args, term = y0__args__term
    ys = diffrax.diffeqsolve(
        term,
        solver,
        0.3,
        9.5,
        0.1,
        y0_,
        args,
        saveat=saveat,
        adjoint=adjoint,
    ).ys
    return jnp.sum(ys)


t0 = True
t1 = True
ts = None
y0__args__term = y0_args_term0

saveat = diffrax.SaveAt(t0=t0, t1=t1, ts=ts)

inexact, static = eqx.partition(y0__args__term, eqx.is_inexact_array)

def _run_inexact(inexact, saveat_, adjoint_):
    return _run(eqx.combine(inexact, static), saveat_, adjoint_)

_run_grad = eqx.filter_jit(jax.grad(_run_inexact))
_run_fwd_grad = eqx.filter_jit(jax.jacfwd(_run_inexact))
recursive_grads = _run_grad(inexact, saveat, diffrax.RecursiveCheckpointAdjoint())
forward_grads = _run_fwd_grad(inexact, saveat, diffrax.ForwardMode())
direct_grads = _run_grad(inexact, saveat, diffrax.DirectAdjoint())
backsolve_grads = _run_grad(
    inexact, saveat, diffrax.BacksolveAdjoint()
)

Forward errors with TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

and

Backsolve errors with ValueError: Terms are not compatible with solver!

(this is from a fork of main).

I know there's this remark on closure with Backsolve (https://docs.kidger.site/diffrax/further_details/faq/#im-getting-a-customvjpexception), but back solve here has a different failure and forward I should be passing args, y0, terms directly into the function for forward in case it was the same.

@lockwo
Copy link
Contributor Author

lockwo commented Jan 4, 2025

Exploring more with ForwardMode and seeing different errors depending on the solver (all of this feels unexpected to me, since it seems like a very simple modification to the reverse mode requirements, but maybe I am missing something in how ForwardMode should work).

  • Euler: no error
  • Heun: no error
  • Midpoint: no error
  • Ralston: no error
  • EulerHuen: no error
  • ItoMilstein: ValueError: jvp called with different primal and tangent shapes;Got primal shape (3,) and tangent shape as () (I realize my noise is diagonal, but not commutative, so maybe these are expected to fail)
  • StratonovichMilstein: ValueError: jvp called with different primal and tangent shapes;Got primal shape (3,) and tangent shape as ()
  • GeneralShark: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
  • Spark: TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

Also if I make the noise additive via

class EmptyField(eqx.Module):
    force: eqx.Module

    def __call__(self, t, y, args):
        return lx.DiagonalLinearOperator(jnp.zeros_like(y))

vbt_terms = diffrax.MultiTerm(
    diffrax.ODETerm(Field(drift_mlp)),
    diffrax.ControlTerm(EmptyField(diffusion_mlp), vbt),
)

I get the same errors (indicating the silent non-commutativity wasn't an issue)

@johannahaffner
Copy link
Contributor

Forward errors with TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

I can't comment on the others, but this one is expected and (still) a general limitation in JAX.

@lockwo
Copy link
Contributor Author

lockwo commented Jan 4, 2025

I can't comment on the others, but this one is expected and (still) a general limitation in JAX.

Yea, I get that that error is a valid error, maybe my question is more like, then these solvers that have custom_vjps (I guess empirically I determined it's the SRK ones) are expected fail cases for forward mode? If so, should we document that these solvers won't work with it?

Further testing also indicates the SRK methods are not working with backsolve. Switching to euler solves all problems (tbf, it was my mistake for solving an SDE with anything but euler to begin with 😆 )

Backsolve issue seems to be SRK has some requirements on terms (e.g. being multi terms), but back solve gives adjoint terms here

    adjoint_terms = jtu.tree_map(
        AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
    )

not sure if this is fundamentally necessary tho and if we could get SRK working.

@johannahaffner
Copy link
Contributor

It is expected that forward-mode autodiff through anything that has a custom_vjp will fail in JAX. I'm guessing we're not doing specific warnings since this is a JAX thing, not a diffrax thing.

The private functions of BacksolveAdjoint are the only place where I see custom_vjps being defined (and def_fwds). Does that mean that there are custom_vjps in diffrax' dependencies?

@patrick-kidger
Copy link
Owner

Okay, so!

First of all, for ForwardMode: indeed these were failing for the SRK solvers. That was just a bug and is now fixed :) (#561) These solvers internally use a custom_vjp as they have an Equinox checkpointed scan over stages. The fix is just to substitute it out with the regular lax.scan when doing forward mode.

FWIW I regard this as a huge hack -- it's special-casing the combination of built-in adjoints with built-in solvers, whereas ideally these should be orthogonally separate pieces. We need JAX to support forward mode autodiff through custom_vjps for that to go away, sadly. (And if it ever does then we can also delete both DirectAdjoint and ForwardMode. Having that in JAX really would make our lives a lot simpler.)

(You might also wonder -- why don't we always just use the lax version? The reason is that during the step, we need to incrementally fill in a buffer of vector field evaluations, stage-by-stage. And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass! So we find ourselves wanting:

  • the lax version during forward-mode, as the eqxi version has a custom_vjp
  • and the eqxi version during reverse mode, as the lax version hits the above XLA bug.)

Now on to BacksolveAdjoint: I think this should only be valid with the basic SDE solvers (Euler, Heun, ... anything that doesn't make assumptions about the noise). The reason is that on the backward pass we need to consider the noise type of the adjoint system and it is that that needs to satisfy the conditions of the solver: e.g. to be commutative for ItoMilstein to be a valid solver.

At least for commutativity, I have actually derived the necessary condition on the forward process for the adjoint process to be commutative. It's some kind of 'higher order commutativity' condition. No idea if it's novel or not but at the very least neither me nor James had seen it elsewhere as of a couple of years ago (when I derived this).

My assumption is that approximately no-one is sitting around checking that their forward processes satisfy obscure conditions like these, just so that they can use the truly worst adjoint method ever invented. (And which I sometimes consider removing because of its utter uselessness. It survives only so that I don't get inundated with messages asking where it is, and so that I have a spot to put up big warning signs teaching people to stop using the blasted thing.)

And so, for this reason, it does not support these kinds of solvers. FWIW we could maybe try to catch this inside BacksolveAdjoint.__call__ instead of deferring it to the nested diffeqsolve, as it's not a super friendly error message.


Does that cover everything?

@patrick-kidger patrick-kidger added the question User queries label Jan 5, 2025
@lockwo
Copy link
Contributor Author

lockwo commented Jan 6, 2025

And sadly, XLA has a longstanding bug in which grad-of-loop-of-inplace will make copies of that buffer during the backward pass!

Smh why does google not use a tiny bit of their 93.23 billion cash on hand really advance JAX/XLA and fix all these bugs in a package presumably used extensively internally and the backbone of a multi billion dollar industry (since many LLM companies use it).

And which I sometimes consider removing because of its utter uselessness. It survives only so that I don't get inundated with messages asking where it is, and so that I have a spot to put up big warning signs teaching people to stop using the blasted thing.

Backsolve adjoint being mid seems to be a somewhat common sentiment. What dissuaded you from adding other (non-bad) continuous adjoints, e.g. those of table 1/fig 7 of https://arxiv.org/abs/2406.09699, or those in https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#sensitivity_diffeq, when supposedly there are advantages for these fancier adjoints (https://arxiv.org/abs/2001.04385, https://arxiv.org/abs/1812.01892).

And so, for this reason, it does not support these kinds of solvers. FWIW we could maybe try to catch this inside BacksolveAdjoint.call instead of deferring it to the nested diffeqsolve, as it's not a super friendly error message.

I think this is a good idea (there are already a lot of flags/catches in back solve), just to make it clear.

@patrick-kidger
Copy link
Owner

What dissuaded you from adding other (non-bad) continuous adjoints

Mostly they just don't add very much. RecursiveCheckpointAdjoint is already fast+accurate+low-memory, so it handles the vast majority of use-cases already. (when supposedly there are advantages for these fancier adjoints -> I don't think so.)

That aside we do intentionally have an AbstractAdjoint interface so that a user can always write their own adjoint method without relying on internal changes within Diffrax. So I imagine a sufficiently motivated user could already implement any of these themselves if they wanted!

I think this is a good idea (there are already a lot of flags/catches in back solve), just to make it clear.

Happy to take a PR on this one! ;)

@lockwo
Copy link
Contributor Author

lockwo commented Jan 28, 2025

Happy to take a PR on this one! ;)

As a strict adherent to software best practices I will make a new PR for this roll this into my big PR that already addresses 5 other issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants