Skip to content

Commit

Permalink
Added JumpStepWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Aug 12, 2024
1 parent 7384bfa commit 78b122a
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 147 deletions.
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
AbstractStepSizeController as AbstractStepSizeController,
ConstantStepSize as ConstantStepSize,
JumpStepWrapper as JumpStepWrapper,
PIDController as PIDController,
StepTo as StepTo,
)
Expand Down
1 change: 1 addition & 0 deletions diffrax/_step_size_controller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
)
from .base import AbstractStepSizeController as AbstractStepSizeController
from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo
from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper
160 changes: 18 additions & 142 deletions diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,22 @@
from collections.abc import Callable
from typing import cast, Optional, TYPE_CHECKING, TypeVar

import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
import optimistix as optx
from jaxtyping import Real


if TYPE_CHECKING:
from typing import ClassVar as AbstractVar
else:
from equinox import AbstractVar
from equinox.internal import ω
from jaxtyping import Array, PyTree
from lineax.internal import complex_to_real_dtype
from jaxtyping import PyTree
from lineax.internal import complex_to_real_dtype # pyright: ignore

from .._custom_types import (
Args,
Expand All @@ -29,7 +27,6 @@
VF,
Y,
)
from .._misc import static_select, upcast_or_raise
from .._solution import RESULTS
from .._term import AbstractTerm, ODETerm
from .base import AbstractStepSizeController
Expand Down Expand Up @@ -123,9 +120,8 @@ def __check_init__(self):
)


_PidState = tuple[
BoolScalarLike, BoolScalarLike, RealScalarLike, RealScalarLike, RealScalarLike
]
# _PidState = (at_dtmin, prev_inv_scaled_error, prev_prev_inv_scaled_error)
_PidState = tuple[BoolScalarLike, RealScalarLike, RealScalarLike]


def _none_or_array(x):
Expand Down Expand Up @@ -353,35 +349,14 @@ def dynamics(t, y, args):
dtmin: Optional[RealScalarLike] = None
dtmax: Optional[RealScalarLike] = None
force_dtmin: bool = True
step_ts: Optional[Real[Array, " steps"]] = eqx.field(
default=None, converter=_none_or_array
)
jump_ts: Optional[Real[Array, " jumps"]] = eqx.field(
default=None, converter=_none_or_array
)
factormin: RealScalarLike = 0.2
factormax: RealScalarLike = 10.0
norm: Callable[[PyTree], RealScalarLike] = rms_norm
safety: RealScalarLike = 0.9
error_order: Optional[RealScalarLike] = None

def __check_init__(self):
if self.jump_ts is not None and not jnp.issubdtype(
self.jump_ts.dtype, jnp.inexact
):
raise ValueError(
f"jump_ts must be floating point, not {self.jump_ts.dtype}"
)

def wrap(self, direction: IntScalarLike):
step_ts = None if self.step_ts is None else self.step_ts * direction
jump_ts = None if self.jump_ts is None else self.jump_ts * direction
return eqx.tree_at(
lambda s: (s.step_ts, s.jump_ts),
self,
(step_ts, jump_ts),
is_leaf=lambda x: x is None,
)
return self

def init(
self,
Expand Down Expand Up @@ -450,20 +425,18 @@ def init(
at_dtmin = dt0 <= self.dtmin
dt0 = jnp.maximum(dt0, self.dtmin)

t1 = self._clip_step_ts(t0, t0 + dt0)
t1, jump_next_step = self._clip_jump_ts(t0, t1)
t1 = t0 + dt0

y_leaves = jtu.tree_leaves(y0)
if len(y_leaves) == 0:
y_dtype = lxi.default_floating_dtype()
else:
y_dtype = jnp.result_type(*y_leaves)
real_dtype = complex_to_real_dtype(y_dtype)
return t1, (
jump_next_step,
at_dtmin,
dt0,
jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)),
jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)),
jnp.array(1.0, dtype=real_dtype),
jnp.array(1.0, dtype=real_dtype),
)

def adapt_step_size(
Expand Down Expand Up @@ -543,22 +516,12 @@ def adapt_step_size(
"error estimates."
)
(
made_jump,
at_dtmin,
prev_dt,
prev_inv_scaled_error,
prev_prev_inv_scaled_error,
) = controller_state
error_order = self._get_error_order(error_order)
# t1 - t0 is the step we actually took, so that's usually what we mean by the
# "previous dt".
# However if we made a jump then this t1 was clipped relatively to what it
# could have been, so for guessing the next step size it's probably better to
# use the size the step would have been, had there been no jump.
# There are cases in which something besides the step size controller modifies
# the step locations t0, t1; most notably the main integration routine clipping
# steps when we're right at the end of the interval.
prev_dt = jnp.where(made_jump, prev_dt, t1 - t0)
prev_dt = t1 - t0

#
# Figure out how things went on the last step: error, and whether to
Expand All @@ -576,7 +539,9 @@ def _scale(_y0, _y1_candidate, _y_error):

scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error))
keep_step = scaled_error < 1
# Automatically keep the step if we're at dtmin.
if self.dtmin is not None:
at_dtmin = at_dtmin | (prev_dt <= self.dtmin)
keep_step = keep_step | at_dtmin
# Make sure it's not a Python scalar and thus getting a ZeroDivisionError.
inv_scaled_error = 1 / jnp.asarray(scaled_error)
Expand All @@ -602,8 +567,8 @@ def _scale(_y0, _y1_candidate, _y_error):
factormin = jnp.where(keep_step, 1, self.factormin)
factor = jnp.clip(
self.safety * factor1 * factor2 * factor3,
min=factormin,
max=self.factormax,
min=factormin, # pyright: ignore
max=self.factormax, # pyright: ignore
)
# Once again, see above. In case we have gradients on {i,p,d}coeff.
# (Probably quite common for them to have zero tangents if passed across
Expand Down Expand Up @@ -634,35 +599,20 @@ def _scale(_y0, _y1_candidate, _y_error):
at_dtmin = dt <= self.dtmin
dt = jnp.maximum(dt, self.dtmin)

#
# Clip next step size based on step_ts/jump_ts
#

if jnp.issubdtype(jnp.result_type(t1), jnp.inexact):
# Two nextafters. If made_jump then t1 = prevbefore(jump location)
# so now _t1 = nextafter(jump location)
# This is important because we don't know whether or not the jump is as a
# result of a left- or right-discontinuity, so we have to skip the jump
# location altogether.
_t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1)
else:
_t1 = t1
next_t0 = jnp.where(keep_step, _t1, t0)
next_t1 = self._clip_step_ts(next_t0, next_t0 + dt)
next_t1, next_made_jump = self._clip_jump_ts(next_t0, next_t1)
next_t0 = jnp.where(keep_step, t1, t0)
next_t1 = next_t0 + dt

inv_scaled_error = jnp.where(keep_step, inv_scaled_error, prev_inv_scaled_error)
prev_inv_scaled_error = jnp.where(
keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error
)
controller_state = (
next_made_jump,
at_dtmin,
dt,
inv_scaled_error,
prev_inv_scaled_error,
)
return keep_step, next_t0, next_t1, made_jump, controller_state, result
# made_jump is handled by JumpStepWrapper, so we automatically set it to False
return keep_step, next_t0, next_t1, False, controller_state, result

def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarLike:
# Attribute takes priority, if the user knows the correct error order better
Expand All @@ -677,76 +627,6 @@ def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarL
)
return error_order

def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike:
if self.step_ts is None:
return t1

step_ts0 = upcast_or_raise(
self.step_ts,
t0,
"`PIDController.step_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
step_ts1 = upcast_or_raise(
self.step_ts,
t1,
"`PIDController.step_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
# TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping
# track of where we were last, and using that as a hint for the next search.
t0_index = jnp.searchsorted(step_ts0, t0, side="right")
t1_index = jnp.searchsorted(step_ts1, t1, side="right")
# This minimum may or may not actually be necessary. The left branch is taken
# iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must
# already satisfy the minimum.
# However, that branch is actually executed unconditionally and then where'd,
# so we clamp it just to be sure we're not hitting undefined behaviour.
t1 = jnp.where(
t0_index < t1_index,
step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)],
t1,
)
return t1

def _clip_jump_ts(
self, t0: RealScalarLike, t1: RealScalarLike
) -> tuple[RealScalarLike, BoolScalarLike]:
if self.jump_ts is None:
return t1, False
assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact)
if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact):
raise ValueError(
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
f"Got {jnp.result_type(t0)}."
)
if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact):
raise ValueError(
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
f"Got {jnp.result_type(t1)}."
)
jump_ts0 = upcast_or_raise(
self.jump_ts,
t0,
"`PIDController.jump_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
jump_ts1 = upcast_or_raise(
self.jump_ts,
t1,
"`PIDController.jump_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
t0_index = jnp.searchsorted(jump_ts0, t0, side="right")
t1_index = jnp.searchsorted(jump_ts1, t1, side="right")
next_made_jump = t0_index < t1_index
t1 = jnp.where(
next_made_jump,
eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
t1,
)
return t1, next_made_jump


PIDController.__init__.__doc__ = """**Arguments:**
Expand All @@ -761,10 +641,6 @@ def _clip_jump_ts(
- `force_dtmin`: How to handle the step size hitting the minimum. If `True` then the
step size is clipped to `dtmin`. If `False` then the differential equation solve
halts with an error.
- `step_ts`: Denotes extra times that must be stepped to.
- `jump_ts`: Denotes extra times that must be stepped to, and at which the vector field
has a known discontinuity. (This is used to force FSAL solvers so re-evaluate the
vector field.)
- `factormin`: Minimum amount a step size can be decreased relative to the previous
step.
- `factormax`: Maximum amount a step size can be increased relative to the previous
Expand Down
Loading

0 comments on commit 78b122a

Please sign in to comment.