diff --git a/tests/fixtures.py b/tests/fixtures.py index 9d2d75802..c068bd62c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -93,7 +93,7 @@ def __init__( ModelFixture(os.path.join(MODELS_PATH, "SEIRHD_stockflow.json"), "p_cbeta"), ] -optimize_kwargs_SIRstockflow_param = { +optkwargs_SIRstockflow_param = { "qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1), "risk_bound": 300.0, "static_parameter_interventions": param_value_objective( @@ -106,7 +106,7 @@ def __init__( "bounds_interventions": [[0.1], [0.5]], } -optimize_kwargs_SIRstockflow_time = { +optkwargs_SIRstockflow_time = { "qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1), "risk_bound": 300.0, "static_parameter_interventions": start_time_objective( @@ -118,7 +118,7 @@ def __init__( "bounds_interventions": [[0.0], [40.0]], } -optimize_kwargs_SIRstockflow_time_param = { +optkwargs_SIRstockflow_time_param = { "qoi": lambda x: obs_nday_average_qoi(x, ["I_state"], 1), "risk_bound": 300.0, "static_parameter_interventions": start_time_param_value_objective( @@ -129,35 +129,47 @@ def __init__( "bounds_interventions": [[0.0, 0.1], [40.0, 0.5]], } -optimize_kwargs_SEIRHD_param_maxQoI = { +# Creating a combined intervention +intervened_params = ["beta_c", "gamma"] +static_parameter_interventions1 = param_value_objective( + param_name=[intervened_params[0]], + start_time=torch.tensor([10.0]), +) +static_parameter_interventions2 = start_time_objective( + param_name=[intervened_params[1]], + param_value=torch.tensor([0.45]), +) +# Combine different intervention templates into a list of Callables +static_parameter_interventions = lambda x: [ + static_parameter_interventions1(torch.atleast_1d(x[0])), + static_parameter_interventions2(torch.atleast_1d(x[1])), +] +optkwargs_SEIRHD_paramtimeComb_maxQoI = { "qoi": lambda x: obs_max_qoi(x, ["I_state"]), - "risk_bound": 300.0, - "static_parameter_interventions": param_value_objective( - param_name=["beta_c", "gamma"], - start_time=[torch.tensor(10.0), torch.tensor(15.0)], - ), - "objfun": lambda x: np.abs(0.35 - x[0]) + np.abs(0.2 - x[1]), - "initial_guess_interventions": [0.2, 0.4], - "bounds_interventions": [[0.1, 0.1], [0.5, 0.5]], + "risk_bound": 3e5, + "static_parameter_interventions": static_parameter_interventions, + "objfun": lambda x: np.abs(0.35 - x[0]) - x[1], + "initial_guess_interventions": [0.35, 5.0], + "bounds_interventions": [[0.1, 1.0], [0.5, 90.0]], "fixed_static_parameter_interventions": {10.0: {"hosp": torch.tensor(0.1)}}, } OPT_MODELS = [ ModelFixture( os.path.join(MODELS_PATH, "SIR_stockflow.json"), - optimize_kwargs=optimize_kwargs_SIRstockflow_param, + optimize_kwargs=optkwargs_SIRstockflow_param, ), ModelFixture( os.path.join(MODELS_PATH, "SIR_stockflow.json"), - optimize_kwargs=optimize_kwargs_SIRstockflow_time, + optimize_kwargs=optkwargs_SIRstockflow_time, ), ModelFixture( os.path.join(MODELS_PATH, "SIR_stockflow.json"), - optimize_kwargs=optimize_kwargs_SIRstockflow_time_param, + optimize_kwargs=optkwargs_SIRstockflow_time_param, ), ModelFixture( os.path.join(MODELS_PATH, "SEIRHD_NPI_Type1_petrinet.json"), - optimize_kwargs=optimize_kwargs_SEIRHD_param_maxQoI, + optimize_kwargs=optkwargs_SEIRHD_paramtimeComb_maxQoI, ), ]