diff --git a/docs/source/interfaces.ipynb b/docs/source/interfaces.ipynb index b9e70833..dd4de0d3 100644 --- a/docs/source/interfaces.ipynb +++ b/docs/source/interfaces.ipynb @@ -3327,7 +3327,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "askem", "language": "python", "name": "python3" }, @@ -3341,7 +3341,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/docs/source/optimize_interface.ipynb b/docs/source/optimize_interface.ipynb index 0246a28b..7004d7ba 100644 --- a/docs/source/optimize_interface.ipynb +++ b/docs/source/optimize_interface.ipynb @@ -936,7 +936,7 @@ "risk_bound = 3e5\n", "qoi = lambda y: obs_max_qoi(y, observed_params)\n", "objfun = lambda x: np.sum(np.abs(param_current - x))\n", - "fixed_interventions = {10.: {\"hosp\": torch.tensor(0.1)}}\n", + "fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n", "\n", "# Run optimize interface\n", "opt_result5 = pyciemss.optimize(\n", @@ -1273,7 +1273,7 @@ ")\n", "\n", "# Fixed intervention on hosp parameter\n", - "fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n", + "fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n", "\n", "# Run optimize interface\n", "opt_result6 = pyciemss.optimize(\n", @@ -1513,7 +1513,7 @@ ")\n", "\n", "# Fixed intervention on hosp parameter\n", - "fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n", + "fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n", "\n", "# Run optimize interface\n", "opt_result6 = pyciemss.optimize(\n", @@ -1602,7 +1602,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "askem", "language": "python", "name": "python3" }, diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 9ffe13f2..18513df8 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -627,13 +627,13 @@ def calibrate( - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. - By default we set the `start_time` to be 0. - static_state_interventions: Dict[float, Dict[str, Intervention]] + static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]] - A dictionary of static interventions to apply to the model. - Each key is the time at which the intervention is applied. - Each value is a dictionary of the form {state_variable_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. - static_parameter_interventions: Dict[float, Dict[str, Intervention]] + static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] - A dictionary of static interventions to apply to the model. - Each key is the time at which the intervention is applied. - Each value is a dictionary of the form {parameter_name: intervention_assignment}. @@ -812,7 +812,18 @@ def optimize( solver_options: Dict[str, Any] = {}, start_time: float = 0.0, inferred_parameters: Optional[pyro.nn.PyroModule] = None, - fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}, + fixed_static_parameter_interventions: Dict[ + torch.Tensor, Dict[str, Intervention] + ] = {}, + fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {}, + fixed_dynamic_state_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention], + ] = {}, + fixed_dynamic_parameter_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention], + ] = {}, n_samples_ouu: int = int(1e3), maxiter: int = 5, maxfeval: int = 25, @@ -870,12 +881,38 @@ def optimize( - A Pyro module that contains the inferred parameters of the model. This is typically the result of `calibrate`. - If not provided, we will use the default values from the AMR model. - fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] - - A dictionary of fixed static interventions to apply to the model and not optimize for. + fixed_static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] + - A dictionary of fixed static parameter interventions to apply to the model and not optimize for. - Each key is the time at which the intervention is applied. - Each value is a dictionary of the form {parameter_name: intervention_assignment}. - Note that the `intervention_assignment` can be any type supported by :func:`~chirho.interventional.ops.intervene`, including functions. + fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]] + - A dictionary of static state interventions to apply to the model and not optimize for. + - Each key is the time at which the intervention is applied. + - Each value is a dictionary of the form {state_variable_name: intervention_assignment}. + - Note that the `intervention_assignment` can be any type supported by + :func:`~chirho.interventional.ops.intervene`, including functions. + fixed_dynamic_state_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention] + ] + - A dictionary of dynamic interventions to apply to the model and not optimize for. + - Each key is a function that takes in the current state of the model and returns a tensor. + When this function crosses 0, the dynamic intervention is applied. + - Each value is a dictionary of the form {state_variable_name: intervention_assignment}. + - Note that the `intervention_assignment` can be any type supported by + :func:`~chirho.interventional.ops.intervene`, including functions. + fixed_dynamic_parameter_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention] + ] + - A dictionary of dynamic interventions to apply to the model and not optimize for. + - Each key is a function that takes in the current state of the model and returns a tensor. + When this function crosses 0, the dynamic intervention is applied. + - Each value is a dictionary of the form {parameter_name: intervention_assignment}. + - Note that the `intervention_assignment` can be any type supported by + :func:`~chirho.interventional.ops.intervene`, including functions. n_samples_ouu: int - The number of samples to draw from the model to estimate risk for each optimization iteration. maxiter: int @@ -929,6 +966,9 @@ def optimize( num_samples=1, guide=inferred_parameters, fixed_static_parameter_interventions=fixed_static_parameter_interventions, + fixed_static_state_interventions=fixed_static_state_interventions, + fixed_dynamic_state_interventions=fixed_dynamic_state_interventions, + fixed_dynamic_parameter_interventions=fixed_dynamic_parameter_interventions, solver_method=solver_method, solver_options=solver_options, u_bounds=bounds_np, diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 28629d4f..4579295c 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -73,7 +73,20 @@ def __init__( risk_measure: List[Callable] = [lambda z: alpha_superquantile(z, alpha=0.95)], num_samples: int = 1000, guide=None, - fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {}, + fixed_static_parameter_interventions: Dict[ + torch.Tensor, Dict[str, Intervention] + ] = {}, + fixed_static_state_interventions: Dict[ + torch.Tensor, Dict[str, Intervention] + ] = {}, + fixed_dynamic_state_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention], + ] = {}, + fixed_dynamic_parameter_interventions: Dict[ + Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + Dict[str, Intervention], + ] = {}, solver_method: str = "dopri5", solver_options: Dict[str, Any] = {}, u_bounds: np.ndarray = np.atleast_2d([[0], [1]]), @@ -89,6 +102,11 @@ def __init__( self.end_time = end_time self.guide = guide self.fixed_static_parameter_interventions = fixed_static_parameter_interventions + self.fixed_static_state_interventions = fixed_static_state_interventions + self.fixed_dynamic_state_interventions = fixed_dynamic_state_interventions + self.fixed_dynamic_parameter_interventions = ( + fixed_dynamic_parameter_interventions + ) self.solver_method = solver_method self.solver_options = solver_options self.logging_times = torch.arange( @@ -143,6 +161,34 @@ def propagate_uncertainty(self, x): for time, static_intervention_assignment in static_parameter_interventions.items() ] + static_state_intervention_handlers = [ + StaticIntervention(time, dict(**static_intervention_assignment)) + for time, static_intervention_assignment in self.fixed_static_state_interventions.items() + ] + + dynamic_state_intervention_handlers = [ + DynamicIntervention( + event_fn, dict(**dynamic_intervention_assignment) + ) + for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_state_interventions.items() + ] + + dynamic_parameter_intervention_handlers = [ + DynamicParameterIntervention( + event_fn, + dict(**dynamic_intervention_assignment), + is_traced=True, + ) + for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_parameter_interventions.items() + ] + + intervention_handlers = ( + static_state_intervention_handlers + + static_parameter_intervention_handlers + + dynamic_state_intervention_handlers + + dynamic_parameter_intervention_handlers + ) + def wrapped_model(): with ParameterInterventionTracer(): with TorchDiffEq( @@ -152,7 +198,7 @@ def wrapped_model(): options=self.solver_options, ): with contextlib.ExitStack() as stack: - for handler in static_parameter_intervention_handlers: + for handler in intervention_handlers: stack.enter_context(handler) self.model( torch.as_tensor(self.start_time), @@ -161,6 +207,15 @@ def wrapped_model(): is_traced=True, ) + parallel = ( + False + if len( + dynamic_parameter_intervention_handlers + + dynamic_state_intervention_handlers + ) + > 0 + else True + ) # Sample from intervened model samples = pyro.infer.Predictive( wrapped_model,