Skip to content

Commit

Permalink
Updated optimize to handle combination of multiple intervention templ…
Browse files Browse the repository at this point in the history
…ates (#591)

* Updated optimize to handle multiple intervention templates

- updated typing for interventions input to optimize
- created a lambda function to combine different intervention templates

* Update optimize_interface.ipynb

* Update optimize_interface.ipynb

* fixed how callable list is read in 'ouu.py'

* Update optimize_interface.ipynb

* Add tests for combined intervention templates

* Fixed tests to run with new setup of interventions

* added intervention function combinator and example usage in the notebook (which should be deleted) (#595)

Co-authored-by: Sam Witty <[email protected]>

* Update intervention_builder.py

* Update intervention_builder.py

* update data type

* fix data type

* updated notebook

* typing

* typing for interventions changed from float to Tensor

* Lint

* Typing

* Update ouu.py

* Update intervention_builder.py

* Updates for type matching

* Update ouu.py

---------

Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
anirban-chaudhuri and SamWitty authored Jul 31, 2024
1 parent a44045f commit adeb6b9
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 38 deletions.
304 changes: 301 additions & 3 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

54 changes: 48 additions & 6 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def param_value_objective(
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == param_size, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if start_time[count].item() in static_parameter_interventions:
Expand Down Expand Up @@ -48,6 +53,11 @@ def start_time_objective(
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == param_size, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count].item() in static_parameter_interventions:
Expand Down Expand Up @@ -78,9 +88,11 @@ def start_time_param_value_objective(
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
assert (
x.size()[0] == param_size * 2
), "Size mismatch: check size for initial_guess_interventions and/or bounds_interventions"
x = torch.atleast_1d(x)
assert x.size()[0] == param_size * 2, (
f"Size mismatch between input size ('{x.size()[0]}') and param_name size ('{param_size * 2}'): "
"check size for initial_guess_interventions and/or bounds_interventions."
)
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for count in range(param_size):
if x[count * 2].item() in static_parameter_interventions:
Expand All @@ -102,10 +114,40 @@ def intervention_generator(
return intervention_generator


def intervention_func_combinator(
intervention_funcs: List[
Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]
],
intervention_func_lengths: List[int],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
assert len(intervention_funcs) == len(intervention_func_lengths)

total_length = sum(intervention_func_lengths)

# Note: This only works for combining static parameter interventions.
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == total_length
interventions: List[Dict[float, Dict[str, Intervention]]] = [
{} for _ in range(len(intervention_funcs))
]
i = 0
for j, (input_length, intervention_func) in enumerate(
zip(intervention_func_lengths, intervention_funcs)
):
interventions[j] = intervention_func(x[i : i + input_length])
i += input_length
return combine_static_parameter_interventions(interventions)

return intervention_generator


def combine_static_parameter_interventions(
interventions: List[Dict[torch.Tensor, Dict[str, Intervention]]]
) -> Dict[torch.Tensor, Dict[str, Intervention]]:
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}
interventions: List[Dict[float, Dict[str, Intervention]]]
) -> Dict[float, Dict[str, Intervention]]:
static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}
for intervention in interventions:
for key, value in intervention.items():
if key in static_parameter_interventions:
Expand Down
4 changes: 1 addition & 3 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,9 +783,7 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down
13 changes: 6 additions & 7 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def __init__(
risk_measure: Callable = lambda z: alpha_superquantile(z, alpha=0.95),
num_samples: int = 1000,
guide=None,
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
Expand Down Expand Up @@ -128,11 +126,12 @@ def propagate_uncertainty(self, x):
with torch.no_grad():
x = np.atleast_1d(x)
# Combine existing interventions with intervention being optimized
intervention_list = [
deepcopy(self.fixed_static_parameter_interventions)
]
intervention_list.extend([self.interventions(torch.from_numpy(x))])
static_parameter_interventions = combine_static_parameter_interventions(
[
deepcopy(self.fixed_static_parameter_interventions),
self.interventions(torch.from_numpy(x)),
]
intervention_list
)
static_parameter_intervention_handlers = [
StaticParameterIntervention(
Expand Down
41 changes: 26 additions & 15 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from pyciemss.integration_utils.intervention_builder import (
intervention_func_combinator,
param_value_objective,
start_time_objective,
start_time_param_value_objective,
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__(
ModelFixture(os.path.join(MODELS_PATH, "SEIRHD_stockflow.json"), "p_cbeta"),
]

optimize_kwargs_SIRstockflow_param = {
optkwargs_SIRstockflow_param = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
Expand All @@ -106,7 +107,7 @@ def __init__(
"bounds_interventions": [[0.1], [0.5]],
}

optimize_kwargs_SIRstockflow_time = {
optkwargs_SIRstockflow_time = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": start_time_objective(
Expand All @@ -118,7 +119,7 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_time_param = {
optkwargs_SIRstockflow_time_param = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": start_time_param_value_objective(
Expand All @@ -129,35 +130,45 @@ def __init__(
"bounds_interventions": [[0.0, 0.1], [40.0, 0.5]],
}

optimize_kwargs_SEIRHD_param_maxQoI = {
# Creating a combined interventions by combining into list of Callables
intervened_params = ["beta_c", "gamma"]
static_parameter_interventions1 = param_value_objective(
param_name=[intervened_params[0]],
start_time=torch.tensor([10.0]),
)
static_parameter_interventions2 = start_time_objective(
param_name=[intervened_params[1]],
param_value=torch.tensor([0.45]),
)
optkwargs_SEIRHD_paramtimeComb_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
param_name=["beta_c", "gamma"],
start_time=[torch.tensor(10.0), torch.tensor(15.0)],
"risk_bound": 3e5,
"static_parameter_interventions": intervention_func_combinator(
[static_parameter_interventions1, static_parameter_interventions2],
[1, 1],
),
"objfun": lambda x: np.abs(0.35 - x[0]) + np.abs(0.2 - x[1]),
"initial_guess_interventions": [0.2, 0.4],
"bounds_interventions": [[0.1, 0.1], [0.5, 0.5]],
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
}

OPT_MODELS = [
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_param,
optimize_kwargs=optkwargs_SIRstockflow_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
optimize_kwargs=optkwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time_param,
optimize_kwargs=optkwargs_SIRstockflow_time_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"),
optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI,
optimize_kwargs=optkwargs_SEIRHD_paramtimeComb_maxQoI,
),
]

Expand Down
14 changes: 10 additions & 4 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,18 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
assert opt_policy[i] <= bounds_interventions[1][i]

if "fixed_static_parameter_interventions" in optimize_kwargs:
opt_intervention = combine_static_parameter_interventions(
[
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]),
intervention_list = [
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"])
]
intervention_list.extend(
[optimize_kwargs["static_parameter_interventions"](opt_result["policy"])]
if not isinstance(
optimize_kwargs["static_parameter_interventions"](opt_result["policy"]),
]
list,
)
else optimize_kwargs["static_parameter_interventions"](opt_result["policy"])
)
opt_intervention = combine_static_parameter_interventions(intervention_list)
else:
opt_intervention = optimize_kwargs["static_parameter_interventions"](
opt_result["policy"]
Expand Down

0 comments on commit adeb6b9

Please sign in to comment.