Skip to content

Commit

Permalink
Optimize API Changes (#124)
Browse files Browse the repository at this point in the history
* optimize_interventions to list.

* update initial_guess_interventions

* update test payload
  • Loading branch information
Tom-Szendrey authored Nov 4, 2024
1 parent 066f4b1 commit 93ec358
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 48 deletions.
108 changes: 64 additions & 44 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import ClassVar, List, Optional, Union
from typing import ClassVar, List, Optional, Union, Callable, Dict
from chirho.interventional.ops import Intervention
from enum import Enum
from utils.rabbitmq import OptimizeHook
from pika.exceptions import AMQPConnectionError
Expand All @@ -13,6 +14,7 @@
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan, HMIIntervention
from pyciemss.integration_utils.intervention_builder import (
intervention_func_combinator,
param_value_objective,
start_time_objective,
start_time_param_value_objective,
Expand Down Expand Up @@ -59,23 +61,26 @@ def gen_risk_bound(self):
return -self.risk_bound


def objfun(x, initial_guess, objective_function_option, relative_importance):
def objfun(x, optimize_interventions: list[InterventionObjective]):
"""
Calculate the weighted sum of objective functions based on the given parameters.
Parameters:
x (list): The current values of the variables.
initial_guess (list): The initial guess values of the variables.
objective_function_option (list): List of options specifying the type of objective function for each variable.
relative_importance (list): List of weights indicating the relative importance of each variable.
optimize_interventions: The interventions which are being optimized over
Returns:
float: The weighted sum of the objective functions.
"""

# Initialize the total sum to zero
total_sum = 0

# TODO: Will be cleaning this up in the next PR. I want minimal changes per PR for readability.
relative_importance = [i.relative_importance for i in optimize_interventions]
initial_guess = [i.initial_guess for i in optimize_interventions]
objective_function_option = [
i.objective_function_option for i in optimize_interventions
]
# Calculate the sum of all weights, fallback to 1 if the sum is 0
sum_of_all_weights = np.sum(relative_importance) or 1

Expand Down Expand Up @@ -116,9 +121,9 @@ class InterventionObjective(BaseModel):
param_names: list[str]
param_values: Optional[list[Optional[float]]] = None
start_time: Optional[list[float]] = None
objective_function_option: Optional[list[str]] = None
objective_function_option: Optional[str] = None
initial_guess: Optional[list[float]] = None
relative_importance: Optional[list[float]] = None
relative_importance: Optional[float] = None


class OptimizeExtra(BaseModel):
Expand Down Expand Up @@ -150,7 +155,9 @@ class Optimize(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "optimize"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
optimize_interventions: InterventionObjective # These are the interventions to be optimized.
optimize_interventions: list[
InterventionObjective
] # These are the interventions to be optimized.
fixed_interventions: list[HMIIntervention] = Field(
None
) # Theses are interventions provided that will not be optimized
Expand All @@ -170,33 +177,47 @@ def gen_pyciemss_args(self, job_id):
fixed_static_state_interventions,
) = convert_static_interventions(self.fixed_interventions)

intervention_type = self.optimize_interventions.intervention_type
if intervention_type == "param_value":
assert self.optimize_interventions.start_time is not None
start_time = [
torch.tensor(time) for time in self.optimize_interventions.start_time
]
param_value = [None] * len(self.optimize_interventions.param_names)

optimize_interventions = param_value_objective(
start_time=start_time,
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
if intervention_type == "start_time":
assert self.optimize_interventions.param_values is not None
param_value = [
torch.tensor(value)
for value in self.optimize_interventions.param_values
]
optimize_interventions = start_time_objective(
param_name=self.optimize_interventions.param_names,
param_value=param_value,
)
if intervention_type == "start_time_param_value":
optimize_interventions = start_time_param_value_objective(
param_name=self.optimize_interventions.param_names
)
transformed_optimize_interventions: list[
Callable[[torch.Tensor], Dict[float, Dict[str, Intervention]]]
] = []
intervention_func_lengths: list[int] = []
for i in range(len(self.optimize_interventions)):
currentIntervention = self.optimize_interventions[i]
intervention_type = currentIntervention.intervention_type
if intervention_type == "param_value":
assert currentIntervention.start_time is not None
start_time = [
torch.tensor(time) for time in currentIntervention.start_time
]
param_value = [None] * len(currentIntervention.param_names)

transformed_optimize_interventions.append(
param_value_objective(
start_time=start_time,
param_name=currentIntervention.param_names,
param_value=param_value,
)
)
intervention_func_lengths.append(1)
if intervention_type == "start_time":
assert currentIntervention.param_values is not None
param_value = [
torch.tensor(value) for value in currentIntervention.param_values
]
transformed_optimize_interventions.append(
start_time_objective(
param_name=currentIntervention.param_names,
param_value=param_value,
)
)
intervention_func_lengths.append(1)
if intervention_type == "start_time_param_value":
transformed_optimize_interventions.append(
start_time_param_value_objective(
param_name=currentIntervention.param_names
)
)
intervention_func_lengths.append(2)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
Expand Down Expand Up @@ -236,17 +257,16 @@ def progress_hook(current_results):
"logging_step_size": self.logging_step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"objfun": lambda x: objfun(
x,
self.optimize_interventions.initial_guess,
self.optimize_interventions.objective_function_option,
self.optimize_interventions.relative_importance,
),
"objfun": lambda x: objfun(x, self.optimize_interventions),
"qoi": qoi_methods,
"risk_bound": risk_bounds,
"initial_guess_interventions": self.optimize_interventions.initial_guess,
"initial_guess_interventions": [
i.initial_guess for i in self.optimize_interventions
],
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": optimize_interventions,
"static_parameter_interventions": intervention_func_combinator(
transformed_optimize_interventions, intervention_func_lengths
),
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
# https://github.com/DARPA-ASKEM/terarium/issues/4612
# "fixed_static_state_interventions": fixed_static_state_interventions,
Expand Down
8 changes: 4 additions & 4 deletions tests/examples/optimize/input/request.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
"engine": "ciemss",
"user_id": "not_provided",
"model_config_id": "sidarthe",
"optimize_interventions": {
"optimize_interventions": [{
"intervention_type": "param_value",
"objective_function_option": ["lower_bound"],
"objective_function_option": "lower_bound",
"start_time": [2],
"param_names": ["beta"],
"param_values": [0.02],
"initial_guess": [0],
"relative_importance": [1]
},
"relative_importance": 1
}],
"timespan": {
"start": 0,
"end": 90
Expand Down

0 comments on commit 93ec358

Please sign in to comment.