Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Owen/update #9

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# runs-on: ubuntu-latest
# steps:
# - name: Release
# uses: patrick-kidger/action_update_python_project@v2
# uses: patrick-kidger/action_update_python_project@v6
# with:
# python-version: "3.11"
# test-script: |
Expand All @@ -21,7 +21,3 @@
# pypi-token: ${{ secrets.pypi_token }}
# github-user: lockwo
# github-token: ${{ github.token }}
# email-user: ${{ secrets.email_user }}
# email-token: ${{ secrets.email_token }}
# email-server: ${{ secrets.email_server }}
# email-target: ${{ secrets.email_target }}
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions, typeguard]
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions]
51 changes: 51 additions & 0 deletions diffrax_extensions/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,60 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
out_size = 1 if subsaveat.t0 else 0
out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0
out_size += len(subsaveat.ts)

def _make_ys(out, old_outs):
outs = jnp.stack([out] * out_size)
if subsaveat.steps:
outs = jnp.concatenate(
[
outs,
jnp.full(
(max_steps,) + out.shape, jnp.inf, dtype=out.dtype
),
]
)
assert outs.shape == old_outs.shape
return outs

ts = jnp.full(out_size, t0)
if subsaveat.steps:
ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf, dtype=ts.dtype)))
assert ts.shape == save_state.ts.shape
ys = jtu.tree_map(_make_ys, subsaveat.fn(t0, yfinal, args), save_state.ys)
save_state = SaveState(
saveat_ts_index=out_size,
ts=ts,
ys=ys,
save_index=out_size,
)
return save_state

save_state = jtu.tree_map(
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
)

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
eqxi.unvmap_any(t0 == t1),
lambda __save_state: jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_if_t0_equals_t1, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
__save_state,
),
lambda __save_state: __save_state,
save_state,
)

final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
)
Expand Down
2 changes: 1 addition & 1 deletion diffrax_extensions/_local_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
):
def _calculate(_y0, _y1, _k):
with jax.numpy_dtype_promotion("standard"):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1).astype(_y0.dtype)
_f0 = _k[0]
_f1 = _k[-1]
# TODO: rewrite as matrix-vector product?
Expand Down
2 changes: 1 addition & 1 deletion diffrax_extensions/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
# This in turn allows us to perform some trace-time optimisations that XLA isn't
# smart enough to do on its own.
if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == ():
pred = pred.item()
pred = pred.item() # pyright: ignore
if pred is True:
return a
elif pred is False:
Expand Down
2 changes: 1 addition & 1 deletion diffrax_extensions/_progress_meter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _step_bar(bar: list[float], progress: FloatScalarLike) -> None:
if eqx.is_array(progress):
# May not be an array when called with `JAX_DISABLE_JIT=1`
progress = cast(Union[Array, np.ndarray], progress)
progress = progress.item()
progress = progress.item() # pyright: ignore
progress = cast(float, progress)
bar[0] = progress
print(f"{100 * progress:.2f}%")
Expand Down
2 changes: 1 addition & 1 deletion diffrax_extensions/_solver/runge_kutta.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def eval_k_jac():
assert implicit_tableau.a_diagonal[0] == 0 # pyright: ignore
assert len(set(implicit_tableau.a_diagonal[1:])) == 1 # pyright: ignore
jac_stage_index = 1
stage_index = eqxi.nonbatchable(stage_index)
stage_index = eqxi.nonbatchable(stage_index)
# These `stop_gradients` are needed to work around the lack of
# symbolic zeros in `custom_vjp`s.
if eval_fs:
Expand Down
35 changes: 35 additions & 0 deletions docs/api/terms.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,38 @@ where `bm` is an [`diffrax_extensions.AbstractBrownianPath`][] and the same valu
selection:
members:
- __init__


---

#### Underdamped Langevin terms

These are special terms which describe the Underdamped Langevin diffusion (ULD),
which takes the form

\begin{align*}
\mathrm{d} x(t) &= v(t) \, \mathrm{d}t \\
\mathrm{d} v(t) &= - \gamma \, v(t) \, \mathrm{d}t - u \,
\nabla \! f( x(t) ) \, \mathrm{d}t + \sqrt{2 \gamma u} \, \mathrm{d} w(t),
\end{align*}

where $x(t), v(t) \in \mathbb{R}^d$ represent the position
and velocity, $w$ is a Brownian motion in $\mathbb{R}^d$,
$f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and
$\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing
the friction and the damping of the system.

These terms enable the use of ULD-specific solvers which can be found
[here](./solvers/sde_solvers.md#underdamped-langevin-solvers). Note that these ULD solvers will only work if given
terms with structure `MultiTerm(UnderdampedLangevinDriftTerm(gamma, u, grad_f), UnderdampedLangevinDiffusionTerm(gamma, u, bm))`,
where `bm` is an [`diffrax_extensions.AbstractBrownianPath`][] and the same values of `gammma` and `u` are passed to both terms.

::: diffrax_extensions.UnderdampedLangevinDriftTerm
selection:
members:
- __init__

::: diffrax_extensions.UnderdampedLangevinDiffusionTerm
selection:
members:
- __init__
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible w
nbconvert==6.5.0 # | Older verson to avoid error
nbformat==5.4.0 # |
pygments==2.14.0
mkdocs-autorefs==1.0.1
mkdocs-material-extensions==1.3.1

# Install latest version of our dependencies
jax[cpu]
5 changes: 5 additions & 0 deletions docs/usage/how-to-choose-a-solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ In this case the Itô solution and the Stratonovich solution coincide, and mathe
The Underdamped Langevin Diffusion is a special case of an SDE with additive noise.
For details on the form of this SDE and appropriate solvers, please refer to the section on [Underdamped Langevin solvers](../api/solvers/sde_solvers.md#underdamped-langevin-solvers).

### Underdamped Langevin Diffusion

The Underdamped Langevin Diffusion is a special case of an SDE with additive noise.
For details on the form of this SDE and appropriate solvers, please refer to the section on [Underdamped Langevin solvers](../api/solvers/sde_solvers.md#underdamped-langevin-solvers).

---

## Controlled differential equations
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "diffrax_extensions"
version = "0.6.0"
version = "0.6.1"
description = "GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX."
readme = "README.md"
requires-python ="~=3.9"
Expand Down
16 changes: 11 additions & 5 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import diffrax_extensions as diffrax
import equinox as eqx
import jax
import jax.interpreters.ad
import jax._src.interpreters.ad
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
Expand All @@ -21,6 +21,7 @@ class _VectorField(eqx.Module):
diff_arg: float

def __call__(self, t, y, args):
del t
assert y.shape == (2,)
diff_arg, nondiff_arg = args
dya = diff_arg * y[0] + nondiff_arg * y[1]
Expand All @@ -29,7 +30,7 @@ def __call__(self, t, y, args):


@pytest.mark.slow
def test_against(getkey):
def test_against():
y0 = jnp.array([0.9, 5.4])
args = (0.1, -1)
term = diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1))
Expand Down Expand Up @@ -215,6 +216,7 @@ def test_closure_errors():
@eqx.filter_value_and_grad
def run(model):
def f(t, y, args):
del t, args
return model(y)

sol = diffrax.diffeqsolve(
Expand All @@ -228,7 +230,7 @@ def f(t, y, args):
)
return jnp.sum(cast(Array, sol.ys))

with pytest.raises(jax.interpreters.ad.CustomVJPException):
with pytest.raises(jax._src.interpreters.ad.CustomVJPException):
run(mlp)


Expand All @@ -239,6 +241,7 @@ class VectorField(eqx.Module):
model: Callable

def __call__(self, t, y, args):
del t, args
return self.model(y)

@eqx.filter_jit
Expand Down Expand Up @@ -307,12 +310,12 @@ def make_step(model, opt_state, target_steady_state):
model = eqx.apply_updates(model, updates)
return model, opt_state

for step in range(100):
for _ in range(100):
model, opt_state = make_step(model, opt_state, target_steady_state)
assert tree_allclose(model.steady_state, target_steady_state, rtol=1e-2, atol=1e-2)


def test_backprop_ts(getkey):
def test_backprop_ts():
mlp = eqx.nn.MLP(1, 1, 8, 2, key=jr.PRNGKey(0))

@eqx.filter_jit
Expand All @@ -338,14 +341,17 @@ def run(model):
)
def test_sde_against(diffusion_fn, getkey):
def f(t, y, args):
del t
k0, _ = args
return -k0 * y

def g(t, y, args):
del t
_, k1 = args
return k1 * y

def g_lx(t, y, args):
del t
_, k1 = args
return lx.DiagonalLinearOperator(k1 * y)

Expand Down
2 changes: 1 addition & 1 deletion test/test_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _test_dense_interpolation(solver, key, t1):


@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers)
def test_dense_interpolation(solver, getkey):
def test_dense_interpolation(solver):
solver = implicit_tol(solver)
key = jr.PRNGKey(5678)
vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1)
Expand Down
108 changes: 108 additions & 0 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,114 @@ def test_saveat_solution():
assert sol.result == diffrax.RESULTS.successful


@pytest.mark.parametrize("subs", [True, False])
def test_t0_eq_t1(subs):
y0 = jnp.array([2.0])
ts = jnp.linspace(1.0, 1.0, 3)
max_steps = 256
if subs:
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
get2 = diffrax.SubSaveAt(
t0=True,
ts=ts,
steps=True,
)
subs = (get0, get1, get2)
saveat = diffrax.SaveAt(subs=subs)
else:
saveat = diffrax.SaveAt(t0=True, t1=True, ts=ts)
term = diffrax.ODETerm(lambda t, y, args: y)
sol = diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
max_steps=max_steps,
)
if subs:
compare = jnp.full((len(ts) + 1, *y0.shape), y0)
compare_2 = jnp.concatenate(
(compare, jnp.full((max_steps, *y0.shape), jnp.inf))
)
ya, yb, yc = sol.ys # pyright: ignore[reportGeneralTypeIssues]
assert tree_allclose(ya, compare)
assert tree_allclose(yb, compare)
assert tree_allclose(yc, compare_2)
else:
compare = jnp.full((len(ts) + 2, *y0.shape), y0)
assert tree_allclose(sol.ys, compare)


def test_t0_eq_t1_complicated():
"""This test case also checks:

- vmap'ing
- non-float32 dtypes
- `SubSaveAt(fn=...)`
"""
ntsave = 4
dtype = jnp.float16
y0 = jnp.array([2.0], dtype=dtype)
term = diffrax.ODETerm(lambda t, y, args: y)

def _solve(tf):
ts = jnp.linspace(0.0, tf, ntsave, dtype=dtype)
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
get2 = diffrax.SubSaveAt(
t0=True,
ts=ts,
steps=True,
fn=lambda t, y, args: jnp.where(jnp.isinf(y), 3.0, 4.0),
)
subs = (get0, get1, get2)
saveat = diffrax.SaveAt(subs=subs)
sol = diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
max_steps=15,
)
return sol.ys

compare = jnp.full((ntsave + 1, *y0.shape), y0, dtype=dtype)
compare_c = jnp.concatenate(
[
jnp.full((ntsave + 1, *y0.shape), 4.0, dtype=jnp.float64),
jnp.full((15, *y0.shape), jnp.inf, dtype=jnp.float64),
]
)
true_ya, true_yb, true_yc = _solve(1.0) # pyright: ignore[reportGeneralTypeIssues]
true_ya = jnp.stack([compare, true_ya])
true_yb = jnp.stack([compare, true_yb])
true_yc = jnp.stack([compare_c, true_yc])

ya, yb, yc = jax.vmap(_solve)(jnp.array([0.0, 1.0])) # pyright: ignore[reportGeneralTypeIssues]
assert tree_allclose(ya, true_ya)
assert tree_allclose(yb, true_yb)
assert tree_allclose(yc, true_yc)


def test_trivial_dense():
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
y0 = jnp.array([2.1])
Expand Down
Loading