Skip to content

Commit

Permalink
Fixed tests to run with new setup of interventions
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 16, 2024
1 parent 67d4fbe commit e6a903b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
12 changes: 5 additions & 7 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
"bounds_interventions": [[0.0, 0.1], [40.0, 0.5]],
}

# Creating a combined intervention
# Creating a combined interventions by combining into list of Callables
intervened_params = ["beta_c", "gamma"]
static_parameter_interventions1 = param_value_objective(
param_name=[intervened_params[0]],
Expand All @@ -139,15 +139,13 @@ def __init__(
param_name=[intervened_params[1]],
param_value=torch.tensor([0.45]),
)
# Combine different intervention templates into a list of Callables
static_parameter_interventions = lambda x: [
static_parameter_interventions1(torch.atleast_1d(x[0])),
static_parameter_interventions2(torch.atleast_1d(x[1])),
]
optkwargs_SEIRHD_paramtimeComb_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 3e5,
"static_parameter_interventions": static_parameter_interventions,
"static_parameter_interventions": lambda x: [
static_parameter_interventions1(torch.atleast_1d(x[0])),
static_parameter_interventions2(torch.atleast_1d(x[1])),
],
"objfun": lambda x: np.abs(0.35 - x[0]) - x[1],
"initial_guess_interventions": [0.35, 5.0],
"bounds_interventions": [[0.1, 1.0], [0.5, 90.0]],
Expand Down
14 changes: 10 additions & 4 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,12 +584,18 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
assert opt_policy[i] <= bounds_interventions[1][i]

if "fixed_static_parameter_interventions" in optimize_kwargs:
opt_intervention = combine_static_parameter_interventions(
[
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]),
intervention_list = [
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"])
]
intervention_list.extend(
[optimize_kwargs["static_parameter_interventions"](opt_result["policy"])]
if not isinstance(
optimize_kwargs["static_parameter_interventions"](opt_result["policy"]),
]
list,
)
else optimize_kwargs["static_parameter_interventions"](opt_result["policy"])
)
opt_intervention = combine_static_parameter_interventions(intervention_list)
else:
opt_intervention = optimize_kwargs["static_parameter_interventions"](
opt_result["policy"]
Expand Down

0 comments on commit e6a903b

Please sign in to comment.