diff --git a/tests/fixtures.py b/tests/fixtures.py index c068bd62..a51373ba 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -129,7 +129,7 @@ def __init__( "bounds_interventions": [[0.0, 0.1], [40.0, 0.5]], } -# Creating a combined intervention +# Creating a combined interventions by combining into list of Callables intervened_params = ["beta_c", "gamma"] static_parameter_interventions1 = param_value_objective( param_name=[intervened_params[0]], @@ -139,15 +139,13 @@ def __init__( param_name=[intervened_params[1]], param_value=torch.tensor([0.45]), ) -# Combine different intervention templates into a list of Callables -static_parameter_interventions = lambda x: [ - static_parameter_interventions1(torch.atleast_1d(x[0])), - static_parameter_interventions2(torch.atleast_1d(x[1])), -] optkwargs_SEIRHD_paramtimeComb_maxQoI = { "qoi": lambda x: obs_max_qoi(x, ["I_state"]), "risk_bound": 3e5, - "static_parameter_interventions": static_parameter_interventions, + "static_parameter_interventions": lambda x: [ + static_parameter_interventions1(torch.atleast_1d(x[0])), + static_parameter_interventions2(torch.atleast_1d(x[1])), + ], "objfun": lambda x: np.abs(0.35 - x[0]) - x[1], "initial_guess_interventions": [0.35, 5.0], "bounds_interventions": [[0.1, 1.0], [0.5, 90.0]], diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 4b58a571..7e647c5b 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -584,12 +584,18 @@ def test_optimize(model_fixture, start_time, end_time, num_samples): assert opt_policy[i] <= bounds_interventions[1][i] if "fixed_static_parameter_interventions" in optimize_kwargs: - opt_intervention = combine_static_parameter_interventions( - [ - deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]), + intervention_list = [ + deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]) + ] + intervention_list.extend( + [optimize_kwargs["static_parameter_interventions"](opt_result["policy"])] + if not isinstance( optimize_kwargs["static_parameter_interventions"](opt_result["policy"]), - ] + list, + ) + else optimize_kwargs["static_parameter_interventions"](opt_result["policy"]) ) + opt_intervention = combine_static_parameter_interventions(intervention_list) else: opt_intervention = optimize_kwargs["static_parameter_interventions"]( opt_result["policy"]