Skip to content

Commit

Permalink
Add tests for combined intervention templates
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 16, 2024
1 parent 0dcc3e8 commit 67d4fbe
Showing 1 changed file with 28 additions and 16 deletions.
44 changes: 28 additions & 16 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
ModelFixture(os.path.join(MODELS_PATH, "SEIRHD_stockflow.json"), "p_cbeta"),
]

optimize_kwargs_SIRstockflow_param = {
optkwargs_SIRstockflow_param = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
Expand All @@ -106,7 +106,7 @@ def __init__(
"bounds_interventions": [[0.1], [0.5]],
}

optimize_kwargs_SIRstockflow_time = {
optkwargs_SIRstockflow_time = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": start_time_objective(
Expand All @@ -118,7 +118,7 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_time_param = {
optkwargs_SIRstockflow_time_param = {
"qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1),
"risk_bound": 300.0,
"static_parameter_interventions": start_time_param_value_objective(
Expand All @@ -129,35 +129,47 @@ def __init__(
"bounds_interventions": [[0.0, 0.1], [40.0, 0.5]],
}

optimize_kwargs_SEIRHD_param_maxQoI = {
# Creating a combined intervention
intervened_params = ["beta_c", "gamma"]
static_parameter_interventions1 = param_value_objective(
param_name=[intervened_params[0]],
start_time=torch.tensor([10.0]),
)
static_parameter_interventions2 = start_time_objective(
param_name=[intervened_params[1]],
param_value=torch.tensor([0.45]),
)
# Combine different intervention templates into a list of Callables
static_parameter_interventions = lambda x: [
static_parameter_interventions1(torch.atleast_1d(x[0])),
static_parameter_interventions2(torch.atleast_1d(x[1])),
]
optkwargs_SEIRHD_paramtimeComb_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
param_name=["beta_c", "gamma"],
start_time=[torch.tensor(10.0), torch.tensor(15.0)],
),
"objfun": lambda x: np.abs(0.35 - x[0]) + np.abs(0.2 - x[1]),
"initial_guess_interventions": [0.2, 0.4],
"bounds_interventions": [[0.1, 0.1], [0.5, 0.5]],
"risk_bound": 3e5,
"static_parameter_interventions": static_parameter_interventions,
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
}

OPT_MODELS = [
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_param,
optimize_kwargs=optkwargs_SIRstockflow_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
optimize_kwargs=optkwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
optimize_kwargs=optimize_kwargs_SIRstockflow_time_param,
optimize_kwargs=optkwargs_SIRstockflow_time_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"),
optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI,
optimize_kwargs=optkwargs_SEIRHD_paramtimeComb_maxQoI,
),
]

Expand Down

0 comments on commit 67d4fbe

Please sign in to comment.