Skip to content

Commit

Permalink
Added different types of fixed interventions to optimize (#625)
Browse files Browse the repository at this point in the history
* Added different types of fixed interventions to optimize

* Adding tests

* Lint and fixes

* Update test_interfaces.py

* Update fixtures.py

* testing

* add random seed ensuring that dynamic intervention always happens

* lint

---------

Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
anirban-chaudhuri and SamWitty authored Dec 6, 2024
1 parent 357be59 commit 2f910c9
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 28 deletions.
4 changes: 2 additions & 2 deletions docs/source/interfaces.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3327,7 +3327,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "askem",
"language": "python",
"name": "python3"
},
Expand All @@ -3341,7 +3341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions docs/source/optimize_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@
"risk_bound = 3e5\n",
"qoi = lambda y: obs_max_qoi(y, observed_params)\n",
"objfun = lambda x: np.sum(np.abs(param_current - x))\n",
"fixed_interventions = {10.: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result5 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1273,7 +1273,7 @@
")\n",
"\n",
"# Fixed intervention on hosp parameter\n",
"fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result6 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1513,7 +1513,7 @@
")\n",
"\n",
"# Fixed intervention on hosp parameter\n",
"fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result6 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1602,7 +1602,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "askem",
"language": "python",
"name": "python3"
},
Expand Down
50 changes: 45 additions & 5 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,13 +648,13 @@ def calibrate(
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
static_state_interventions: Dict[float, Dict[str, Intervention]]
static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
static_parameter_interventions: Dict[float, Dict[str, Intervention]]
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
Expand Down Expand Up @@ -833,7 +833,18 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {},
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down Expand Up @@ -891,12 +902,38 @@ def optimize(
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of fixed static interventions to apply to the model and not optimize for.
fixed_static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of fixed static parameter interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static state interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model and not optimize for.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model and not optimize for.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
n_samples_ouu: int
- The number of samples to draw from the model to estimate risk for each optimization iteration.
maxiter: int
Expand Down Expand Up @@ -950,6 +987,9 @@ def optimize(
num_samples=1,
guide=inferred_parameters,
fixed_static_parameter_interventions=fixed_static_parameter_interventions,
fixed_static_state_interventions=fixed_static_state_interventions,
fixed_dynamic_state_interventions=fixed_dynamic_state_interventions,
fixed_dynamic_parameter_interventions=fixed_dynamic_parameter_interventions,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
Expand Down
63 changes: 60 additions & 3 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pyro
import torch
from chirho.dynamical.handlers import DynamicIntervention, StaticIntervention
from chirho.dynamical.handlers.solver import TorchDiffEq
from chirho.interventional.ops import Intervention
from scipy.optimize import basinhopping
Expand All @@ -14,6 +15,7 @@
combine_static_parameter_interventions,
)
from pyciemss.interruptions import (
DynamicParameterIntervention,
ParameterInterventionTracer,
StaticParameterIntervention,
)
Expand Down Expand Up @@ -73,7 +75,20 @@ def __init__(
risk_measure: List[Callable] = [lambda z: alpha_superquantile(z, alpha=0.95)],
num_samples: int = 1000,
guide=None,
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_state_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
Expand All @@ -89,6 +104,11 @@ def __init__(
self.end_time = end_time
self.guide = guide
self.fixed_static_parameter_interventions = fixed_static_parameter_interventions
self.fixed_static_state_interventions = fixed_static_state_interventions
self.fixed_dynamic_state_interventions = fixed_dynamic_state_interventions
self.fixed_dynamic_parameter_interventions = (
fixed_dynamic_parameter_interventions
)
self.solver_method = solver_method
self.solver_options = solver_options
self.logging_times = torch.arange(
Expand Down Expand Up @@ -143,6 +163,34 @@ def propagate_uncertainty(self, x):
for time, static_intervention_assignment in static_parameter_interventions.items()
]

static_state_intervention_handlers = [
StaticIntervention(time, dict(**static_intervention_assignment))
for time, static_intervention_assignment in self.fixed_static_state_interventions.items()
]

dynamic_state_intervention_handlers = [
DynamicIntervention(
event_fn, dict(**dynamic_intervention_assignment)
)
for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_state_interventions.items()
]

dynamic_parameter_intervention_handlers = [
DynamicParameterIntervention(
event_fn,
dict(**dynamic_intervention_assignment),
is_traced=True,
)
for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_parameter_interventions.items()
]

intervention_handlers = (
static_state_intervention_handlers
+ static_parameter_intervention_handlers
+ dynamic_state_intervention_handlers
+ dynamic_parameter_intervention_handlers
)

def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(
Expand All @@ -152,7 +200,7 @@ def wrapped_model():
options=self.solver_options,
):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
for handler in intervention_handlers:
stack.enter_context(handler)
self.model(
torch.as_tensor(self.start_time),
Expand All @@ -161,12 +209,21 @@ def wrapped_model():
is_traced=True,
)

parallel = (
False
if len(
dynamic_parameter_intervention_handlers
+ dynamic_state_intervention_handlers
)
> 0
else True
)
# Sample from intervened model
samples = pyro.infer.Predictive(
wrapped_model,
guide=self.guide,
num_samples=self.num_samples,
parallel=True,
parallel=parallel,
)()
return samples

Expand Down
30 changes: 28 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,25 @@ def __init__(
ModelFixture(os.path.join(MODELS_PATH, "SEIRHD_stockflow.json"), "p_cbeta"),
]

fixed_static_state_interventions = {torch.tensor(5.0): {"I": torch.tensor(20.0)}}


# Define the threshold for when the intervention should be applied
def make_var_threshold(var: str, threshold: torch.Tensor):
def var_threshold(time, state):
return state[var] - threshold

return var_threshold


infection_threshold1 = make_var_threshold("I", torch.tensor(150.0))
fixed_dynamic_parameter_interventions = {
infection_threshold1: {"p_tr": torch.tensor(10.0)}
}

infection_threshold2 = make_var_threshold("I", torch.tensor(400.0))
fixed_dynamic_state_interventions = {infection_threshold2: {"S": torch.tensor(200.0)}}

optkwargs_SIRstockflow_param = {
"qoi": [lambda x: obs_nday_average_qoi(x, ["I_state"], 1)],
"risk_bound": [300.0],
Expand All @@ -111,6 +130,9 @@ def __init__(
"objfun": lambda x: np.abs(0.35 - x),
"initial_guess_interventions": 0.15,
"bounds_interventions": [[0.1], [0.5]],
"fixed_static_state_interventions": fixed_static_state_interventions,
"fixed_dynamic_parameter_interventions": fixed_dynamic_parameter_interventions,
"fixed_dynamic_state_interventions": fixed_dynamic_state_interventions,
}

optkwargs_SIRstockflow_time = {
Expand Down Expand Up @@ -156,7 +178,9 @@ def __init__(
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
"fixed_static_parameter_interventions": {
torch.tensor(10.0): {"hosp": torch.tensor(0.1)}
},
}
optkwargs_SEIRHD_multipleConstraints = {
"qoi": [
Expand All @@ -171,7 +195,9 @@ def __init__(
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
"fixed_static_parameter_interventions": {
torch.tensor(10.0): {"hosp": torch.tensor(0.1)}
},
"alpha": [0.95, 0.90],
}

Expand Down
46 changes: 34 additions & 12 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,23 +650,45 @@ def __call__(self, x):
opt_intervention = combine_static_parameter_interventions(intervention_list)
else:
opt_intervention = opt_intervention_temp
if "fixed_static_state_interventions" not in optimize_kwargs:
fixed_static_state_interventions = {}
else:
fixed_static_state_interventions = optimize_kwargs[
"fixed_static_state_interventions"
]
if "fixed_dynamic_parameter_interventions" not in optimize_kwargs:
fixed_dynamic_parameter_interventions = {}
else:
fixed_dynamic_parameter_interventions = optimize_kwargs[
"fixed_dynamic_parameter_interventions"
]
if "fixed_dynamic_state_interventions" not in optimize_kwargs:
fixed_dynamic_state_interventions = {}
else:
fixed_dynamic_state_interventions = optimize_kwargs[
"fixed_dynamic_state_interventions"
]

if "alpha" in optimize_kwargs:
alpha = optimize_kwargs["alpha"]
else:
alpha = [0.95]
result_opt = sample(
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions=opt_intervention,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
alpha=alpha,
qoi=optimize_kwargs["qoi"],
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result_opt = sample(
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions=opt_intervention,
static_state_interventions=fixed_static_state_interventions,
dynamic_parameter_interventions=fixed_dynamic_parameter_interventions,
dynamic_state_interventions=fixed_dynamic_state_interventions,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
alpha=alpha,
qoi=optimize_kwargs["qoi"],
)["unprocessed_result"]

intervened_result_subset = {
k: v
Expand Down

0 comments on commit 2f910c9

Please sign in to comment.