Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful Controls #559

Draft
wants to merge 19 commits into
base: dev
Choose a base branch
from
Prev Previous commit
Next Next commit
solver work
  • Loading branch information
lockwo committed Dec 31, 2024
commit 382d171b4d096ef1b34903d50d746a0ba5846016
16 changes: 11 additions & 5 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -56,9 +56,7 @@ def _nondiff_solver_controller_state(
else:
controller_fn = lax.stop_gradient
if passed_path_state:
name = (
f"When using `adjoint={adjoint.__class__.__name__}()`, then `path_state`"
)
name = f"When using `adjoint={adjoint.__class__.__name__}()`, then `path_state`"
path_fn = ft.partial(
eqxi.nondifferentiable,
name=name,
@@ -509,7 +507,11 @@ def loop(
"`saveat=SaveAt(t1=True)`."
)
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state, passed_path_state
self,
init_state,
passed_solver_state,
passed_controller_state,
passed_path_state,
)
inputs = (args, terms, self, kwargs, solver, saveat, init_state)
ys, residual = optxi.implicit_jvp(
@@ -860,7 +862,11 @@ def loop(
y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state, passed_path_state
self,
init_state,
passed_solver_state,
passed_controller_state,
passed_path_state,
)

final_state, aux_stats = _loop_backsolve(
9 changes: 6 additions & 3 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
@@ -86,7 +86,7 @@ def __init__(
levy_area: type[
Union[BrownianIncrement, SpaceTimeLevyArea, SpaceTimeTimeLevyArea]
] = BrownianIncrement,
precompute: bool = True,
precompute: bool = False,
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
@@ -142,7 +142,7 @@ def init(
args: Args,
max_steps: Optional[int],
) -> _BrownianState:
if max_steps is not None:
if max_steps is not None and self.precompute:
subkey = split_by_tree(self.key, self.shape)
noise = jtu.tree_map(
lambda subkey, shape: self._generate_noise(subkey, shape),
@@ -181,7 +181,7 @@ def __call__(
t1 = cast(RealScalarLike, t1)

key, noises, counter = brownian_state
if key is None: # precomputed noise
if self.precompute: # precomputed noise
out = jtu.tree_map(
lambda shape, noise: self._evaluate_leaf_precomputed(
t0, t1, shape, self.levy_area, use_levy, noise
@@ -338,6 +338,9 @@ def _evaluate_leaf(
solvers.
- `precompute`: Whether or not to precompute the brownian motion (if possible). Precomputing
requires additional memory at initialization time, but can result in faster integrations.
Some thought may be required before enabling this, as solvers which require multiple
brownian increments may result in index out of bounds causing silent errors as the size
of the precomputed brownian motion is derived from the maximum steps.
"""

UnsafeBrownianPath = DirectBrownianPath
27 changes: 24 additions & 3 deletions diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools as ft
from collections.abc import Callable
from typing import cast, Optional, TYPE_CHECKING
from typing_extensions import TypeAlias

import equinox as eqx
import equinox.internal as eqxi
@@ -18,16 +19,17 @@
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, PyTree, Real, Shaped

from ._custom_types import DenseInfos, IntScalarLike, RealScalarLike, Y
from ._custom_types import DenseInfos, IntScalarLike, RealScalarLike, Y, Args
from ._local_interpolation import AbstractLocalInterpolation
from ._misc import fill_forward, left_broadcast_to
from ._path import AbstractPath
from ._path import AbstractPath, _Control


ω = cast(Callable, ω)
_PathState: TypeAlias = None


class AbstractGlobalInterpolation(AbstractPath):
class AbstractGlobalInterpolation(AbstractPath[_Control, _PathState]):
ts: AbstractVar[Real[Array, " times"]]
ts_size: AbstractVar[IntScalarLike]

@@ -55,6 +57,25 @@ def t1(self):
"""The end of the interval over which the interpolation is defined."""
return self.ts[-1]

def init(
self,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
max_steps: Optional[int],
) -> _PathState:
return None

def __call__(
self,
t0: RealScalarLike,
path_state: _PathState,
t1: Optional[RealScalarLike] = None,
left: bool = True,
) -> tuple[_Control, _PathState]:
return self.evaluate(t0, t1, left), path_state


class LinearInterpolation(AbstractGlobalInterpolation):
"""Linearly interpolates some data `ys` over the interval $[t_0, t_1]$ with knots
71 changes: 56 additions & 15 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
@@ -65,7 +65,14 @@
PIDController,
StepTo,
)
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm, _AbstractControlTerm
from ._term import (
_AbstractControlTerm,
AbstractTerm,
MultiTerm,
ODETerm,
UnderdampedLangevinDiffusionTerm,
WrapTerm,
)
from ._typing import better_isinstance, get_args_of, get_origin_no_specials


@@ -158,14 +165,14 @@ def _check(term_cls, term, term_contr_kwargs, yi):
# `term_cls` | `term_args`
# --------------------------|--------------
# AbstractTerm | ()
# AbstractTerm[VF, Control] | (VF, Control)
# AbstractTerm[VF, Control] | (VF, Control, Path)
# -----------------------------------------
term_args = get_args_of(AbstractTerm, term_cls, error_msg)
n_term_args = len(term_args)
if n_term_args == 0:
pass
elif n_term_args == 2:
vf_type_expected, control_type_expected = term_args
elif n_term_args == 3:
vf_type_expected, control_type_expected, path_type_expected = term_args
try:
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
except Exception as e:
@@ -179,14 +186,19 @@ def _check(term_cls, term, term_contr_kwargs, yi):
contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
try:
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
control_type, path_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
except Exception as e:
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
control_type_compatible = eqx.filter_eval_shape(
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError(f"Control term {term} is incompatible.")
path_type_compatible = eqx.filter_eval_shape(
better_isinstance, path_type, path_type_expected
)
if not path_type_compatible:
raise ValueError(f"Control term {term} path state is incompatible.")
else:
assert False, "Malformed term structure"
# If we've got to this point then the term is compatible
@@ -343,8 +355,8 @@ def body_fun_aux(state):
state.y,
args,
state.solver_state,
state.path_state,
state.made_jump,
state.path_state,
)

# e.g. if someone has a sqrt(y) in the vector field, and dt0 is so large that
@@ -853,7 +865,7 @@ class SaveAt(eqx.Module): # noqa: F811
t1: bool


@eqx.filter_jit
# @eqx.filter_jit
@eqxi.doc_remove_args("discrete_terminating_event")
def diffeqsolve(
terms: PyTree[AbstractTerm],
@@ -957,8 +969,8 @@ def diffeqsolve(

- `controller_state`: Some initial state for the step size controller. Generally
obtained by `SaveAt(controller_state=True)` from a previous solve.
- `path_state`: Some initial state for the path. Generally obtained by

- `path_state`: Some initial state for the path. Generally obtained by
`SaveAt(path_state=True)` from a previous solve.

- `made_jump`: Whether a jump has just been made at `t0`. Used to update
@@ -1094,13 +1106,27 @@ def _promote(yi):
)
terms = MultiTerm(*terms)

def _path_init(term):
if isinstance(term, _AbstractControlTerm) or isinstance(
term, UnderdampedLangevinDiffusionTerm
):
return term.control.init(t0, t1, y0, args, max_steps)
elif isinstance(term, MultiTerm):
return jax.tree.map(_path_init, term.terms, is_leaf=lambda x: isinstance(x, AbstractTerm))
return None

if path_state is None:
path_state = jtu.tree_map(
_path_init, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
)

# Error checking for term compatibility
_assert_term_compatible(
y0,
args,
terms,
solver.term_structure,
solver.term_compatible_contr_kwargs,
jtu.tree_map(lambda x, y: x | {"control_state": y}, solver.term_compatible_contr_kwargs, path_state, is_leaf=lambda x: isinstance(x, dict)),
)

if is_sde(terms):
@@ -1231,20 +1257,27 @@ def _subsaveat_direction_fn(x):
tnext = t0 + dt0
tnext = jnp.minimum(tnext, t1)

# reinit for tnext
def _path_init(term):
if isinstance(term, _AbstractControlTerm):
if isinstance(term, _AbstractControlTerm) or isinstance(
term, UnderdampedLangevinDiffusionTerm
):
return term.control.init(t0, tnext, y0, args, max_steps)
elif isinstance(term, MultiTerm):
return jax.tree.map(_path_init, term.terms, is_leaf=lambda x: isinstance(x, AbstractTerm))
return None

if path_state is None:
passed_path_state = False
path_state = jtu.tree_map(_path_init, terms)
path_state = jtu.tree_map(
_path_init, terms, is_leaf=lambda x: isinstance(x, AbstractTerm)
)
else:
passed_path_state = True

if solver_state is None:
passed_solver_state = False
solver_state = solver.init(terms, t0, tnext, y0, args)
solver_state = solver.init(terms, t0, tnext, y0, args, path_state)
else:
passed_solver_state = True

@@ -1285,7 +1318,15 @@ def _allocate_output(subsaveat: SubSaveAt) -> SaveState:
result = RESULTS.successful
if saveat.dense or event is not None:
_, _, dense_info_struct, _, _ = eqx.filter_eval_shape(
solver.step, terms, tprev, tnext, y0, args, solver_state, made_jump, path_state
solver.step,
terms,
tprev,
tnext,
y0,
args,
solver_state,
made_jump,
path_state,
)
if saveat.dense:
if max_steps is None:
28 changes: 24 additions & 4 deletions diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Callable
from typing import cast, Optional, TYPE_CHECKING
from typing_extensions import TypeAlias

import jax
import jax.numpy as jnp
@@ -14,17 +15,36 @@
from equinox.internal import ω
from jaxtyping import Array, ArrayLike, PyTree, Shaped

from ._custom_types import RealScalarLike, Y
from ._custom_types import RealScalarLike, Y, Args
from ._misc import linear_rescale
from ._path import AbstractPath
from ._path import AbstractPath, _Control


_PathState: TypeAlias = None

ω = cast(Callable, ω)


class AbstractLocalInterpolation(AbstractPath):
pass
class AbstractLocalInterpolation(AbstractPath[_Control, _PathState]):

def init(
self,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
max_steps: Optional[int],
) -> _PathState:
return None

def __call__(
self,
t0: RealScalarLike,
path_state: _PathState,
t1: Optional[RealScalarLike] = None,
left: bool = True,
) -> tuple[_Control, _PathState]:
return self.evaluate(t0, t1, left), path_state

class LocalLinearInterpolation(AbstractLocalInterpolation):
t0: RealScalarLike
1 change: 0 additions & 1 deletion diffrax/_path.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import PyTree


if TYPE_CHECKING:
21 changes: 20 additions & 1 deletion diffrax/_solution.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import optimistix as optx
from jaxtyping import Array, Bool, PyTree, Real, Shaped

from ._custom_types import BoolScalarLike, RealScalarLike
from ._custom_types import BoolScalarLike, RealScalarLike, Args, Y
from ._global_interpolation import DenseInterpolation
from ._path import AbstractPath

@@ -124,6 +124,25 @@ class Solution(AbstractPath):
made_jump: Optional[BoolScalarLike]
event_mask: Optional[PyTree[BoolScalarLike]]

def init(
self,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
max_steps: Optional[int],
) -> None:
return None

def __call__(
self,
t0: RealScalarLike,
path_state: None,
t1: Optional[RealScalarLike] = None,
left: bool = True,
) -> tuple[PyTree[Shaped[Array, "?*shape"], " Y"], None]:
return self.evaluate(t0, t1, left), path_state

def evaluate(
self, t0: RealScalarLike, t1: Optional[RealScalarLike] = None, left: bool = True
) -> PyTree[Shaped[Array, "?*shape"], " Y"]:
3 changes: 2 additions & 1 deletion diffrax/_solver/align.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
UnderdampedLangevinTuple,
UnderdampedLangevinX,
)
from .base import _PathState
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
@@ -43,7 +44,7 @@ def __init__(self, beta, a1, b1, aa, chh):
_ErrorEstimate = UnderdampedLangevinTuple


class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate, _PathState]):
r"""The Adaptive Langevin via Interpolated Gradients and Noise method
designed by James Foster. This is a second order solver for the
Underdamped Langevin Diffusion, and accepts terms of the form
Loading