diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index b682fcbe..dd6bf9ed 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -149,7 +149,7 @@ def _compute_step( levy: AbstractSpaceTimeLevyArea, x0: UnderdampedLangevinX, v0: UnderdampedLangevinX, - uld_args: UnderdampedLangevinArgs, + underdamped_langevin_args: UnderdampedLangevinArgs, coeffs: _ALIGNCoeffs, rho: UnderdampedLangevinX, prev_f: UnderdampedLangevinX, @@ -163,7 +163,7 @@ def _compute_step( w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) - gamma, u, f = uld_args + gamma, u, f = underdamped_langevin_args uh = (u**ω * h).ω f0 = prev_f diff --git a/diffrax/_solver/foster_langevin_srk.py b/diffrax/_solver/foster_langevin_srk.py index a20b50ce..24627fdb 100644 --- a/diffrax/_solver/foster_langevin_srk.py +++ b/diffrax/_solver/foster_langevin_srk.py @@ -326,7 +326,7 @@ def _compute_step( levy, x0: UnderdampedLangevinX, v0: UnderdampedLangevinX, - uld_args: UnderdampedLangevinArgs, + underdamped_langevin_args: UnderdampedLangevinArgs, coeffs: _Coeffs, rho: UnderdampedLangevinX, prev_f: Optional[UnderdampedLangevinX], diff --git a/diffrax/_solver/quicsort.py b/diffrax/_solver/quicsort.py index 303d2af2..2e6ca897 100644 --- a/diffrax/_solver/quicsort.py +++ b/diffrax/_solver/quicsort.py @@ -5,7 +5,7 @@ import jax import jax.numpy as jnp import jax.tree_util as jtu -from equinox.internal import ω +from equinox.internal import scan_trick, ω from jaxtyping import ArrayLike, PyTree from .._custom_types import ( @@ -193,7 +193,7 @@ def _compute_step( levy: AbstractSpaceTimeTimeLevyArea, x0: UnderdampedLangevinX, v0: UnderdampedLangevinX, - uld_args: UnderdampedLangevinArgs, + underdamped_langevin_args: UnderdampedLangevinArgs, coeffs: _QUICSORTCoeffs, rho: UnderdampedLangevinX, prev_f: Optional[UnderdampedLangevinX], @@ -204,7 +204,7 @@ def _compute_step( hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) - gamma, u, f = uld_args + gamma, u, f = underdamped_langevin_args def _extract_coeffs(coeff, index): return jtu.tree_map(lambda arr: arr[..., index], coeff) @@ -226,12 +226,24 @@ def _extract_coeffs(coeff, index): v_tilde = (v0**ω + rho**ω * (hh**ω + 6 * kk**ω)).ω x1 = (x0**ω + a_l**ω * v_tilde**ω + b_l**ω * rho_w_k**ω).ω - f1uh = (f(x1) ** ω * uh**ω).ω - x2 = ( - x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1uh**ω - ).ω - f2uh = (f(x2) ** ω * uh**ω).ω + # Use eqinox.internal.scan_trick to compute f1, x2 and f2 in one go + # carry = x, f1, f2. We use x0 as the initial value for f1 and f2 + init = x1, x0, x0 + + def fn(carry): + x, _f, _ = carry + fx_uh = (f(x) ** ω * uh**ω).ω + return x, _f, fx_uh + + def compute_x2(carry): + _, _, f1 = carry + x = ( + x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1**ω + ).ω + return x, f1, f1 + + x2, f1uh, f2uh = scan_trick(fn, [compute_x2], init) x_out = ( x0**ω diff --git a/diffrax/_solver/should.py b/diffrax/_solver/should.py index 9955baf6..6d8b0cb0 100644 --- a/diffrax/_solver/should.py +++ b/diffrax/_solver/should.py @@ -1,7 +1,7 @@ import equinox as eqx import jax.numpy as jnp import jax.tree_util as jtu -from equinox.internal import ω +from equinox.internal import scan_trick, ω from jaxtyping import ArrayLike, PyTree from .._custom_types import ( @@ -193,7 +193,7 @@ def _compute_step( levy: AbstractSpaceTimeTimeLevyArea, x0: UnderdampedLangevinX, v0: UnderdampedLangevinX, - uld_args: UnderdampedLangevinArgs, + underdamped_langevin_args: UnderdampedLangevinArgs, coeffs: _ShOULDCoeffs, rho: UnderdampedLangevinX, prev_f: UnderdampedLangevinX, @@ -203,7 +203,9 @@ def _compute_step( hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) - gamma, u, f = uld_args + chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω + + gamma, u, f = underdamped_langevin_args rho_w_k = (rho**ω * (w**ω - 12 * kk**ω)).ω uh = (u**ω * h).ω @@ -215,17 +217,28 @@ def _compute_step( + coeffs.a_half**ω * v1**ω + coeffs.b_half**ω * (-(uh**ω) * f0**ω + rho_w_k**ω) ).ω - f1 = f(x1) - chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω + # Use equinox.internal.scan_trick to compute f1, x_out and f_out in one go + # carry = x, f1, f2. We use x0 as the initial value for f1 and f2 + init = x1, x0, x0 + + def fn(carry): + x, _f, _ = carry + fx = f(x) + return x, _f, fx + + def compute_x2(carry): + _, _, _f1 = carry + x = ( + x0**ω + + coeffs.a1**ω * v0**ω + - uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * _f1**ω) + + rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω) + ).ω + return x, _f1, _f1 + + x_out, f1, f_out = scan_trick(fn, [compute_x2], init) - x_out = ( - x0**ω - + coeffs.a1**ω * v0**ω - - uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * f1**ω) - + rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω) - ).ω - f_out = f(x_out) v_out = ( coeffs.beta1**ω * v0**ω - uh**ω diff --git a/test/test_underdamped_langevin.py b/test/test_underdamped_langevin.py index 9c2f8cb3..9ed12283 100644 --- a/test/test_underdamped_langevin.py +++ b/test/test_underdamped_langevin.py @@ -35,16 +35,7 @@ def _solvers_and_orders(): def get_pytree_uld(t0=0.3, t1=1.0, dtype=jnp.float32): def make_pytree(array_factory): return { - "rr": ( - array_factory((1, 3, 2), dtype), - array_factory( - ( - 3, - 2, - ), - dtype, - ), - ), + "rr": (array_factory((1, 3, 2), dtype), array_factory((3, 2), dtype)), "qq": ( array_factory((1, 2), dtype), array_factory((3,), dtype),