Skip to content

Commit

Permalink
optimize with fixed set of interventions along with optimizing over…
Browse files Browse the repository at this point in the history
… specific interventions (#563)

* Adding fixed interventions as input to `optimize`

* Update optimize_interface.ipynb

* Adding integration utility to combine interventions

* Adding deepcopy for fixed interventions

* Update intervention_builder.py

* Sw output (#565)

* simplify intervention splicing

* lint

* removed test

---------

Co-authored-by: Sam Witty <[email protected]>

* Update optimize_interface.ipynb

* Adding tests for using fixed intervention with optimize

* Simplify interchange dictionary processing (#564) (#566)

* simplify intervention splicing

* lint

* removed test

Co-authored-by: Sam Witty <[email protected]>

* Update test_interfaces.py

* updating name for combining interventions

* Update ouu.py

* lint

* update tests to use new name for combining interventions

* lint

* Update optimize_interface.ipynb

---------

Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
anirban-chaudhuri and SamWitty authored Apr 12, 2024
1 parent f5454df commit 366e755
Show file tree
Hide file tree
Showing 6 changed files with 486 additions and 154 deletions.
543 changes: 420 additions & 123 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,16 @@ def intervention_generator(
return static_parameter_interventions

return intervention_generator


def combine_static_parameter_interventions(
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]]
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
for intervention in interventions:
for key, value in intervention.items():
if key in static_parameter_interventions:
static_parameter_interventions[key].update(value)
else:
static_parameter_interventions.update({key: value})
return static_parameter_interventions
26 changes: 11 additions & 15 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down Expand Up @@ -818,6 +821,12 @@ def optimize(
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of fixed static interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
n_samples_ouu: int
- The number of samples to draw from the model to estimate risk for each optimization iteration.
maxiter: int
Expand Down Expand Up @@ -855,6 +864,7 @@ def optimize(
risk_measure=lambda z: alpha_superquantile(z, alpha=alpha),
num_samples=1,
guide=inferred_parameters,
fixed_static_parameter_interventions=fixed_static_parameter_interventions,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
Expand All @@ -872,21 +882,7 @@ def optimize(
print(f"Time taken: ({forward_time/1.:.2e} seconds per model evaluation).")

# Assign the required number of MC samples for each OUU iteration
RISK = computeRisk(
model=control_model,
interventions=static_parameter_interventions,
qoi=qoi,
end_time=end_time,
logging_step_size=logging_step_size,
start_time=start_time,
risk_measure=lambda z: alpha_superquantile(z, alpha=alpha),
num_samples=n_samples_ouu,
guide=inferred_parameters,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
risk_bound=risk_bound,
)
RISK.num_samples = n_samples_ouu
# Define constraints >= 0
constraints = (
# risk constraint
Expand Down
16 changes: 15 additions & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import warnings
from copy import deepcopy
from typing import Any, Callable, Dict, List, Tuple

import numpy as np
Expand All @@ -10,6 +11,9 @@
from scipy.optimize import basinhopping
from tqdm import tqdm

from pyciemss.integration_utils.intervention_builder import (
combine_static_parameter_interventions,
)
from pyciemss.interruptions import (
ParameterInterventionTracer,
StaticParameterIntervention,
Expand Down Expand Up @@ -70,6 +74,9 @@ def __init__(
risk_measure: Callable = lambda z: alpha_superquantile(z, alpha=0.95),
num_samples: int = 1000,
guide=None,
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
Expand All @@ -84,6 +91,7 @@ def __init__(
self.start_time = start_time
self.end_time = end_time
self.guide = guide
self.fixed_static_parameter_interventions = fixed_static_parameter_interventions
self.solver_method = solver_method
self.solver_options = solver_options
self.logging_times = torch.arange(
Expand Down Expand Up @@ -119,7 +127,13 @@ def propagate_uncertainty(self, x):
with pyro.poutine.seed(rng_seed=0):
with torch.no_grad():
x = np.atleast_1d(x)
static_parameter_interventions = self.interventions(torch.from_numpy(x))
# Combine existing interventions with intervention being optimized
static_parameter_interventions = combine_static_parameter_interventions(
[
deepcopy(self.fixed_static_parameter_interventions),
self.interventions(torch.from_numpy(x)),
]
)
static_parameter_intervention_handlers = [
StaticParameterIntervention(
time, dict(**static_intervention_assignment)
Expand Down
21 changes: 9 additions & 12 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,34 +110,31 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_param_maxQoI = {
optimize_kwargs_SEIRHD_param_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
param_name=["p_cbeta"],
param_value=[lambda x: torch.tensor([x])],
start_time=[torch.tensor(1.0)],
param_name=["beta_c", "gamma"],
start_time=[torch.tensor(10.0), torch.tensor(15.0)],
),
"objfun": lambda x: np.abs(0.35 - x),
"initial_guess_interventions": 0.15,
"bounds_interventions": [[0.1], [0.5]],
"objfun": lambda x: np.abs(0.35 - x[0]) + np.abs(0.2 - x[1]),
"initial_guess_interventions": [0.2, 0.4],
"bounds_interventions": [[0.1, 0.1], [0.5, 0.5]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
}

OPT_MODELS = [
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_param_maxQoI,
os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"),
optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI,
),
]

Expand Down
21 changes: 18 additions & 3 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pyro
import pytest
import torch

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.intervention_builder import (
combine_static_parameter_interventions,
)
from pyciemss.integration_utils.observation import load_data
from pyciemss.interfaces import (
calibrate,
Expand Down Expand Up @@ -578,15 +583,25 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
assert bounds_interventions[0][i] <= opt_policy[i]
assert opt_policy[i] <= bounds_interventions[1][i]

if "fixed_static_parameter_interventions" in optimize_kwargs:
opt_intervention = combine_static_parameter_interventions(
[
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]),
optimize_kwargs["static_parameter_interventions"](opt_result["policy"]),
]
)
else:
opt_intervention = optimize_kwargs["static_parameter_interventions"](
opt_result["policy"]
)

result_opt = sample(
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions=optimize_kwargs[
"static_parameter_interventions"
](opt_result["policy"]),
static_parameter_interventions=opt_intervention,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
)["unprocessed_result"]
Expand Down

0 comments on commit 366e755

Please sign in to comment.