Skip to content

Commit

Permalink
SRKs now support forward-mode autodiff.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 5, 2025
1 parent 9eb1bff commit c9a214d
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
19 changes: 15 additions & 4 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
from ._solver import (
AbstractItoSolver,
AbstractRungeKutta,
AbstractSRK,
AbstractStratonovichSolver,
)
from ._term import AbstractTerm, AdjointTerm


Expand Down Expand Up @@ -272,7 +277,7 @@ def loop(
if is_unsafe_sde(terms):
raise ValueError(
"`adjoint=RecursiveCheckpointAdjoint()` does not support "
"`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` "
"`UnsafeBrownianPath`. Consider using `adjoint=ForwardMode()` "
"instead."
)
if self.checkpoints is None and max_steps is None:
Expand Down Expand Up @@ -376,7 +381,10 @@ def loop(
msg = None
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
if (
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
and solver.scan_kind is None
):
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
)
Expand Down Expand Up @@ -888,7 +896,10 @@ def loop(
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
if (
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
and solver.scan_kind is None
):
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
final_state = self._loop(
solver=solver,
Expand Down
6 changes: 4 additions & 2 deletions diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import TypeAlias

import equinox as eqx
Expand Down Expand Up @@ -255,6 +255,8 @@ class AbstractSRK(AbstractSolver[_SolverState]):
as well as $b^H$, $a^H$, $b^K$, and $a^K$ if needed.
"""

scan_kind: Union[None, Literal["lax", "checkpointed"]] = None

interpolation_cls = LocalLinearInterpolation
term_compatible_contr_kwargs = (dict(), dict(use_levy=True))
tableau: AbstractClassVar[StochasticButcherTableau]
Expand Down Expand Up @@ -583,7 +585,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
scan_inputs,
len(b_sol),
buffers=lambda x: x,
kind="checkpointed",
kind="checkpointed" if self.scan_kind is None else self.scan_kind,
checkpoints="all",
)

Expand Down
33 changes: 29 additions & 4 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,7 @@ def run(model):
run(mlp)


@pytest.mark.parametrize(
"diffusion_fn",
["weak", "lineax"],
)
@pytest.mark.parametrize("diffusion_fn", ["weak", "lineax"])
def test_sde_against(diffusion_fn, getkey):
def f(t, y, args):
del t
Expand Down Expand Up @@ -427,3 +424,31 @@ def test_implicit_runge_kutta_direct_adjoint():
adjoint=diffrax.DirectAdjoint(),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
)


@pytest.mark.parametrize("solver", (diffrax.Tsit5(), diffrax.GeneralShARK()))
def test_forward_mode_runge_kutta(solver, getkey):
# Totally fine that we're using Tsit5 with an SDE, it should converge to the
# Stratonovich solution.
bm = diffrax.UnsafeBrownianPath((), getkey(), levy_area=diffrax.SpaceTimeLevyArea)
drift = diffrax.ODETerm(lambda t, y, args: -y)
diffusion = diffrax.ControlTerm(lambda t, y, args: 0.1 * y, bm)
terms = diffrax.MultiTerm(drift, diffusion)

def run(y0):
sol = diffrax.diffeqsolve(
terms,
solver,
0,
1,
0.01,
y0,
adjoint=diffrax.ForwardMode(),
)
return sol.ys

@jax.jit
def run_jvp(y0):
return jax.jvp(run, (y0,), (jnp.ones_like(y0),))

run_jvp(jnp.array(1.0))

0 comments on commit c9a214d

Please sign in to comment.