Skip to content

Commit

Permalink
update tests to use new name for combining interventions
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Apr 11, 2024
1 parent 13c768c commit ae40bb3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 34 deletions.
40 changes: 8 additions & 32 deletions docs/source/optimize_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1298,37 +1298,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
" 28%|██▊ | 33/120 [03:04<08:17, 5.72s/it]C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\ouu\\ouu.py:106: UserWarning: Selected interventions are out of bounds. Will use a penalty instead of estimating risk.\n",
" 28%|██▊ | 33/120 [04:14<09:45, 6.73s/it]C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\ouu\\ouu.py:108: UserWarning: Selected interventions are out of bounds. Will use a penalty instead of estimating risk.\n",
" warnings.warn(\n",
"124it [10:02, 4.86s/it] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Optimal policy: tensor([0.3500, 0.4602], dtype=torch.float64)\n",
"{'policy': tensor([0.3500, 0.4602], dtype=torch.float64), 'OptResults': message: ['requested number of basinhopping iterations completed successfully']\n",
" success: False\n",
" fun: 0.26016707274061507\n",
" x: [ 3.500e-01 4.602e-01]\n",
" nit: 3\n",
" minimization_failures: 4\n",
" nfev: 120\n",
" lowest_optimization_result: message: Maximum number of function evaluations has been exceeded.\n",
" success: False\n",
" status: 2\n",
" fun: 0.26016707274061507\n",
" x: [ 3.500e-01 4.601e-01]\n",
" nfev: 30\n",
" maxcv: 0.0}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
" 52%|█████▎ | 63/120 [08:35<07:14, 7.62s/it]C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\ouu\\ouu.py:108: UserWarning: Selected interventions are out of bounds. Will use a penalty instead of estimating risk.\n",
" warnings.warn(\n",
"C:\\Users\\Anirban\\Documents\\GitHub\\pyciemss\\pyciemss\\ouu\\ouu.py:108: UserWarning: Selected interventions are out of bounds. Will use a penalty instead of estimating risk.\n",
" warnings.warn(\n",
" 64%|██████▍ | 77/120 [10:16<06:02, 8.42s/it]"
]
}
],
Expand All @@ -1341,7 +1317,7 @@
"initial_guess_interventions = [0.2, 0.4]\n",
"bounds_interventions = [[0.1, 0.1], [0.5, 0.5]]\n",
"# Note that param_value is not passed in below and defaults to None.\n",
"# User can also pass ina list of lambda x: torch.tensor(x) for each intervention.\n",
"# User can also pass in a list of lambda x: torch.tensor(x) for each intervention.\n",
"static_parameter_interventions = param_value_objective(\n",
" param_name=intervened_params,\n",
" start_time=intervention_time,\n",
Expand Down Expand Up @@ -1384,7 +1360,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down
4 changes: 2 additions & 2 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.intervention_builder import combine_interventions
from pyciemss.integration_utils.intervention_builder import combine_static_parameter_interventions
from pyciemss.integration_utils.observation import load_data
from pyciemss.interfaces import (
calibrate,
Expand Down Expand Up @@ -582,7 +582,7 @@ def test_optimize(model_fixture, start_time, end_time, num_samples):
assert opt_policy[i] <= bounds_interventions[1][i]

if "fixed_static_parameter_interventions" in optimize_kwargs:
opt_intervention = combine_interventions(
opt_intervention = combine_static_parameter_interventions(
[
deepcopy(optimize_kwargs["fixed_static_parameter_interventions"]),
optimize_kwargs["static_parameter_interventions"](opt_result["policy"]),
Expand Down

0 comments on commit ae40bb3

Please sign in to comment.