Skip to content

Commit

Permalink
Adding fixed interventions as input to optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Apr 8, 2024
1 parent 8283eca commit e085e19
Show file tree
Hide file tree
Showing 3 changed files with 428 additions and 30 deletions.
410 changes: 396 additions & 14 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

41 changes: 26 additions & 15 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def optimize(
solver_options: Dict[str, Any] = {},
start_time: float = 0.0,
inferred_parameters: Optional[pyro.nn.PyroModule] = None,
fixed_static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {},
n_samples_ouu: int = int(1e3),
maxiter: int = 5,
maxfeval: int = 25,
Expand Down Expand Up @@ -818,6 +819,12 @@ def optimize(
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
fixed_static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of fixed static interventions to apply to the model and not optimize for.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
n_samples_ouu: int
- The number of samples to draw from the model to estimate risk for each optimization iteration.
maxiter: int
Expand All @@ -838,6 +845,7 @@ def optimize(
- Optimization results as scipy object.
"""
check_solver(solver_method, solver_options)
print(fixed_static_parameter_interventions)

with torch.no_grad():
control_model = CompiledDynamics.load(model_path_or_json)
Expand All @@ -855,6 +863,7 @@ def optimize(
risk_measure=lambda z: alpha_superquantile(z, alpha=alpha),
num_samples=1,
guide=inferred_parameters,
fixed_static_parameter_interventions=fixed_static_parameter_interventions,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
Expand All @@ -872,21 +881,23 @@ def optimize(
print(f"Time taken: ({forward_time/1.:.2e} seconds per model evaluation).")

# Assign the required number of MC samples for each OUU iteration
RISK = computeRisk(
model=control_model,
interventions=static_parameter_interventions,
qoi=qoi,
end_time=end_time,
logging_step_size=logging_step_size,
start_time=start_time,
risk_measure=lambda z: alpha_superquantile(z, alpha=alpha),
num_samples=n_samples_ouu,
guide=inferred_parameters,
solver_method=solver_method,
solver_options=solver_options,
u_bounds=bounds_np,
risk_bound=risk_bound,
)
# RISK = computeRisk(
# model=control_model,
# interventions=static_parameter_interventions,
# qoi=qoi,
# end_time=end_time,
# logging_step_size=logging_step_size,
# start_time=start_time,
# risk_measure=lambda z: alpha_superquantile(z, alpha=alpha),
# num_samples=n_samples_ouu,
# guide=inferred_parameters,
# fixed_static_parameter_interventions=fixed_static_parameter_interventions,
# solver_method=solver_method,
# solver_options=solver_options,
# u_bounds=bounds_np,
# risk_bound=risk_bound,
# )
RISK.num_samples = n_samples_ouu
# Define constraints >= 0
constraints = (
# risk constraint
Expand Down
7 changes: 6 additions & 1 deletion pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
risk_measure: Callable = lambda z: alpha_superquantile(z, alpha=0.95),
num_samples: int = 1000,
guide=None,
fixed_static_parameter_interventions: Dict[torch.Tensor, Dict[str, Intervention]] = {},
solver_method: str = "dopri5",
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
Expand All @@ -84,6 +85,7 @@ def __init__(
self.start_time = start_time
self.end_time = end_time
self.guide = guide
self.fixed_static_parameter_interventions = fixed_static_parameter_interventions
self.solver_method = solver_method
self.solver_options = solver_options
self.logging_times = torch.arange(
Expand Down Expand Up @@ -119,7 +121,10 @@ def propagate_uncertainty(self, x):
with pyro.poutine.seed(rng_seed=0):
with torch.no_grad():
x = np.atleast_1d(x)
static_parameter_interventions = self.interventions(torch.from_numpy(x))
# Existing interventions
static_parameter_interventions = self.fixed_static_parameter_interventions
# Intervention being optimized
static_parameter_interventions.update(self.interventions(torch.from_numpy(x)))
static_parameter_intervention_handlers = [
StaticParameterIntervention(
time, dict(**static_intervention_assignment)
Expand Down

0 comments on commit e085e19

Please sign in to comment.