From e6a903bc6fe8978e06a9c50c90f9d7b4c7c4b513 Mon Sep 17 00:00:00 2001 From: Anirban Chaudhuri <75496534+anirban-chaudhuri@users.noreply.github.com> Date: Tue, 16 Jul 2024 18:14:26 -0400 Subject: [PATCH] Fixed tests to run with new setup of interventions --- tests/fixtures.py | 12 +++++------- tests/test_interfaces.py | 14 ++++++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/fixtures.py b/tests/fixtures.py index c068bd62c..a51373bae 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -129,7 +129,7 @@ def __init__( "bounds_interventions": [[0.0, 0.1], [40.0, 0.5]], } -# Creating a combined intervention +# Creating a combined interventions by combining into list of Callables intervened_params = ["beta_c", "gamma"] static_parameter_interventions1 = param_value_objective( param_name=[intervened_params[0]], @@ -139,15 +139,13 @@ def __init__( param_name=[intervened_params[1]], param_value=torch.tensor([0.45]), ) -# Combine different intervention templates into a list of Callables -static_parameter_interventions = lambda x: [ - static_parameter_interventions1(torch.atleast_1d(x[0])), - static_parameter_interventions2(torch.atleast_1d(x[1])), -] optkwargs_SEIRHD_paramtimeComb_maxQoI = { "qoi": lambda x: obs_max_qoi(x, ["I_state"]), "risk_bound": 3e5, - "static_parameter_interventions": static_parameter_interventions, + "static_parameter_interventions": lambda x: [ + static_parameter_interventions1(torch.atleast_1d(x[0])), + static_parameter_interventions2(torch.atleast_1d(x[1])), + ], "objfun": lambda x: np.abs(0.35 - x[0]) - x[1], "initial_guess_interventions": [0.35, 5.0], "bounds_interventions": [[0.1, 1.0], [0.5, 90.0]], diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 4b58a5710..7e647c5b3 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -584,12 +584,18 @@ def test_optimize(model_fixture, start_time, end_time, num_samples): assert opt_policy[i] <= bounds_interventions[1][i] if "fixed_static_parameter_interventions" in optimize_kwargs: - opt_intervention = combine_static_parameter_interventions( - [ - deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]), + intervention_list = [ + deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]) + ] + intervention_list.extend( + [optimize_kwargs["static_parameter_interventions"](opt_result["policy"])] + if not isinstance( optimize_kwargs["static_parameter_interventions"](opt_result["policy"]), - ] + list, + ) + else optimize_kwargs["static_parameter_interventions"](opt_result["policy"]) ) + opt_intervention = combine_static_parameter_interventions(intervention_list) else: opt_intervention = optimize_kwargs["static_parameter_interventions"]( opt_result["policy"]