Skip to content

Commit

Permalink
Modify observables computations to account for dependence on paramete…
Browse files Browse the repository at this point in the history
…rs that change via intervention (#631)

* fix issue of parameter interventions not making their way to observables

* another pass at the implementation

* fix

* lint

* unsqueeze time to handle shape errors

* adding failing test for observables defined by parameter values (#632)

* set initial_observables before running simulate

* lint

* add reset

* remove print

---------

Co-authored-by: sabinala <[email protected]>
  • Loading branch information
SamWitty and sabinala authored Nov 21, 2024
1 parent 1ce34fd commit dc96430
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
62 changes: 57 additions & 5 deletions pyciemss/compiled_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import Callable, Dict, Optional, Tuple, TypeVar, Union
from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union

import mira
import mira.metamodel
Expand All @@ -11,6 +11,7 @@
import pyro
import torch
from chirho.dynamical.handlers import LogTrajectory
from chirho.dynamical.internals._utils import _squeeze_time_dim
from chirho.dynamical.ops import State, simulate

S = TypeVar("S")
Expand Down Expand Up @@ -121,7 +122,7 @@ def forward(
self.instantiate_parameters()

if logging_times is not None:
with LogTrajectory(logging_times) as lt:
with LogObservables(logging_times, self) as lo:
try:
simulate(self.deriv, self.initial_state(), start_time, end_time)
except AssertionError as e:
Expand All @@ -135,11 +136,11 @@ def forward(
else:
raise e

state = lt.trajectory
state = lo.state
observables = lo.observables
else:
state = simulate(self.deriv, self.initial_state(), start_time, end_time)

observables = self.observables(state)
observables = self.observables(state)

if is_traced:
# Add the observables to the trace so that they can be accessed later.
Expand Down Expand Up @@ -229,3 +230,54 @@ def get_name(obj) -> str:
@get_name.register
def _get_name_str(name: str) -> str:
return name


class LogObservables(pyro.poutine.messenger.Messenger):
def __init__(self, times: torch.Tensor, model: CompiledDynamics):
super().__init__()
self.model = model
self.observables: State[torch.Tensor] = {}
self.state: State[torch.Tensor] = {}

# This gets around the issue of the LogTrajectory handler blocking `self`
self.lt = LogTrajectory(times)

self.reset()

def reset(self):
self._observables_names: List[str] = []
self._initial_observables: State[torch.Tensor] = {}

def _pyro_simulate_point(self, msg):
self.lt._pyro_simulate_point(msg)

def _pyro_post_simulate_trajectory(self, msg):
observables = self.model.observables(msg["value"])
self._observables_names = list(observables.keys())
msg["value"] = {**msg["value"], **observables}

def _pyro_simulate(self, msg):
self._initial_observables = _squeeze_time_dim(
self.model.observables(msg["args"][1])
)

def _pyro_post_simulate(self, msg):
msg["args"] = (
msg["args"][0],
{**msg["args"][1], **self._initial_observables},
msg["args"][2],
msg["args"][3],
)

self.lt._pyro_post_simulate(msg)

self.observables = {
k: v for k, v in self.lt.trajectory.items() if k in self._observables_names
}
self.state = {
k: v
for k, v in self.lt.trajectory.items()
if k not in self._observables_names
}

self.reset()
38 changes: 38 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,3 +826,41 @@ def test_intervention_on_constant_param(
torch.arange(start_time, end_time + logging_step_size, logging_step_size)
)
assert processed_result.shape[1] >= 2


@pytest.mark.parametrize("sample_method", [sample])
@pytest.mark.parametrize("model_fixture", MODELS)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
@pytest.mark.parametrize("start_time", START_TIMES)
def test_observables_change_with_interventions(
sample_method,
model_fixture,
end_time,
logging_step_size,
num_samples,
start_time,
):
# Assert that sample returns expected result with intervention on constant parameter
if "SIR_param" not in model_fixture.url:
pytest.skip("Only test 'SIR_param_in_obs' model")
else:
processed_result = sample_method(
model_fixture.url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions={
torch.tensor(2.0): {"beta": torch.tensor(0.001)}
},
)["data"]

# The test will fail if values before and after the intervention are the same
assert (
processed_result["beta_param_observable_state"][0]
> processed_result["beta_param_observable_state"][
int(end_time / logging_step_size)
]
)

0 comments on commit dc96430

Please sign in to comment.