diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 67b4ca50..dc93c879 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -117,6 +117,7 @@ AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, AbstractStepSizeController as AbstractStepSizeController, ConstantStepSize as ConstantStepSize, + JumpStepWrapper as JumpStepWrapper, PIDController as PIDController, StepTo as StepTo, ) diff --git a/diffrax/_step_size_controller/__init__.py b/diffrax/_step_size_controller/__init__.py index 18d19c00..719a1d23 100644 --- a/diffrax/_step_size_controller/__init__.py +++ b/diffrax/_step_size_controller/__init__.py @@ -4,3 +4,4 @@ ) from .base import AbstractStepSizeController as AbstractStepSizeController from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo +from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 9d181c95..a751d50f 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -2,7 +2,6 @@ from collections.abc import Callable from typing import cast, Optional, TYPE_CHECKING, TypeVar -import equinox as eqx import equinox.internal as eqxi import jax import jax.lax as lax @@ -10,7 +9,6 @@ import jax.tree_util as jtu import lineax.internal as lxi import optimistix as optx -from jaxtyping import Real if TYPE_CHECKING: @@ -18,8 +16,8 @@ else: from equinox import AbstractVar from equinox.internal import ω -from jaxtyping import Array, PyTree -from lineax.internal import complex_to_real_dtype +from jaxtyping import PyTree +from lineax.internal import complex_to_real_dtype # pyright: ignore from .._custom_types import ( Args, @@ -29,7 +27,6 @@ VF, Y, ) -from .._misc import static_select, upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm, ODETerm from .base import AbstractStepSizeController @@ -123,9 +120,8 @@ def __check_init__(self): ) -_PidState = tuple[ - BoolScalarLike, BoolScalarLike, RealScalarLike, RealScalarLike, RealScalarLike -] +# _PidState = (at_dtmin, prev_inv_scaled_error, prev_prev_inv_scaled_error) +_PidState = tuple[BoolScalarLike, RealScalarLike, RealScalarLike] def _none_or_array(x): @@ -353,35 +349,14 @@ def dynamics(t, y, args): dtmin: Optional[RealScalarLike] = None dtmax: Optional[RealScalarLike] = None force_dtmin: bool = True - step_ts: Optional[Real[Array, " steps"]] = eqx.field( - default=None, converter=_none_or_array - ) - jump_ts: Optional[Real[Array, " jumps"]] = eqx.field( - default=None, converter=_none_or_array - ) factormin: RealScalarLike = 0.2 factormax: RealScalarLike = 10.0 norm: Callable[[PyTree], RealScalarLike] = rms_norm safety: RealScalarLike = 0.9 error_order: Optional[RealScalarLike] = None - def __check_init__(self): - if self.jump_ts is not None and not jnp.issubdtype( - self.jump_ts.dtype, jnp.inexact - ): - raise ValueError( - f"jump_ts must be floating point, not {self.jump_ts.dtype}" - ) - def wrap(self, direction: IntScalarLike): - step_ts = None if self.step_ts is None else self.step_ts * direction - jump_ts = None if self.jump_ts is None else self.jump_ts * direction - return eqx.tree_at( - lambda s: (s.step_ts, s.jump_ts), - self, - (step_ts, jump_ts), - is_leaf=lambda x: x is None, - ) + return self def init( self, @@ -450,20 +425,18 @@ def init( at_dtmin = dt0 <= self.dtmin dt0 = jnp.maximum(dt0, self.dtmin) - t1 = self._clip_step_ts(t0, t0 + dt0) - t1, jump_next_step = self._clip_jump_ts(t0, t1) + t1 = t0 + dt0 y_leaves = jtu.tree_leaves(y0) if len(y_leaves) == 0: y_dtype = lxi.default_floating_dtype() else: y_dtype = jnp.result_type(*y_leaves) + real_dtype = complex_to_real_dtype(y_dtype) return t1, ( - jump_next_step, at_dtmin, - dt0, - jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), - jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), + jnp.array(1.0, dtype=real_dtype), + jnp.array(1.0, dtype=real_dtype), ) def adapt_step_size( @@ -543,22 +516,12 @@ def adapt_step_size( "error estimates." ) ( - made_jump, at_dtmin, - prev_dt, prev_inv_scaled_error, prev_prev_inv_scaled_error, ) = controller_state error_order = self._get_error_order(error_order) - # t1 - t0 is the step we actually took, so that's usually what we mean by the - # "previous dt". - # However if we made a jump then this t1 was clipped relatively to what it - # could have been, so for guessing the next step size it's probably better to - # use the size the step would have been, had there been no jump. - # There are cases in which something besides the step size controller modifies - # the step locations t0, t1; most notably the main integration routine clipping - # steps when we're right at the end of the interval. - prev_dt = jnp.where(made_jump, prev_dt, t1 - t0) + prev_dt = t1 - t0 # # Figure out how things went on the last step: error, and whether to @@ -576,7 +539,9 @@ def _scale(_y0, _y1_candidate, _y_error): scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 + # Automatically keep the step if we're at dtmin. if self.dtmin is not None: + at_dtmin = at_dtmin | (prev_dt <= self.dtmin) keep_step = keep_step | at_dtmin # Make sure it's not a Python scalar and thus getting a ZeroDivisionError. inv_scaled_error = 1 / jnp.asarray(scaled_error) @@ -602,8 +567,8 @@ def _scale(_y0, _y1_candidate, _y_error): factormin = jnp.where(keep_step, 1, self.factormin) factor = jnp.clip( self.safety * factor1 * factor2 * factor3, - min=factormin, - max=self.factormax, + min=factormin, # pyright: ignore + max=self.factormax, # pyright: ignore ) # Once again, see above. In case we have gradients on {i,p,d}coeff. # (Probably quite common for them to have zero tangents if passed across @@ -634,35 +599,20 @@ def _scale(_y0, _y1_candidate, _y_error): at_dtmin = dt <= self.dtmin dt = jnp.maximum(dt, self.dtmin) - # - # Clip next step size based on step_ts/jump_ts - # - - if jnp.issubdtype(jnp.result_type(t1), jnp.inexact): - # Two nextafters. If made_jump then t1 = prevbefore(jump location) - # so now _t1 = nextafter(jump location) - # This is important because we don't know whether or not the jump is as a - # result of a left- or right-discontinuity, so we have to skip the jump - # location altogether. - _t1 = static_select(made_jump, eqxi.nextafter(eqxi.nextafter(t1)), t1) - else: - _t1 = t1 - next_t0 = jnp.where(keep_step, _t1, t0) - next_t1 = self._clip_step_ts(next_t0, next_t0 + dt) - next_t1, next_made_jump = self._clip_jump_ts(next_t0, next_t1) + next_t0 = jnp.where(keep_step, t1, t0) + next_t1 = next_t0 + dt inv_scaled_error = jnp.where(keep_step, inv_scaled_error, prev_inv_scaled_error) prev_inv_scaled_error = jnp.where( keep_step, prev_inv_scaled_error, prev_prev_inv_scaled_error ) controller_state = ( - next_made_jump, at_dtmin, - dt, inv_scaled_error, prev_inv_scaled_error, ) - return keep_step, next_t0, next_t1, made_jump, controller_state, result + # made_jump is handled by JumpStepWrapper, so we automatically set it to False + return keep_step, next_t0, next_t1, False, controller_state, result def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarLike: # Attribute takes priority, if the user knows the correct error order better @@ -677,76 +627,6 @@ def _get_error_order(self, error_order: Optional[RealScalarLike]) -> RealScalarL ) return error_order - def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: - if self.step_ts is None: - return t1 - - step_ts0 = upcast_or_raise( - self.step_ts, - t0, - "`PIDController.step_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - step_ts1 = upcast_or_raise( - self.step_ts, - t1, - "`PIDController.step_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping - # track of where we were last, and using that as a hint for the next search. - t0_index = jnp.searchsorted(step_ts0, t0, side="right") - t1_index = jnp.searchsorted(step_ts1, t1, side="right") - # This minimum may or may not actually be necessary. The left branch is taken - # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must - # already satisfy the minimum. - # However, that branch is actually executed unconditionally and then where'd, - # so we clamp it just to be sure we're not hitting undefined behaviour. - t1 = jnp.where( - t0_index < t1_index, - step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], - t1, - ) - return t1 - - def _clip_jump_ts( - self, t0: RealScalarLike, t1: RealScalarLike - ) -> tuple[RealScalarLike, BoolScalarLike]: - if self.jump_ts is None: - return t1, False - assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) - if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): - raise ValueError( - "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " - f"Got {jnp.result_type(t0)}." - ) - if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): - raise ValueError( - "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " - f"Got {jnp.result_type(t1)}." - ) - jump_ts0 = upcast_or_raise( - self.jump_ts, - t0, - "`PIDController.jump_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - jump_ts1 = upcast_or_raise( - self.jump_ts, - t1, - "`PIDController.jump_ts`", - "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", - ) - t0_index = jnp.searchsorted(jump_ts0, t0, side="right") - t1_index = jnp.searchsorted(jump_ts1, t1, side="right") - next_made_jump = t0_index < t1_index - t1 = jnp.where( - next_made_jump, - eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), - t1, - ) - return t1, next_made_jump - PIDController.__init__.__doc__ = """**Arguments:** @@ -761,10 +641,6 @@ def _clip_jump_ts( - `force_dtmin`: How to handle the step size hitting the minimum. If `True` then the step size is clipped to `dtmin`. If `False` then the differential equation solve halts with an error. -- `step_ts`: Denotes extra times that must be stepped to. -- `jump_ts`: Denotes extra times that must be stepped to, and at which the vector field - has a known discontinuity. (This is used to force FSAL solvers so re-evaluate the - vector field.) - `factormin`: Minimum amount a step size can be decreased relative to the previous step. - `factormax`: Maximum amount a step size can be increased relative to the previous diff --git a/diffrax/_step_size_controller/jump_step_wrapper.py b/diffrax/_step_size_controller/jump_step_wrapper.py new file mode 100644 index 00000000..9db17802 --- /dev/null +++ b/diffrax/_step_size_controller/jump_step_wrapper.py @@ -0,0 +1,232 @@ +from collections.abc import Callable +from typing import Optional, TYPE_CHECKING, TypeAlias, TypeVar + +import equinox as eqx +import equinox.internal as eqxi +import jax.numpy as jnp +from jaxtyping import Array, PyTree, Real + +from .._custom_types import ( + Args, + BoolScalarLike, + IntScalarLike, + RealScalarLike, + VF, + Y, +) +from .._misc import static_select, upcast_or_raise +from .._solution import RESULTS +from .._term import AbstractTerm +from .adaptive import _none_or_array +from .base import AbstractStepSizeController + + +_ControllerState = TypeVar("_ControllerState") +_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) +_JumpStepState: TypeAlias = tuple[BoolScalarLike, RealScalarLike, _ControllerState] + + +class JumpStepWrapper( + AbstractStepSizeController[ + tuple[BoolScalarLike, RealScalarLike, _ControllerState], _Dt0 + ] +): + """Wraps an existing step controller and adds the ability to specify `step_ts` + and `jump_ts`. The former are times to which the controller should step and the + latter are times at which the vector field has a discontinuity (jump).""" + + controller: AbstractStepSizeController[_ControllerState, _Dt0] + step_ts: Optional[Real[Array, " steps"]] = eqx.field( + default=None, converter=_none_or_array + ) + jump_ts: Optional[Real[Array, " jumps"]] = eqx.field( + default=None, converter=_none_or_array + ) + + def __check_init__(self): + if self.jump_ts is not None and not jnp.issubdtype( + self.jump_ts.dtype, jnp.inexact + ): + raise ValueError( + f"jump_ts must be floating point, not {self.jump_ts.dtype}" + ) + + def wrap(self, direction: IntScalarLike): + step_ts = None if self.step_ts is None else self.step_ts * direction + jump_ts = None if self.jump_ts is None else self.jump_ts * direction + return eqx.tree_at( + lambda s: (s.step_ts, s.jump_ts), + self, + (step_ts, jump_ts), + is_leaf=lambda x: x is None, + ) + + def init( + self, + terms: PyTree[AbstractTerm], + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + dt0: _Dt0, + args: Args, + func: Callable[[PyTree[AbstractTerm], RealScalarLike, Y, Args], VF], + error_order: Optional[RealScalarLike], + ) -> tuple[RealScalarLike, _JumpStepState]: + t1, inner_state = self.controller.init( + terms, t0, t1, y0, dt0, args, func, error_order + ) + dt_proposal = t1 - t0 + + t1 = self._clip_step_ts(t0, t1) + t1, jump_next_step = self._clip_jump_ts(t0, t1) + + state = (jump_next_step, dt_proposal, inner_state) + + return t1, state + + def adapt_step_size( + self, + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + y1_candidate: Y, + args: Args, + y_error: Optional[Y], + error_order: RealScalarLike, + controller_state: _JumpStepState, + ) -> tuple[ + BoolScalarLike, + RealScalarLike, + RealScalarLike, + BoolScalarLike, + _JumpStepState, + RESULTS, + ]: + made_jump, prev_dt, inner_state = controller_state + eqx.error_if(prev_dt, prev_dt < t1 - t0, "prev_dt must be >= t1-t0") + + ( + keep_step, + next_t0, + next_t1, + _, + inner_state, + result, + ) = self.controller.adapt_step_size( + t0, t1, y0, y1_candidate, args, y_error, error_order, inner_state + ) + + dt_proposal = next_t1 - next_t0 + dt_proposal = jnp.where( + keep_step, jnp.maximum(dt_proposal, prev_dt), dt_proposal + ) + new_prev_dt = dt_proposal + + # If the step was kept and a jump was made, then we need to set + # `next_t0 = nextafter(nextafter(next_t0))` to ensure that we really skip + # over the jump and don't evaluate the vector field at the discontinuity. + if jnp.issubdtype(jnp.result_type(t1), jnp.inexact): + # Two nextafters. If made_jump then t1 = prevbefore(jump location) + # so now _t1 = nextafter(jump location) + # This is important because we don't know whether or not the jump is as a + # result of a left- or right-discontinuity, so we have to skip the jump + # location altogether. + jump_keep = made_jump & keep_step + next_t0 = static_select( + jump_keep, eqxi.nextafter(eqxi.nextafter(next_t0)), next_t0 + ) + + if TYPE_CHECKING: + assert isinstance( + next_t0, RealScalarLike + ), f"type(next_t0) = {type(next_t0)}" + next_t1 = next_t0 + dt_proposal + + # Clip the step to the next element of jump_ts or step_ts. + next_t1 = self._clip_step_ts(next_t0, next_t1) + next_t1, jump_next_step = self._clip_jump_ts(next_t0, next_t1) + + state = (jump_next_step, new_prev_dt, inner_state) + + return keep_step, next_t0, next_t1, made_jump, state, result + + def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: + if self.step_ts is None: + return t1 + + step_ts0 = upcast_or_raise( + self.step_ts, + t0, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + step_ts1 = upcast_or_raise( + self.step_ts, + t1, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping + # track of where we were last, and using that as a hint for the next search. + t0_index = jnp.searchsorted(step_ts0, t0, side="right") + t1_index = jnp.searchsorted(step_ts1, t1, side="right") + # This minimum may or may not actually be necessary. The left branch is taken + # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must + # already satisfy the minimum. + # However, that branch is actually executed unconditionally and then where'd, + # so we clamp it just to be sure we're not hitting undefined behaviour. + t1 = jnp.where( + t0_index < t1_index, + step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], + t1, + ) + return t1 + + def _clip_jump_ts( + self, t0: RealScalarLike, t1: RealScalarLike + ) -> tuple[RealScalarLike, BoolScalarLike]: + if self.jump_ts is None: + return t1, False + assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) + if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t0)}." + ) + if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): + raise ValueError( + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t1)}." + ) + jump_ts0 = upcast_or_raise( + self.jump_ts, + t0, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + jump_ts1 = upcast_or_raise( + self.jump_ts, + t1, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + t0_index = jnp.searchsorted(jump_ts0, t0, side="right") + t1_index = jnp.searchsorted(jump_ts1, t1, side="right") + next_made_jump = t0_index < t1_index + t1 = jnp.where( + next_made_jump, + eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), + t1, + ) + return t1, next_made_jump + + +JumpStepWrapper.__init__.__doc__ = r"""**Arguments**: + +- `controller`: The controller to wrap. +- `step_ts`: Denotes extra times that must be stepped to. +- `jump_ts`: Denotes extra times that must be stepped to, and at which the vector field + has a known discontinuity. (This is used to force FSAL solvers so re-evaluate the + vector field.) + +""" diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 4cc996c8..a22603f1 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -17,7 +17,8 @@ def test_step_ts(): t1 = 5 dt0 = None y0 = 1.0 - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=[3, 4]) + pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, step_ts=[3, 4]) saveat = diffrax.SaveAt(steps=True) sol = diffrax.diffeqsolve( term, @@ -50,7 +51,8 @@ def vector_field(t, y, args): saveat = diffrax.SaveAt(steps=True) def run(**kwargs): - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6, **kwargs) + pid_controller = diffrax.PIDController(rtol=1e-4, atol=1e-6) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, **kwargs) return diffrax.diffeqsolve( term, solver, @@ -113,9 +115,11 @@ def run(t): t1 = 1 dt0 = None y0 = 1.0 - stepsize_controller = diffrax.PIDController( - rtol=1e-8, atol=1e-8, step_ts=t[None] + pid_controller = diffrax.PIDController( + rtol=1e-8, + atol=1e-8, ) + stepsize_controller = diffrax.JumpStepWrapper(pid_controller, step_ts=t[None]) def forcing(s): return jnp.where(s < t, 0, 1) diff --git a/test/test_progress_meter.py b/test/test_progress_meter.py index 76028169..b7d5c4c8 100644 --- a/test/test_progress_meter.py +++ b/test/test_progress_meter.py @@ -39,7 +39,7 @@ def solve(t0): err = captured.err.strip() assert re.match("0.00%|[ ]+|", err.split("\r", 1)[0]) assert re.match("100.00%|█+|", err.rsplit("\r", 1)[1]) - assert captured.err.count("\r") == num_lines + assert captured.err.count("\r") - num_lines in [0, 1] assert captured.err.count("\n") == 1