Skip to content

Commit

Permalink
Adding tests for using fixed intervention with optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Apr 8, 2024
1 parent f41cf11 commit 114712d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 115 deletions.
136 changes: 36 additions & 100 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

21 changes: 9 additions & 12 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,34 +110,31 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_param_maxQoI = {
optimize_kwargs_SEIRHD_param_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
param_name=["p_cbeta"],
param_value=[lambda x: torch.tensor([x])],
start_time=[torch.tensor(1.0)],
param_name=["beta_c", "gamma"],
start_time=[torch.tensor(10.0), torch.tensor(15.0)],
),
"objfun": lambda x: np.abs(0.35 - x),
"initial_guess_interventions": 0.15,
"bounds_interventions": [[0.1], [0.5]],
"objfun": lambda x: np.abs(0.35 - x[0]) + np.abs(0.2 - x[1]),
"initial_guess_interventions": [0.2, 0.4],
"bounds_interventions": [[0.1, 0.1], [0.5, 0.5]],
"fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}},
}

OPT_MODELS = [
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_param,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_param_maxQoI,
os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"),
optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI,
),
]

Expand Down
19 changes: 16 additions & 3 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from copy import deepcopy

import numpy as np
import pandas as pd
import pyro
import pytest
import torch

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.intervention_builder import combine_interventions
from pyciemss.integration_utils.observation import load_data
from pyciemss.interfaces import (
calibrate,
Expand Down Expand Up @@ -578,15 +581,25 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
assert bounds_interventions[0][i] <= opt_policy[i]
assert opt_policy[i] <= bounds_interventions[1][i]

if optimize_kwargs["fixed_static_parameter_interventions"]:
opt_intervention = combine_interventions(
[
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]),
optimize_kwargs["static_parameter_interventions"](opt_result["policy"]),
]
)
else:
opt_intervention = optimize_kwargs["static_parameter_interventions"](
opt_result["policy"]
)

result_opt = sample(
model_url,
end_time,
logging_step_size,
num_samples,
start_time=start_time,
static_parameter_interventions=optimize_kwargs[
"static_parameter_interventions"
](opt_result["policy"]),
static_parameter_interventions=opt_intervention,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
)["unprocessed_result"]
Expand Down

0 comments on commit 114712d

Please sign in to comment.