Skip to content

Commit

Permalink
added intervention function combinator and example usage in the noteb…
Browse files Browse the repository at this point in the history
…ook (which should be deleted) (#595)

Co-authored-by: Sam Witty <[email protected]>
  • Loading branch information
anirban-chaudhuri and SamWitty authored Jul 30, 2024
1 parent e6a903b commit adaaa94
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
33 changes: 30 additions & 3 deletions docs/source/optimize_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1460,7 +1460,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand All @@ -1473,7 +1473,9 @@
}
],
"source": [
"intervened_params = [\"beta_c\", \"gamma\"]\n",
"from pyciemss.integration_utils.intervention_builder import intervention_func_combinator\n",
"\n",
"intervened_params = [\"beta_c\", \"gamma\", \"gamma_c\"]\n",
"static_parameter_interventions1 = param_value_objective(\n",
" param_name=[intervened_params[0]],\n",
" start_time=torch.tensor([10.0]),\n",
Expand All @@ -1482,6 +1484,7 @@
" param_name=[intervened_params[1]],\n",
" param_value=torch.tensor([0.4]),\n",
")\n",
"\n",
"# Combine different intervention templates into a list of Callables\n",
"static_parameter_interventions = lambda x: [\n",
" static_parameter_interventions1(torch.atleast_1d(x[0])),\n",
Expand All @@ -1494,6 +1497,30 @@
"print(static_parameter_interventions(torch.tensor([0.4, 5.0])))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{10.0: {'beta_c': tensor(0.4000)}, 5.0: {'gamma': tensor(0.4000)}}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"static_parameter_interventions_comb = intervention_func_combinator(\n",
" [static_parameter_interventions1, static_parameter_interventions2],\n",
" [1, 1],\n",
")\n",
"static_parameter_interventions_comb(torch.tensor([0.4, 5.0]))"
]
},
{
"cell_type": "code",
"execution_count": 11,
Expand Down
26 changes: 26 additions & 0 deletions pyciemss/integration_utils/intervention_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,32 @@ def intervention_generator(
return intervention_generator


def intervention_func_combinator(
intervention_funcs: List[
Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]
],
intervention_func_lengths: List[int],
) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]:
assert len(intervention_funcs) == len(intervention_func_lengths)

total_length = sum(intervention_func_lengths)

# Note: This only works for combining static parameter interventions.
def intervention_generator(
x: torch.Tensor,
) -> Dict[float, Dict[str, Intervention]]:
x = torch.atleast_1d(x)
assert x.size()[0] == total_length
interventions = [None] * len(intervention_funcs)
i = 0
for j, (input_length, intervention_func) in enumerate(zip(intervention_func_lengths, intervention_funcs)):
interventions[j] = intervention_func(x[i:i + input_length])
i += input_length
return combine_static_parameter_interventions(interventions)

return intervention_generator


def start_time_param_value_objective(
param_name: List[str],
param_value: List[Intervention] = [None],
Expand Down
4 changes: 2 additions & 2 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def propagate_uncertainty(self, x):
]
intervention_list.extend(
[self.interventions(torch.from_numpy(x))]
if not isinstance(self.interventions(torch.from_numpy(x)), list)
else self.interventions(torch.from_numpy(x))
# if not isinstance(self.interventions(torch.from_numpy(x)), list)
# else self.interventions(torch.from_numpy(x))
)
static_parameter_interventions = combine_static_parameter_interventions(
intervention_list
Expand Down

0 comments on commit adaaa94

Please sign in to comment.