Skip to content

Commit

Permalink
Making Dynamical Systems Tests Reparametrizeable (#554)
Browse files Browse the repository at this point in the history
* starting to switch to more general, hookable parametrization of dynamics tests.

* slight adjustment to grouping of parametrize args.

* modifies parametrizations to group dependent arguments in tuples, allowing outside test reparams to know what must be reparametrized together.

* fixes simulate_kwargs not being to simulate

* refactors fixtures for more straightforward construction of mock closures on the chirho_diffeqpy side

* separates out test sir param prior.

* parametrizes solver in noop interruptions tests.

* parametrizes by solver for static_observation.

* moves model definition outside of test so it can be reparametrized

* parametrizes solver in static interventions test, adjusts tolerance of state match check to match diffeqpy default tolerances.

* parametrizes dynamic intervention tests, makes more generic array like gather implementation.

* adds additional preliminary assertions to check trajectory length mismatch

* lints

* further linting

* reverts modularization of gather.

* parametrizes test_handler_composition

* reverts tolerance loosening

* lints

* resolves pr change requests.
  • Loading branch information
azane authored Sep 11, 2024
1 parent 8f31b31 commit 0f5dae6
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 188 deletions.
89 changes: 67 additions & 22 deletions tests/dynamical/dynamical_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import Mapping, TypeVar

import pyro
import torch
Expand All @@ -10,30 +10,30 @@

T = TypeVar("T")

ATempParams = Mapping[str, T]

class UnifiedFixtureDynamics(pyro.nn.PyroModule):
def __init__(self, beta=None, gamma=None):
super().__init__()

self.beta = beta
if self.beta is None:
self.beta = pyro.param("beta", torch.tensor(0.5), constraints.positive)
# SIR dynamics written as a pure function of state and parameters.
def pure_sir_dynamics(
state: State[torch.Tensor], atemp_params: ATempParams[torch.Tensor]
) -> State[torch.Tensor]:
beta = atemp_params["beta"]
gamma = atemp_params["gamma"]

self.gamma = gamma
if self.gamma is None:
self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)
dX: State[torch.Tensor] = dict()

def forward(self, X: State[torch.Tensor]):
dX: State[torch.Tensor] = dict()
beta = self.beta * (
1.0 + 0.1 * torch.sin(0.1 * X["t"])
) # beta oscilates slowly in time.
beta = beta * (
1.0 + 0.1 * torch.sin(0.1 * state["t"])
) # beta oscilates slowly in time.

dX["S"] = -beta * state["S"] * state["I"] # noqa
dX["I"] = beta * state["S"] * state["I"] - gamma * state["I"] # noqa
dX["R"] = gamma * state["I"] # noqa

dX["S"] = -beta * X["S"] * X["I"]
dX["I"] = beta * X["S"] * X["I"] - self.gamma * X["I"] # noqa
dX["R"] = self.gamma * X["I"]
return dX
return dX


class SIRObservationMixin:
def _unit_measurement_error(self, name: str, x: torch.Tensor):
if x.ndim == 0:
return pyro.sample(name, Normal(x, 1))
Expand All @@ -47,9 +47,46 @@ def observation(self, X: State[torch.Tensor]):
self._unit_measurement_error("R_obs", X["R"])


def bayes_sir_model():
class SIRReparamObservationMixin(SIRObservationMixin):
def observation(self, X: State[torch.Tensor]):

# A flight arrives in a country that tests all arrivals for a disease. The number of people infected on the
# plane is a noisy function of the number of infected people in the country of origin at that time.
u_ip = pyro.sample(
"u_ip", Normal(7.0, 2.0).expand(X["I"].shape[-1:]).to_event(1)
)
pyro.deterministic("infected_passengers", X["I"] + u_ip, event_dim=1)


class UnifiedFixtureDynamicsBase(pyro.nn.PyroModule):
def __init__(self, beta=None, gamma=None):
super().__init__()

self.beta = beta
if self.beta is None:
self.beta = pyro.param("beta", torch.tensor(0.5), constraints.positive)

self.gamma = gamma
if self.gamma is None:
self.gamma = pyro.param("gamma", torch.tensor(0.7), constraints.positive)

def forward(self, X: State[torch.Tensor]):
atemp_params = dict(beta=self.beta, gamma=self.gamma)
return pure_sir_dynamics(X, atemp_params)


class UnifiedFixtureDynamics(UnifiedFixtureDynamicsBase, SIRObservationMixin):
pass


def sir_param_prior():
beta = pyro.sample("beta", Uniform(0, 1))
gamma = pyro.sample("gamma", Uniform(0, 1))
return beta, gamma


def bayes_sir_model():
beta, gamma = sir_param_prior()
sir = UnifiedFixtureDynamics(beta, gamma)
return sir

Expand All @@ -64,7 +101,8 @@ def check_states_match(state1: State[torch.Tensor], state2: State[torch.Tensor])

for k in state1.keys():
assert torch.allclose(
state1[k], state2[k]
state1[k],
state2[k],
), f"Trajectories differ in state trajectory of variable {k}, but should be identical."

return True
Expand All @@ -77,7 +115,7 @@ def check_trajectories_match_in_all_but_values(

for k in traj1.keys():
assert not torch.allclose(
traj2[k], traj1[k]
traj2[k], traj1[k], atol=1e-6, rtol=1e-3
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."

return True
Expand All @@ -98,3 +136,10 @@ def run_svi_inference_torch_direct(model, n_steps=100, verbose=True, **model_kwa
if (step % 100 == 0) or (step == 1) & verbose:
print("[iteration %04d] loss: %.4f" % (step, loss))
return guide


def build_event_fn_zero_after_tt(tt: torch.Tensor):
def zero_after_tt(t: torch.Tensor, state: State[torch.Tensor]):
return torch.where(t < tt, tt - t, 0.0)

return zero_after_tt
Loading

0 comments on commit 0f5dae6

Please sign in to comment.