diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c2a693f5..cae0b4ac 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,7 @@ # runs-on: ubuntu-latest # steps: # - name: Release -# uses: patrick-kidger/action_update_python_project@v2 +# uses: patrick-kidger/action_update_python_project@v6 # with: # python-version: "3.11" # test-script: | @@ -21,7 +21,3 @@ # pypi-token: ${{ secrets.pypi_token }} # github-user: lockwo # github-token: ${{ github.token }} -# email-user: ${{ secrets.email_user }} -# email-token: ${{ secrets.email_token }} -# email-server: ${{ secrets.email_server }} -# email-target: ${{ secrets.email_target }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 967cb054..f12e8509 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: rev: v1.1.350 hooks: - id: pyright - additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions, typeguard] \ No newline at end of file + additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions] diff --git a/diffrax_extensions/_integrate.py b/diffrax_extensions/_integrate.py index 8f8966c5..c84f24c4 100644 --- a/diffrax_extensions/_integrate.py +++ b/diffrax_extensions/_integrate.py @@ -775,9 +775,60 @@ def _save_t1(subsaveat, save_state): save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state) return save_state + def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: + if subsaveat.ts is not None: + out_size = 1 if subsaveat.t0 else 0 + out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0 + out_size += len(subsaveat.ts) + + def _make_ys(out, old_outs): + outs = jnp.stack([out] * out_size) + if subsaveat.steps: + outs = jnp.concatenate( + [ + outs, + jnp.full( + (max_steps,) + out.shape, jnp.inf, dtype=out.dtype + ), + ] + ) + assert outs.shape == old_outs.shape + return outs + + ts = jnp.full(out_size, t0) + if subsaveat.steps: + ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf, dtype=ts.dtype))) + assert ts.shape == save_state.ts.shape + ys = jtu.tree_map(_make_ys, subsaveat.fn(t0, yfinal, args), save_state.ys) + save_state = SaveState( + saveat_ts_index=out_size, + ts=ts, + ys=ys, + save_index=out_size, + ) + return save_state + save_state = jtu.tree_map( _save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat ) + + # if t0 == t1 then we don't enter the integration loop. In this case we have to + # manually update the saved ts and ys if we want to save at "intermediate" + # times specified by saveat.subs.ts + save_state = jax.lax.cond( + eqxi.unvmap_any(t0 == t1), + lambda __save_state: jax.lax.cond( + t0 == t1, + lambda _save_state: jtu.tree_map( + _save_if_t0_equals_t1, saveat.subs, _save_state, is_leaf=_is_subsaveat + ), + lambda _save_state: _save_state, + __save_state, + ), + lambda __save_state: __save_state, + save_state, + ) + final_state = eqx.tree_at( lambda s: s.save_state, final_state, save_state, is_leaf=_is_none ) diff --git a/diffrax_extensions/_local_interpolation.py b/diffrax_extensions/_local_interpolation.py index 0098a059..29a8eb9e 100644 --- a/diffrax_extensions/_local_interpolation.py +++ b/diffrax_extensions/_local_interpolation.py @@ -110,7 +110,7 @@ def __init__( ): def _calculate(_y0, _y1, _k): with jax.numpy_dtype_promotion("standard"): - _ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1) + _ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1).astype(_y0.dtype) _f0 = _k[0] _f1 = _k[-1] # TODO: rewrite as matrix-vector product? diff --git a/diffrax_extensions/_misc.py b/diffrax_extensions/_misc.py index d38b3d7c..ff92e0cd 100644 --- a/diffrax_extensions/_misc.py +++ b/diffrax_extensions/_misc.py @@ -151,7 +151,7 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike # This in turn allows us to perform some trace-time optimisations that XLA isn't # smart enough to do on its own. if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == (): - pred = pred.item() + pred = pred.item() # pyright: ignore if pred is True: return a elif pred is False: diff --git a/diffrax_extensions/_progress_meter.py b/diffrax_extensions/_progress_meter.py index 15485ffb..a4bcd528 100644 --- a/diffrax_extensions/_progress_meter.py +++ b/diffrax_extensions/_progress_meter.py @@ -123,7 +123,7 @@ def _step_bar(bar: list[float], progress: FloatScalarLike) -> None: if eqx.is_array(progress): # May not be an array when called with `JAX_DISABLE_JIT=1` progress = cast(Union[Array, np.ndarray], progress) - progress = progress.item() + progress = progress.item() # pyright: ignore progress = cast(float, progress) bar[0] = progress print(f"{100 * progress:.2f}%") diff --git a/diffrax_extensions/_solver/runge_kutta.py b/diffrax_extensions/_solver/runge_kutta.py index 10b3152e..695e3310 100644 --- a/diffrax_extensions/_solver/runge_kutta.py +++ b/diffrax_extensions/_solver/runge_kutta.py @@ -964,7 +964,7 @@ def eval_k_jac(): assert implicit_tableau.a_diagonal[0] == 0 # pyright: ignore assert len(set(implicit_tableau.a_diagonal[1:])) == 1 # pyright: ignore jac_stage_index = 1 - stage_index = eqxi.nonbatchable(stage_index) + stage_index = eqxi.nonbatchable(stage_index) # These `stop_gradients` are needed to work around the lack of # symbolic zeros in `custom_vjp`s. if eval_fs: diff --git a/docs/api/terms.md b/docs/api/terms.md index 80a14696..ee43a5db 100644 --- a/docs/api/terms.md +++ b/docs/api/terms.md @@ -127,3 +127,38 @@ where `bm` is an [`diffrax_extensions.AbstractBrownianPath`][] and the same valu selection: members: - __init__ + + +--- + +#### Underdamped Langevin terms + +These are special terms which describe the Underdamped Langevin diffusion (ULD), +which takes the form + +\begin{align*} + \mathrm{d} x(t) &= v(t) \, \mathrm{d}t \\ + \mathrm{d} v(t) &= - \gamma \, v(t) \, \mathrm{d}t - u \, + \nabla \! f( x(t) ) \, \mathrm{d}t + \sqrt{2 \gamma u} \, \mathrm{d} w(t), +\end{align*} + +where $x(t), v(t) \in \mathbb{R}^d$ represent the position +and velocity, $w$ is a Brownian motion in $\mathbb{R}^d$, +$f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and +$\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing +the friction and the damping of the system. + +These terms enable the use of ULD-specific solvers which can be found +[here](./solvers/sde_solvers.md#underdamped-langevin-solvers). Note that these ULD solvers will only work if given +terms with structure `MultiTerm(UnderdampedLangevinDriftTerm(gamma, u, grad_f), UnderdampedLangevinDiffusionTerm(gamma, u, bm))`, +where `bm` is an [`diffrax_extensions.AbstractBrownianPath`][] and the same values of `gammma` and `u` are passed to both terms. + +::: diffrax_extensions.UnderdampedLangevinDriftTerm + selection: + members: + - __init__ + +::: diffrax_extensions.UnderdampedLangevinDiffusionTerm + selection: + members: + - __init__ \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index e0fe8e61..b033b3e3 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -10,6 +10,8 @@ jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible w nbconvert==6.5.0 # | Older verson to avoid error nbformat==5.4.0 # | pygments==2.14.0 +mkdocs-autorefs==1.0.1 +mkdocs-material-extensions==1.3.1 # Install latest version of our dependencies jax[cpu] diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index fb9c88f3..1a51eff0 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -96,6 +96,11 @@ In this case the Itô solution and the Stratonovich solution coincide, and mathe The Underdamped Langevin Diffusion is a special case of an SDE with additive noise. For details on the form of this SDE and appropriate solvers, please refer to the section on [Underdamped Langevin solvers](../api/solvers/sde_solvers.md#underdamped-langevin-solvers). +### Underdamped Langevin Diffusion + +The Underdamped Langevin Diffusion is a special case of an SDE with additive noise. +For details on the form of this SDE and appropriate solvers, please refer to the section on [Underdamped Langevin solvers](../api/solvers/sde_solvers.md#underdamped-langevin-solvers). + --- ## Controlled differential equations diff --git a/pyproject.toml b/pyproject.toml index 0d1e2d21..3030afa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "diffrax_extensions" -version = "0.6.0" +version = "0.6.1" description = "GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX." readme = "README.md" requires-python ="~=3.9" diff --git a/test/test_adjoint.py b/test/test_adjoint.py index e9a06f9a..48a0958b 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -4,7 +4,7 @@ import diffrax_extensions as diffrax import equinox as eqx import jax -import jax.interpreters.ad +import jax._src.interpreters.ad import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu @@ -21,6 +21,7 @@ class _VectorField(eqx.Module): diff_arg: float def __call__(self, t, y, args): + del t assert y.shape == (2,) diff_arg, nondiff_arg = args dya = diff_arg * y[0] + nondiff_arg * y[1] @@ -29,7 +30,7 @@ def __call__(self, t, y, args): @pytest.mark.slow -def test_against(getkey): +def test_against(): y0 = jnp.array([0.9, 5.4]) args = (0.1, -1) term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)) @@ -215,6 +216,7 @@ def test_closure_errors(): @eqx.filter_value_and_grad def run(model): def f(t, y, args): + del t, args return model(y) sol = diffrax.diffeqsolve( @@ -228,7 +230,7 @@ def f(t, y, args): ) return jnp.sum(cast(Array, sol.ys)) - with pytest.raises(jax.interpreters.ad.CustomVJPException): + with pytest.raises(jax._src.interpreters.ad.CustomVJPException): run(mlp) @@ -239,6 +241,7 @@ class VectorField(eqx.Module): model: Callable def __call__(self, t, y, args): + del t, args return self.model(y) @eqx.filter_jit @@ -307,12 +310,12 @@ def make_step(model, opt_state, target_steady_state): model = eqx.apply_updates(model, updates) return model, opt_state - for step in range(100): + for _ in range(100): model, opt_state = make_step(model, opt_state, target_steady_state) assert tree_allclose(model.steady_state, target_steady_state, rtol=1e-2, atol=1e-2) -def test_backprop_ts(getkey): +def test_backprop_ts(): mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0)) @eqx.filter_jit @@ -338,14 +341,17 @@ def run(model): ) def test_sde_against(diffusion_fn, getkey): def f(t, y, args): + del t k0, _ = args return -k0 * y def g(t, y, args): + del t _, k1 = args return k1 * y def g_lx(t, y, args): + del t _, k1 = args return lx.DiagonalLinearOperator(k1 * y) diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index d0dee616..3656d546 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -340,7 +340,7 @@ def _test_dense_interpolation(solver, key, t1): @pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) -def test_dense_interpolation(solver, getkey): +def test_dense_interpolation(solver): solver = implicit_tol(solver) key = jr.PRNGKey(5678) vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1) diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 09847dc1..cd6315c7 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -147,6 +147,114 @@ def test_saveat_solution(): assert sol.result == diffrax.RESULTS.successful +@pytest.mark.parametrize("subs", [True, False]) +def test_t0_eq_t1(subs): + y0 = jnp.array([2.0]) + ts = jnp.linspace(1.0, 1.0, 3) + max_steps = 256 + if subs: + get0 = diffrax.SubSaveAt( + ts=ts, + t1=True, + ) + get1 = diffrax.SubSaveAt( + t0=True, + ts=ts, + ) + get2 = diffrax.SubSaveAt( + t0=True, + ts=ts, + steps=True, + ) + subs = (get0, get1, get2) + saveat = diffrax.SaveAt(subs=subs) + else: + saveat = diffrax.SaveAt(t0=True, t1=True, ts=ts) + term = diffrax.ODETerm(lambda t, y, args: y) + sol = diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=y0, + dt0=0.1, + solver=diffrax.Dopri5(), + saveat=saveat, + max_steps=max_steps, + ) + if subs: + compare = jnp.full((len(ts) + 1, *y0.shape), y0) + compare_2 = jnp.concatenate( + (compare, jnp.full((max_steps, *y0.shape), jnp.inf)) + ) + ya, yb, yc = sol.ys # pyright: ignore[reportGeneralTypeIssues] + assert tree_allclose(ya, compare) + assert tree_allclose(yb, compare) + assert tree_allclose(yc, compare_2) + else: + compare = jnp.full((len(ts) + 2, *y0.shape), y0) + assert tree_allclose(sol.ys, compare) + + +def test_t0_eq_t1_complicated(): + """This test case also checks: + + - vmap'ing + - non-float32 dtypes + - `SubSaveAt(fn=...)` + """ + ntsave = 4 + dtype = jnp.float16 + y0 = jnp.array([2.0], dtype=dtype) + term = diffrax.ODETerm(lambda t, y, args: y) + + def _solve(tf): + ts = jnp.linspace(0.0, tf, ntsave, dtype=dtype) + get0 = diffrax.SubSaveAt( + ts=ts, + t1=True, + ) + get1 = diffrax.SubSaveAt( + t0=True, + ts=ts, + ) + get2 = diffrax.SubSaveAt( + t0=True, + ts=ts, + steps=True, + fn=lambda t, y, args: jnp.where(jnp.isinf(y), 3.0, 4.0), + ) + subs = (get0, get1, get2) + saveat = diffrax.SaveAt(subs=subs) + sol = diffrax.diffeqsolve( + term, + t0=ts[0], + t1=ts[-1], + y0=y0, + dt0=0.1, + solver=diffrax.Dopri5(), + saveat=saveat, + max_steps=15, + ) + return sol.ys + + compare = jnp.full((ntsave + 1, *y0.shape), y0, dtype=dtype) + compare_c = jnp.concatenate( + [ + jnp.full((ntsave + 1, *y0.shape), 4.0, dtype=jnp.float64), + jnp.full((15, *y0.shape), jnp.inf, dtype=jnp.float64), + ] + ) + true_ya, true_yb, true_yc = _solve(1.0) # pyright: ignore[reportGeneralTypeIssues] + true_ya = jnp.stack([compare, true_ya]) + true_yb = jnp.stack([compare, true_yb]) + true_yc = jnp.stack([compare_c, true_yc]) + + ya, yb, yc = jax.vmap(_solve)(jnp.array([0.0, 1.0])) # pyright: ignore[reportGeneralTypeIssues] + assert tree_allclose(ya, true_ya) + assert tree_allclose(yb, true_yb) + assert tree_allclose(yc, true_yc) + + def test_trivial_dense(): term = diffrax.ODETerm(lambda t, y, args: -0.5 * y) y0 = jnp.array([2.1])