diff --git a/benchmarks/jump_step_timing.py b/benchmarks/jump_step_timing.py index 2803ba8e..9250de4f 100644 --- a/benchmarks/jump_step_timing.py +++ b/benchmarks/jump_step_timing.py @@ -39,7 +39,7 @@ def get_terms(key): new_controller = diffrax.JumpStepWrapper( pid_controller, step_ts=step_ts, - rejected_step_buffer_len=0, + rejected_step_buffer_len=None, ) old_controller = OldPIDController( rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts @@ -66,63 +66,51 @@ def solve(key, controller): keys = jr.split(jr.PRNGKey(0), num_samples) -# NEW CONTROLLER -@jax.jit -@eqx.debug.assert_max_traces(max_traces=1) -def time_new_controller_fun(): - sols = solve(keys, new_controller) - assert sols.ys is not None - assert sols.ys.shape == (num_samples, len(step_ts)) - return sols.ys +def do_timing(controller): + @jax.jit + @eqx.debug.assert_max_traces(max_traces=1) + def time_controller_fun(): + sols = solve(keys, controller) + assert sols.ys is not None + assert sols.ys.shape == (num_samples, len(step_ts)) + return sols.ys + def time_controller(): + jax.block_until_ready(time_controller_fun()) -def time_new_controller(): - jax.block_until_ready(time_new_controller_fun()) + return min(timeit.repeat(time_controller, number=3, repeat=20)) -# OLD CONTROLLER -@jax.jit -@eqx.debug.assert_max_traces(max_traces=1) -def time_old_controller_fun(): - sols = solve(keys, old_controller) - assert sols.ys is not None - assert sols.ys.shape == (num_samples, len(step_ts)) - return sols.ys +time_new = do_timing(new_controller) - -def time_old_controller(): - jax.block_until_ready(time_old_controller_fun()) - - -time_new = min(timeit.repeat(time_new_controller, number=3, repeat=20)) - -time_old = min(timeit.repeat(time_old_controller, number=3, repeat=20)) +time_old = do_timing(old_controller) print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s") # How expensive is revisiting rejected steps? -new_revisiting_controller = diffrax.JumpStepWrapper( +revisiting_controller_short = diffrax.JumpStepWrapper( pid_controller, step_ts=step_ts, rejected_step_buffer_len=10, ) +revisiting_controller_long = diffrax.JumpStepWrapper( + pid_controller, + step_ts=step_ts, + rejected_step_buffer_len=4096, +) -def time_revisiting_controller_fun(): - sols = solve(keys, new_revisiting_controller) - assert sols.ys is not None - assert sols.ys.shape == (num_samples, len(step_ts)) - return sols.ys - - -def time_revisiting_controller(): - jax.block_until_ready(time_revisiting_controller_fun()) - - -time_revisiting = min(timeit.repeat(time_revisiting_controller, number=3, repeat=20)) +time_revisiting_short = do_timing(revisiting_controller_short) +time_revisiting_long = do_timing(revisiting_controller_long) -print(f"Revisiting controller: {time_revisiting:.5} s") +print( + f"Revisiting controller\n" + f"with buffer len 10: {time_revisiting_short:.5} s\n" + f"with buffer len 4096: {time_revisiting_long:.5} s" +) # ======= RESULTS ======= -# New controller: 0.22829 s, Old controller: 0.31039 s -# Revisiting controller: 0.23212 s +# New controller: 0.23506 s, Old controller: 0.30735 s +# Revisiting controller +# with buffer len 10: 0.23636 s +# with buffer len 4096: 0.23965 s diff --git a/benchmarks/old_pid_controller.py b/benchmarks/old_pid_controller.py index ed0a6e71..f6d78098 100644 --- a/benchmarks/old_pid_controller.py +++ b/benchmarks/old_pid_controller.py @@ -9,7 +9,7 @@ import jax.tree_util as jtu import lineax.internal as lxi import optimistix as optx -from diffrax import AbstractAdaptiveStepSizeController, AbstractTerm, ODETerm, RESULTS +from diffrax import AbstractTerm, ODETerm, RESULTS from diffrax._custom_types import ( Args, BoolScalarLike, @@ -19,6 +19,7 @@ Y, ) from diffrax._misc import static_select, upcast_or_raise +from diffrax._step_size_controller import AbstractAdaptiveStepSizeController from equinox.internal import ω from jaxtyping import Array, PyTree, Real from lineax.internal import complex_to_real_dtype diff --git a/diffrax/_step_size_controller/__init__.py b/diffrax/_step_size_controller/__init__.py index 719a1d23..7aa2b0b3 100644 --- a/diffrax/_step_size_controller/__init__.py +++ b/diffrax/_step_size_controller/__init__.py @@ -1,7 +1,9 @@ -from .adaptive import ( +from .adaptive_base import ( AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, - PIDController as PIDController, ) from .base import AbstractStepSizeController as AbstractStepSizeController from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper +from .pid import ( + PIDController as PIDController, +) diff --git a/diffrax/_step_size_controller/adaptive_base.py b/diffrax/_step_size_controller/adaptive_base.py new file mode 100644 index 00000000..7f984e05 --- /dev/null +++ b/diffrax/_step_size_controller/adaptive_base.py @@ -0,0 +1,42 @@ +from collections.abc import Callable +from typing import Optional, TypeVar + +from equinox import AbstractVar +from jaxtyping import PyTree + +from .._custom_types import RealScalarLike +from .base import AbstractStepSizeController + + +_ControllerState = TypeVar("_ControllerState") +_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) + + +class AbstractAdaptiveStepSizeController( + AbstractStepSizeController[_ControllerState, _Dt0] +): + """Indicates an adaptive step size controller. + + Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit + solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will + automatically be used as the tolerances for the nonlinear solver passed to the + implicit solver, if they are not specified manually. + """ + + rtol: AbstractVar[RealScalarLike] + atol: AbstractVar[RealScalarLike] + norm: AbstractVar[Callable[[PyTree], RealScalarLike]] + + def __check_init__(self): + if self.rtol is None or self.atol is None: + raise ValueError( + "The default values for `rtol` and `atol` were removed in Diffrax " + "version 0.1.0. (As the choice of tolerance is nearly always " + "something that you, as an end user, should make an explicit choice " + "about.)\n" + "If you want to match the previous defaults then specify " + "`rtol=1e-3`, `atol=1e-6`. For example:\n" + "```\n" + "diffrax.PIDController(rtol=1e-3, atol=1e-6)\n" + "```\n" + ) diff --git a/diffrax/_step_size_controller/jump_step_wrapper.py b/diffrax/_step_size_controller/jump_step_wrapper.py index 58ef38da..2f3cb4f7 100644 --- a/diffrax/_step_size_controller/jump_step_wrapper.py +++ b/diffrax/_step_size_controller/jump_step_wrapper.py @@ -18,7 +18,7 @@ from .._misc import static_select, upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm -from .adaptive import _none_or_array +from .adaptive_base import AbstractAdaptiveStepSizeController from .base import AbstractStepSizeController @@ -51,6 +51,13 @@ def get(self): ) +def _none_or_array(x): + if x is None: + return None + else: + return jnp.asarray(x) + + def _get_t(i: IntScalarLike, ts: Array) -> RealScalarLike: i_min_len = jnp.minimum(i, len(ts) - 1) return jnp.where(i == len(ts), jnp.inf, ts[i_min_len]) @@ -92,7 +99,7 @@ def _find_index(t: RealScalarLike, ts: Optional[Array]) -> IntScalarLike: ts = upcast_or_raise( ts, t, - "`PIDController.step_ts`", + "`JumpStepWrapper.step_ts`", "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", ) return jnp.searchsorted(ts, t, side="right") @@ -101,12 +108,12 @@ def _find_index(t: RealScalarLike, ts: Optional[Array]) -> IntScalarLike: def _revisit_rejected( t0: RealScalarLike, t1: RealScalarLike, - i_rjct: IntScalarLike, - rjct_buff: Optional[Array], + i_reject: IntScalarLike, + rejected_buffer: Optional[Array], ) -> RealScalarLike: - if rjct_buff is None: + if rejected_buffer is None: return t1 - _t1 = _get_t(i_rjct, rjct_buff) + _t1 = _get_t(i_reject, rejected_buffer) _t1 = jnp.minimum(_t1, t1) return _t1 @@ -134,16 +141,17 @@ def _revisit_rejected( # EXPLANATION OF REVISITING REJECTED STEPS # ---------------------------------------- -# We use a "stack" of rejected steps, composed of a buffer `rjct_buff` of length -# `rejected_step_buffer_len` and a counter `i_rjct`. The "stack" are all the items -# in `rjct_buff[i_rjct:]` with `rjct_buff[i_rjct]` being the top of the stack. -# When `i_rjct == rejected_step_buffer_len`, the stack is empty. -# At the start of the run, `i_rjct = rejected_step_buffer_len`. Each time a step is -# rejected `i_rjct -=1` and `rjct_buff[i_rjct] = t1`. Each time a step ends at -# `t1 == rjct_buff[i_rjct]`, we increment `i_rjct` by 1 (even if the step was +# We use a "stack" of rejected steps, composed of a buffer `rejected_buffer` of length +# `rejected_step_buffer_len` and a counter `i_reject`. The "stack" are all the items +# in `rejected_buffer[i_reject:]` with `rejected_buffer[i_reject]` being the top of +# the stack. +# When `i_reject == rejected_step_buffer_len`, the stack is empty. +# At the start of the run, `i_reject = rejected_step_buffer_len`. Each time a step is +# rejected `i_reject -=1` and `rejected_buffer[i_reject] = t1`. Each time a step ends at +# `t1 == rejected_buffer[i_reject]`, we increment `i_reject` by 1 (even if the step was # rejected, in which case we will re-add `t1` to the stack immediately). -# We clip the next step to `t1_next = min(t1_next, rjct_buff[i_rjct])`. -# If `i_rjct < 0` then an error is raised. +# We clip the next step to `t1_next = min(t1_next, rejected_buffer[i_reject])`. +# If `i_reject < 0` then an error is raised. class JumpStepWrapper( @@ -153,10 +161,10 @@ class JumpStepWrapper( and `jump_ts`. The former are times to which the controller should step and the latter are times at which the vector field has a discontinuity (jump).""" - controller: AbstractStepSizeController[_ControllerState, _Dt0] + controller: AbstractAdaptiveStepSizeController[_ControllerState, _Dt0] step_ts: Optional[Real[Array, " steps"]] jump_ts: Optional[Real[Array, " jumps"]] - rejected_step_buffer_len: int = eqx.field(static=True) + rejected_step_buffer_len: Optional[int] = eqx.field(static=True) callback_on_reject: Optional[Callable] = eqx.field(static=True) @eqxi.doc_remove_args("_callback_on_reject") @@ -165,21 +173,22 @@ def __init__( controller, step_ts=None, jump_ts=None, - rejected_step_buffer_len=0, + rejected_step_buffer_len=None, _callback_on_reject=None, ): r""" **Arguments**: - `controller`: The controller to wrap. - Can be any diffrax.AbstractStepSizeController. + Can be any diffrax.AbstractAdaptiveStepSizeController. - `step_ts`: Denotes extra times that must be stepped to. - `jump_ts`: Denotes extra times that must be stepped to, and at which the vector field has a known discontinuity. (This is used to force FSAL solvers so re-evaluate the vector field.) - `rejected_step_buffer_len`: The length of the buffer storing rejected steps. - If this is > 0, then the controller will revisit rejected steps. This is - useful for SDEs, where the solution is guaranteed to be correct only if the + Can either be None or a positive integer. + If it is > 0, then the controller will revisit rejected steps. This is + useful for SDEs, where the solution is guaranteed to be correct if the SDE is evaluated at all times at which the Brownian motion (BM) is evaluated. Since the BM is also evaluated at rejected steps, we must later evaluate the SDE at these times as well. @@ -187,9 +196,10 @@ def __init__( self.controller = controller self.step_ts = _none_or_array(step_ts) self.jump_ts = _none_or_array(jump_ts) + if rejected_step_buffer_len is not None: + assert rejected_step_buffer_len > 0 self.rejected_step_buffer_len = rejected_step_buffer_len self.callback_on_reject = _callback_on_reject - self.__check_init__() def __check_init__(self): if self.jump_ts is not None and not jnp.issubdtype( @@ -202,10 +212,11 @@ def __check_init__(self): def wrap(self, direction: IntScalarLike): step_ts = None if self.step_ts is None else self.step_ts * direction jump_ts = None if self.jump_ts is None else self.jump_ts * direction + controller = self.controller.wrap(direction) return eqx.tree_at( - lambda s: (s.step_ts, s.jump_ts), + lambda s: (s.step_ts, s.jump_ts, s.controller), self, - (step_ts, jump_ts), + (step_ts, jump_ts, controller), is_leaf=lambda x: x is None, ) @@ -226,36 +237,37 @@ def init( dt_proposal = t1 - t0 tdtype = jnp.result_type(t0, t1) - if self.step_ts is not None: + if self.step_ts is None: + step_ts = None + else: # Upcast step_ts to the same dtype as t0, t1 step_ts = upcast_or_raise( self.step_ts, jnp.zeros((), tdtype), - "`PIDController.step_ts`", + "`JumpStepWrapper.step_ts`", "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", ) - else: - step_ts = None - if self.jump_ts is not None: + if self.jump_ts is None: + jump_ts = None + else: # Upcast jump_ts to the same dtype as t0, t1 jump_ts = upcast_or_raise( self.jump_ts, jnp.zeros((), tdtype), - "`PIDController.jump_ts`", + "`JumpStepWrapper.jump_ts`", "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", ) - else: - jump_ts = None - if self.rejected_step_buffer_len > 0: - rjct_buff = jnp.zeros( + if self.rejected_step_buffer_len is None: + rejected_buffer = None + i_reject = jnp.asarray(0) + else: + rejected_buffer = jnp.zeros( (self.rejected_step_buffer_len,) + jnp.shape(t1), dtype=tdtype ) - else: - rjct_buff = None - # rjct_buff[len(rjct_buff)] = jnp.inf (see def of _get_t) - i_rjct = jnp.asarray(self.rejected_step_buffer_len) + # rejected_buffer[len(rejected_buffer)] = jnp.inf (see def of _get_t) + i_reject = jnp.asarray(self.rejected_step_buffer_len) # Find index of first element of step_ts/jump_ts greater than t0 i_step = _find_index(t0, step_ts) @@ -269,8 +281,8 @@ def init( dt_proposal, i_step, i_jump, - i_rjct, - rjct_buff, + i_reject, + rejected_buffer, step_ts, jump_ts, inner_state, @@ -301,8 +313,8 @@ def adapt_step_size( prev_dt, i_step, i_jump, - i_rjct, - rjct_buff, + i_reject, + rejected_buffer, step_ts, jump_ts, inner_state, @@ -324,7 +336,7 @@ def adapt_step_size( if self.callback_on_reject is not None: jax.debug.callback(self.callback_on_reject, keep_step, t1) - # Check whether we stepped over an element of step_ts or jump_ts or rjct_buff + # Check whether we stepped over an element of step_ts/jump_ts/rejected_buffer # This is all still bookkeeping for the PREVIOUS STEP. if step_ts is not None: # If we stepped to `t1 == step_ts[i_step]` and kept the step, then we @@ -337,25 +349,27 @@ def adapt_step_size( jump_inc_cond = keep_step & (t1 >= eqxi.prevbefore(next_jump_t)) i_jump = jnp.where(jump_inc_cond, i_jump + 1, i_jump) - if self.rejected_step_buffer_len > 0: - assert rjct_buff is not None - # If the step ended at t1==rjct_buff[i_rjct], then we have successfully - # stepped to this time and we increment i_rjct. - # We increment i_rjct even if the step was rejected, because we will + if self.rejected_step_buffer_len is not None: + assert rejected_buffer is not None + # If the step ended at t1==rejected_buffer[i_reject], then we have + # successfully stepped to this time and we increment i_reject. + # We increment i_reject even if the step was rejected, because we will # re-add the rejected time to the buffer immediately. - rjct_inc_cond = t1 == _get_t(i_rjct, rjct_buff) - i_rjct = jnp.where(rjct_inc_cond, i_rjct + 1, i_rjct) + rjct_inc_cond = t1 == _get_t(i_reject, rejected_buffer) + i_reject = jnp.where(rjct_inc_cond, i_reject + 1, i_reject) # If the step was rejected, then we need to store the rejected time in the # rejected buffer and decrement the rejected index. - i_rjct = jnp.where(keep_step, i_rjct, i_rjct - 1) - i_rjct = eqx.error_if( - i_rjct, - i_rjct < 0, + i_reject = jnp.where(keep_step, i_reject, i_reject - 1) + i_reject = eqx.error_if( + i_reject, + i_reject < 0, "Maximum number of rejected steps reached. " "Consider increasing JumpStepWrapper.rejected_step_buffer_len.", ) - rjct_buff = jnp.where(keep_step, rjct_buff, rjct_buff.at[i_rjct].set(t1)) + clipped_i = jnp.clip(i_reject, 0, self.rejected_step_buffer_len - 1) + update_rejected_t = jnp.where(keep_step, rejected_buffer[clipped_i], t1) + rejected_buffer = rejected_buffer.at[clipped_i].set(update_rejected_t) # Now move on to the NEXT STEP dt_proposal = next_t1 - next_t0 @@ -391,7 +405,7 @@ def adapt_step_size( # Clip the step to the next element of jump_ts or step_ts or # rejected_buffer. Important to do jump_ts last because otherwise # jump_next_step could be a false positive. - next_t1 = _revisit_rejected(next_t0, next_t1, i_rjct, rjct_buff) + next_t1 = _revisit_rejected(next_t0, next_t1, i_reject, rejected_buffer) next_t1, _ = _clip_ts(next_t0, next_t1, i_step, step_ts, False) next_t1, jump_next_step = _clip_ts(next_t0, next_t1, i_jump, jump_ts, True) @@ -400,8 +414,8 @@ def adapt_step_size( new_prev_dt, i_step, i_jump, - i_rjct, - rjct_buff, + i_reject, + rejected_buffer, step_ts, jump_ts, inner_state, diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/pid.py similarity index 93% rename from diffrax/_step_size_controller/adaptive.py rename to diffrax/_step_size_controller/pid.py index 4af5353e..58949719 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/pid.py @@ -1,7 +1,8 @@ import typing from collections.abc import Callable -from typing import cast, Optional, TYPE_CHECKING, TypeVar +from typing import cast, Optional, TYPE_CHECKING +import equinox as eqx import equinox.internal as eqxi import jax import jax.lax as lax @@ -9,13 +10,13 @@ import jax.tree_util as jtu import lineax.internal as lxi import optimistix as optx +from equinox.internal import ω if TYPE_CHECKING: - from typing import ClassVar as AbstractVar + pass else: - from equinox import AbstractVar -from equinox.internal import ω + pass from jaxtyping import PyTree from lineax.internal import complex_to_real_dtype @@ -29,12 +30,28 @@ ) from .._solution import RESULTS from .._term import AbstractTerm, ODETerm -from .base import AbstractStepSizeController +from .adaptive_base import AbstractAdaptiveStepSizeController +from .jump_step_wrapper import JumpStepWrapper ω = cast(Callable, ω) +# We use a metaclass for backwards compatibility. When a user calls +# PIDController(... step_ts=s, jump_ts=j) this should return a +# JumpStepWrapper(PIDController(...), s, j). +module_meta = type(eqx.Module) + + +class PIDMeta(module_meta): + def __call__(cls, *args, **kwargs): + step_ts = kwargs.pop("step_ts", None) + jump_ts = kwargs.pop("jump_ts", None) + if step_ts is not None or jump_ts is not None: + return JumpStepWrapper(cls(*args, **kwargs), step_ts, jump_ts) + return super().__call__(*args, **kwargs) + + def _select_initial_step( terms: PyTree[AbstractTerm], t0: RealScalarLike, @@ -86,51 +103,10 @@ def intermediate(carry): return jnp.minimum(100 * h0, h1) -_ControllerState = TypeVar("_ControllerState") -_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike]) - - -class AbstractAdaptiveStepSizeController( - AbstractStepSizeController[_ControllerState, _Dt0] -): - """Indicates an adaptive step size controller. - - Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit - solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will - automatically be used as the tolerances for the nonlinear solver passed to the - implicit solver, if they are not specified manually. - """ - - rtol: AbstractVar[RealScalarLike] - atol: AbstractVar[RealScalarLike] - norm: AbstractVar[Callable[[PyTree], RealScalarLike]] - - def __check_init__(self): - if self.rtol is None or self.atol is None: - raise ValueError( - "The default values for `rtol` and `atol` were removed in Diffrax " - "version 0.1.0. (As the choice of tolerance is nearly always " - "something that you, as an end user, should make an explicit choice " - "about.)\n" - "If you want to match the previous defaults then specify " - "`rtol=1e-3`, `atol=1e-6`. For example:\n" - "```\n" - "diffrax.PIDController(rtol=1e-3, atol=1e-6)\n" - "```\n" - ) - - # _PidState = (at_dtmin, prev_inv_scaled_error, prev_prev_inv_scaled_error) _PidState = tuple[BoolScalarLike, RealScalarLike, RealScalarLike] -def _none_or_array(x): - if x is None: - return None - else: - return jnp.asarray(x) - - if TYPE_CHECKING: rms_norm = optx.rms_norm else: @@ -153,7 +129,8 @@ def __repr__(self): # TODO: we don't currently offer a limiter, or a variant accept/reject scheme, as given # in Soderlind and Wang 2006. class PIDController( - AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]] + AbstractAdaptiveStepSizeController[_PidState, Optional[RealScalarLike]], + metaclass=PIDMeta, ): r"""Adapts the step size to produce a solution accurate to a given tolerance. The tolerance is calculated as `atol + rtol * y` for the evolving solution `y`. diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 5fe73d0d..1a3c9b33 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -6,6 +6,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import pytest from jaxtyping import Array from .helpers import tree_allclose @@ -127,13 +128,17 @@ def callback_fun(keep_step, t1): assert 4 in cast(Array, sol.ts) -def test_backprop(): +@pytest.mark.parametrize("use_jump_step", [True, False]) +def test_backprop(use_jump_step): + t0 = jnp.asarray(0, dtype=jnp.float64) + t1 = jnp.asarray(1, dtype=jnp.float64) + @eqx.filter_jit @eqx.filter_grad def run(ys, controller, state): y0, y1_candidate, y_error = ys _, tprev, tnext, _, state, _ = controller.adapt_step_size( - 0, 1, y0, y1_candidate, None, y_error, 5, state + t0, t1, y0, y1_candidate, None, y_error, 5, state ) with jax.numpy_dtype_promotion("standard"): return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) @@ -142,12 +147,16 @@ def run(ys, controller, state): y1_candidate = jnp.array(2.0) term = diffrax.ODETerm(lambda t, y, args: -y) solver = diffrax.Tsit5() - stepsize_controller = diffrax.PIDController(rtol=1e-4, atol=1e-4) - _, state = stepsize_controller.init(term, 0, 1, y0, 0.1, None, solver.func, 5) + controller = diffrax.PIDController(rtol=1e-4, atol=1e-4) + if use_jump_step: + controller = diffrax.JumpStepWrapper( + controller, step_ts=[0.5], rejected_step_buffer_len=20 + ) + _, state = controller.init(term, t0, t1, y0, 0.1, None, solver.func, 5) for y_error in (jnp.array(0.0), jnp.array(3.0), jnp.array(jnp.inf)): ys = (y0, y1_candidate, y_error) - grads = run(ys, stepsize_controller, state) + grads = run(ys, controller, state) assert not any(jnp.isnan(grad).any() for grad in grads) @@ -193,3 +202,17 @@ def forcing(s): finite_diff = (r(0.5) - r(0.5 - eps)) / eps autodiff = jax.jit(jax.grad(run))(0.5) assert tree_allclose(finite_diff, autodiff) + + +def test_pid_meta(): + ts = jnp.array([3, 4], dtype=jnp.float64) + pid1 = diffrax.PIDController(rtol=1e-4, atol=1e-6) + pid2 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts) + pid3 = diffrax.PIDController(rtol=1e-4, atol=1e-6, step_ts=ts, jump_ts=ts) + assert not isinstance(pid1, diffrax.JumpStepWrapper) + assert isinstance(pid1, diffrax.PIDController) + assert isinstance(pid2, diffrax.JumpStepWrapper) + assert isinstance(pid3, diffrax.JumpStepWrapper) + assert all(pid2.step_ts == ts) + assert all(pid3.step_ts == ts) + assert all(pid3.jump_ts == ts)