Skip to content

Commit

Permalink
Made JumpStepWrapper backwards compatible with PIDController and some…
Browse files Browse the repository at this point in the history
… nits
  • Loading branch information
andyElking committed Aug 18, 2024
1 parent 345e23a commit c3c4dcf
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 156 deletions.
74 changes: 31 additions & 43 deletions benchmarks/jump_step_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_terms(key):
new_controller = diffrax.JumpStepWrapper(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=0,
rejected_step_buffer_len=None,
)
old_controller = OldPIDController(
rtol=0, atol=1e-3, dtmin=2**-9, dtmax=1.0, pcoeff=0.3, icoeff=0.7, step_ts=step_ts
Expand All @@ -66,63 +66,51 @@ def solve(key, controller):
keys = jr.split(jr.PRNGKey(0), num_samples)


# NEW CONTROLLER
@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
def time_new_controller_fun():
sols = solve(keys, new_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys
def do_timing(controller):
@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
def time_controller_fun():
sols = solve(keys, controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys

def time_controller():
jax.block_until_ready(time_controller_fun())

def time_new_controller():
jax.block_until_ready(time_new_controller_fun())
return min(timeit.repeat(time_controller, number=3, repeat=20))


# OLD CONTROLLER
@jax.jit
@eqx.debug.assert_max_traces(max_traces=1)
def time_old_controller_fun():
sols = solve(keys, old_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys
time_new = do_timing(new_controller)


def time_old_controller():
jax.block_until_ready(time_old_controller_fun())


time_new = min(timeit.repeat(time_new_controller, number=3, repeat=20))

time_old = min(timeit.repeat(time_old_controller, number=3, repeat=20))
time_old = do_timing(old_controller)

print(f"New controller: {time_new:.5} s, Old controller: {time_old:.5} s")

# How expensive is revisiting rejected steps?
new_revisiting_controller = diffrax.JumpStepWrapper(
revisiting_controller_short = diffrax.JumpStepWrapper(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=10,
)

revisiting_controller_long = diffrax.JumpStepWrapper(
pid_controller,
step_ts=step_ts,
rejected_step_buffer_len=4096,
)

def time_revisiting_controller_fun():
sols = solve(keys, new_revisiting_controller)
assert sols.ys is not None
assert sols.ys.shape == (num_samples, len(step_ts))
return sols.ys


def time_revisiting_controller():
jax.block_until_ready(time_revisiting_controller_fun())


time_revisiting = min(timeit.repeat(time_revisiting_controller, number=3, repeat=20))
time_revisiting_short = do_timing(revisiting_controller_short)
time_revisiting_long = do_timing(revisiting_controller_long)

print(f"Revisiting controller: {time_revisiting:.5} s")
print(
f"Revisiting controller\n"
f"with buffer len 10: {time_revisiting_short:.5} s\n"
f"with buffer len 4096: {time_revisiting_long:.5} s"
)

# ======= RESULTS =======
# New controller: 0.22829 s, Old controller: 0.31039 s
# Revisiting controller: 0.23212 s
# New controller: 0.23506 s, Old controller: 0.30735 s
# Revisiting controller
# with buffer len 10: 0.23636 s
# with buffer len 4096: 0.23965 s
3 changes: 2 additions & 1 deletion benchmarks/old_pid_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.tree_util as jtu
import lineax.internal as lxi
import optimistix as optx
from diffrax import AbstractAdaptiveStepSizeController, AbstractTerm, ODETerm, RESULTS
from diffrax import AbstractTerm, ODETerm, RESULTS
from diffrax._custom_types import (
Args,
BoolScalarLike,
Expand All @@ -19,6 +19,7 @@
Y,
)
from diffrax._misc import static_select, upcast_or_raise
from diffrax._step_size_controller import AbstractAdaptiveStepSizeController
from equinox.internal import ω
from jaxtyping import Array, PyTree, Real
from lineax.internal import complex_to_real_dtype
Expand Down
6 changes: 4 additions & 2 deletions diffrax/_step_size_controller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .adaptive import (
from .adaptive_base import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
PIDController as PIDController,
)
from .base import AbstractStepSizeController as AbstractStepSizeController
from .constant import ConstantStepSize as ConstantStepSize, StepTo as StepTo
from .jump_step_wrapper import JumpStepWrapper as JumpStepWrapper
from .pid import (
PIDController as PIDController,
)
42 changes: 42 additions & 0 deletions diffrax/_step_size_controller/adaptive_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections.abc import Callable
from typing import Optional, TypeVar

from equinox import AbstractVar
from jaxtyping import PyTree

from .._custom_types import RealScalarLike
from .base import AbstractStepSizeController


_ControllerState = TypeVar("_ControllerState")
_Dt0 = TypeVar("_Dt0", None, RealScalarLike, Optional[RealScalarLike])


class AbstractAdaptiveStepSizeController(
AbstractStepSizeController[_ControllerState, _Dt0]
):
"""Indicates an adaptive step size controller.
Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit
solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will
automatically be used as the tolerances for the nonlinear solver passed to the
implicit solver, if they are not specified manually.
"""

rtol: AbstractVar[RealScalarLike]
atol: AbstractVar[RealScalarLike]
norm: AbstractVar[Callable[[PyTree], RealScalarLike]]

def __check_init__(self):
if self.rtol is None or self.atol is None:
raise ValueError(
"The default values for `rtol` and `atol` were removed in Diffrax "
"version 0.1.0. (As the choice of tolerance is nearly always "
"something that you, as an end user, should make an explicit choice "
"about.)\n"
"If you want to match the previous defaults then specify "
"`rtol=1e-3`, `atol=1e-6`. For example:\n"
"```\n"
"diffrax.PIDController(rtol=1e-3, atol=1e-6)\n"
"```\n"
)
Loading

0 comments on commit c3c4dcf

Please sign in to comment.