Skip to content

Commit

Permalink
Added different types of fixed interventions to optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Nov 13, 2024
1 parent 08cf9fc commit 798cade
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 13 deletions.
4 changes: 2 additions & 2 deletions docs/source/interfaces.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3327,7 +3327,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "askem",
"language": "python",
"name": "python3"
},
Expand All @@ -3341,7 +3341,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions docs/source/optimize_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,7 @@
"risk_bound = 3e5\n",
"qoi = lambda y: obs_max_qoi(y, observed_params)\n",
"objfun = lambda x: np.sum(np.abs(param_current - x))\n",
"fixed_interventions = {10.: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result5 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1273,7 +1273,7 @@
")\n",
"\n",
"# Fixed intervention on hosp parameter\n",
"fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result6 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1513,7 +1513,7 @@
")\n",
"\n",
"# Fixed intervention on hosp parameter\n",
"fixed_interventions = {10.0: {\"hosp\": torch.tensor(0.1)}}\n",
"fixed_interventions = {torch.tensor(10.0): {\"hosp\": torch.tensor(0.1)}}\n",
"\n",
"# Run optimize interface\n",
"opt_result6 = pyciemss.optimize(\n",
Expand Down Expand Up @@ -1602,7 +1602,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "askem",
"language": "python",
"name": "python3"
},
Expand Down
50 changes: 45 additions & 5 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,13 +627,13 @@ def calibrate(
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
static_state_interventions: Dict[float, Dict[str, Intervention]]
static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
static_parameter_interventions: Dict[float, Dict[str, Intervention]]
static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
Expand Down Expand Up @@ -812,7 +812,18 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {},
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down Expand Up @@ -870,12 +881,38 @@ def optimize(
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of fixed static interventions to apply to the model and not optimize for.
fixed_static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of fixed static parameter interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_static_state_interventions: Dict[torch.Tensor, Dict[str, Intervention]]
- A dictionary of static state interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model and not optimize for.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model and not optimize for.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
n_samples_ouu: int
- The number of samples to draw from the model to estimate risk for each optimization iteration.
maxiter: int
Expand Down Expand Up @@ -929,6 +966,9 @@ def optimize(
num_samples=1,
guide=inferred_parameters,
fixed_static_parameter_interventions=fixed_static_parameter_interventions,
fixed_static_state_interventions=fixed_static_state_interventions,
fixed_dynamic_state_interventions=fixed_dynamic_state_interventions,
fixed_dynamic_parameter_interventions=fixed_dynamic_parameter_interventions,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
Expand Down
59 changes: 57 additions & 2 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,20 @@ def __init__(
risk_measure: List[Callable] = [lambda z: alpha_superquantile(z, alpha=0.95)],
num_samples: int = 1000,
guide=None,
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]] = {},
fixed_static_parameter_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_static_state_interventions: Dict[
torch.Tensor, Dict[str, Intervention]
] = {},
fixed_dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
fixed_dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention],
] = {},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
Expand All @@ -89,6 +102,11 @@ def __init__(
self.end_time = end_time
self.guide = guide
self.fixed_static_parameter_interventions = fixed_static_parameter_interventions
self.fixed_static_state_interventions = fixed_static_state_interventions
self.fixed_dynamic_state_interventions = fixed_dynamic_state_interventions
self.fixed_dynamic_parameter_interventions = (
fixed_dynamic_parameter_interventions
)
self.solver_method = solver_method
self.solver_options = solver_options
self.logging_times = torch.arange(
Expand Down Expand Up @@ -143,6 +161,34 @@ def propagate_uncertainty(self, x):
for time, static_intervention_assignment in static_parameter_interventions.items()
]

static_state_intervention_handlers = [
StaticIntervention(time, dict(**static_intervention_assignment))
for time, static_intervention_assignment in self.fixed_static_state_interventions.items()
]

dynamic_state_intervention_handlers = [
DynamicIntervention(
event_fn, dict(**dynamic_intervention_assignment)
)
for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_state_interventions.items()
]

dynamic_parameter_intervention_handlers = [
DynamicParameterIntervention(
event_fn,
dict(**dynamic_intervention_assignment),
is_traced=True,
)
for event_fn, dynamic_intervention_assignment in self.fixed_dynamic_parameter_interventions.items()
]

intervention_handlers = (
static_state_intervention_handlers
+ static_parameter_intervention_handlers
+ dynamic_state_intervention_handlers
+ dynamic_parameter_intervention_handlers
)

def wrapped_model():
with ParameterInterventionTracer():
with TorchDiffEq(
Expand All @@ -152,7 +198,7 @@ def wrapped_model():
options=self.solver_options,
):
with contextlib.ExitStack() as stack:
for handler in static_parameter_intervention_handlers:
for handler in intervention_handlers:
stack.enter_context(handler)
self.model(
torch.as_tensor(self.start_time),
Expand All @@ -161,6 +207,15 @@ def wrapped_model():
is_traced=True,
)

parallel = (
False
if len(
dynamic_parameter_intervention_handlers
+ dynamic_state_intervention_handlers
)
> 0
else True
)
# Sample from intervened model
samples = pyro.infer.Predictive(
wrapped_model,
Expand Down

0 comments on commit 798cade

Please sign in to comment.