From dc96430a693e5abf8e2555e0acee59901dea26b1 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 21 Nov 2024 13:48:11 -0500 Subject: [PATCH] Modify observables computations to account for dependence on parameters that change via intervention (#631) * fix issue of parameter interventions not making their way to observables * another pass at the implementation * fix * lint * unsqueeze time to handle shape errors * adding failing test for observables defined by parameter values (#632) * set initial_observables before running simulate * lint * add reset * remove print --------- Co-authored-by: sabinala <130604122+sabinala@users.noreply.github.com> --- pyciemss/compiled_dynamics.py | 62 ++++++++++++++++++++++++++++++++--- tests/test_interfaces.py | 38 +++++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/pyciemss/compiled_dynamics.py b/pyciemss/compiled_dynamics.py index 85248440..57038627 100644 --- a/pyciemss/compiled_dynamics.py +++ b/pyciemss/compiled_dynamics.py @@ -1,7 +1,7 @@ from __future__ import annotations import functools -from typing import Callable, Dict, Optional, Tuple, TypeVar, Union +from typing import Callable, Dict, List, Optional, Tuple, TypeVar, Union import mira import mira.metamodel @@ -11,6 +11,7 @@ import pyro import torch from chirho.dynamical.handlers import LogTrajectory +from chirho.dynamical.internals._utils import _squeeze_time_dim from chirho.dynamical.ops import State, simulate S = TypeVar("S") @@ -121,7 +122,7 @@ def forward( self.instantiate_parameters() if logging_times is not None: - with LogTrajectory(logging_times) as lt: + with LogObservables(logging_times, self) as lo: try: simulate(self.deriv, self.initial_state(), start_time, end_time) except AssertionError as e: @@ -135,11 +136,11 @@ def forward( else: raise e - state = lt.trajectory + state = lo.state + observables = lo.observables else: state = simulate(self.deriv, self.initial_state(), start_time, end_time) - - observables = self.observables(state) + observables = self.observables(state) if is_traced: # Add the observables to the trace so that they can be accessed later. @@ -229,3 +230,54 @@ def get_name(obj) -> str: @get_name.register def _get_name_str(name: str) -> str: return name + + +class LogObservables(pyro.poutine.messenger.Messenger): + def __init__(self, times: torch.Tensor, model: CompiledDynamics): + super().__init__() + self.model = model + self.observables: State[torch.Tensor] = {} + self.state: State[torch.Tensor] = {} + + # This gets around the issue of the LogTrajectory handler blocking `self` + self.lt = LogTrajectory(times) + + self.reset() + + def reset(self): + self._observables_names: List[str] = [] + self._initial_observables: State[torch.Tensor] = {} + + def _pyro_simulate_point(self, msg): + self.lt._pyro_simulate_point(msg) + + def _pyro_post_simulate_trajectory(self, msg): + observables = self.model.observables(msg["value"]) + self._observables_names = list(observables.keys()) + msg["value"] = {**msg["value"], **observables} + + def _pyro_simulate(self, msg): + self._initial_observables = _squeeze_time_dim( + self.model.observables(msg["args"][1]) + ) + + def _pyro_post_simulate(self, msg): + msg["args"] = ( + msg["args"][0], + {**msg["args"][1], **self._initial_observables}, + msg["args"][2], + msg["args"][3], + ) + + self.lt._pyro_post_simulate(msg) + + self.observables = { + k: v for k, v in self.lt.trajectory.items() if k in self._observables_names + } + self.state = { + k: v + for k, v in self.lt.trajectory.items() + if k not in self._observables_names + } + + self.reset() diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index e4563b8b..00b605de 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -826,3 +826,41 @@ def test_intervention_on_constant_param( torch.arange(start_time, end_time + logging_step_size, logging_step_size) ) assert processed_result.shape[1] >= 2 + + +@pytest.mark.parametrize("sample_method", [sample]) +@pytest.mark.parametrize("model_fixture", MODELS) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("start_time", START_TIMES) +def test_observables_change_with_interventions( + sample_method, + model_fixture, + end_time, + logging_step_size, + num_samples, + start_time, +): + # Assert that sample returns expected result with intervention on constant parameter + if "SIR_param" not in model_fixture.url: + pytest.skip("Only test 'SIR_param_in_obs' model") + else: + processed_result = sample_method( + model_fixture.url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + static_parameter_interventions={ + torch.tensor(2.0): {"beta": torch.tensor(0.001)} + }, + )["data"] + + # The test will fail if values before and after the intervention are the same + assert ( + processed_result["beta_param_observable_state"][0] + > processed_result["beta_param_observable_state"][ + int(end_time / logging_step_size) + ] + )