diff --git a/docs/source/optimize_interface.ipynb b/docs/source/optimize_interface.ipynb index 2dd686b0..54e2b597 100644 --- a/docs/source/optimize_interface.ipynb +++ b/docs/source/optimize_interface.ipynb @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1460,7 +1460,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -1473,7 +1473,9 @@ } ], "source": [ - "intervened_params = [\"beta_c\", \"gamma\"]\n", + "from pyciemss.integration_utils.intervention_builder import intervention_func_combinator\n", + "\n", + "intervened_params = [\"beta_c\", \"gamma\", \"gamma_c\"]\n", "static_parameter_interventions1 = param_value_objective(\n", " param_name=[intervened_params[0]],\n", " start_time=torch.tensor([10.0]),\n", @@ -1482,6 +1484,7 @@ " param_name=[intervened_params[1]],\n", " param_value=torch.tensor([0.4]),\n", ")\n", + "\n", "# Combine different intervention templates into a list of Callables\n", "static_parameter_interventions = lambda x: [\n", " static_parameter_interventions1(torch.atleast_1d(x[0])),\n", @@ -1494,6 +1497,30 @@ "print(static_parameter_interventions(torch.tensor([0.4, 5.0])))" ] }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{10.0: {'beta_c': tensor(0.4000)}, 5.0: {'gamma': tensor(0.4000)}}" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "static_parameter_interventions_comb = intervention_func_combinator(\n", + " [static_parameter_interventions1, static_parameter_interventions2],\n", + " [1, 1],\n", + ")\n", + "static_parameter_interventions_comb(torch.tensor([0.4, 5.0]))" + ] + }, { "cell_type": "code", "execution_count": 11, diff --git a/pyciemss/integration_utils/intervention_builder.py b/pyciemss/integration_utils/intervention_builder.py index 16551092..90d57c70 100644 --- a/pyciemss/integration_utils/intervention_builder.py +++ b/pyciemss/integration_utils/intervention_builder.py @@ -73,6 +73,32 @@ def intervention_generator( return intervention_generator +def intervention_func_combinator( + intervention_funcs: List[ + Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]] + ], + intervention_func_lengths: List[int], +) -> Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]: + assert len(intervention_funcs) == len(intervention_func_lengths) + + total_length = sum(intervention_func_lengths) + + # Note: This only works for combining static parameter interventions. + def intervention_generator( + x: torch.Tensor, + ) -> Dict[float, Dict[str, Intervention]]: + x = torch.atleast_1d(x) + assert x.size()[0] == total_length + interventions = [None] * len(intervention_funcs) + i = 0 + for j, (input_length, intervention_func) in enumerate(zip(intervention_func_lengths, intervention_funcs)): + interventions[j] = intervention_func(x[i:i + input_length]) + i += input_length + return combine_static_parameter_interventions(interventions) + + return intervention_generator + + def start_time_param_value_objective( param_name: List[str], param_value: List[Intervention] = [None], diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index c86197b9..8f7ea823 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -139,8 +139,8 @@ def propagate_uncertainty(self, x): ] intervention_list.extend( [self.interventions(torch.from_numpy(x))] - if not isinstance(self.interventions(torch.from_numpy(x)), list) - else self.interventions(torch.from_numpy(x)) + # if not isinstance(self.interventions(torch.from_numpy(x)), list) + # else self.interventions(torch.from_numpy(x)) ) static_parameter_interventions = combine_static_parameter_interventions( intervention_list