Skip to content

Commit

Permalink
Add template for max value of QoI within simulated time (#561)
Browse files Browse the repository at this point in the history
* Add template for max value of QoI within simulated time

* lint

* Update optimize_interface.ipynb

* Update optimize_interface.ipynb

* Adding tests

* Lint
  • Loading branch information
anirban-chaudhuri authored Apr 3, 2024
1 parent 3f16962 commit 8283eca
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 14 deletions.
187 changes: 175 additions & 12 deletions docs/source/optimize_interface.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions pyciemss/ouu/qoi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,14 @@ def obs_nday_average_qoi(
dataQoI = samples[contexts[0]].detach().numpy()

return np.mean(dataQoI[:, -ndays:], axis=1)


def obs_max_qoi(samples: Dict[str, torch.Tensor], contexts: List) -> np.ndarray:
"""
Return maximum value over simulated time.
samples is is the output from a Pyro Predictive object.
samples[VARIABLE] is expected to have dimension (nreplicates, ntimepoints)
"""
dataQoI = samples[contexts[0]].detach().numpy()

return np.max(dataQoI, axis=1)
20 changes: 19 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
param_value_objective,
start_time_objective,
)
from pyciemss.ouu.qoi import obs_nday_average_qoi
from pyciemss.ouu.qoi import obs_max_qoi, obs_nday_average_qoi

T = TypeVar("T")

Expand Down Expand Up @@ -110,6 +110,19 @@ def __init__(
"bounds_interventions": [[0.0], [40.0]],
}

optimize_kwargs_SIRstockflow_param_maxQoI = {
"qoi": lambda x: obs_max_qoi(x, ["I_state"]),
"risk_bound": 300.0,
"static_parameter_interventions": param_value_objective(
param_name=["p_cbeta"],
param_value=[lambda x: torch.tensor([x])],
start_time=[torch.tensor(1.0)],
),
"objfun": lambda x: np.abs(0.35 - x),
"initial_guess_interventions": 0.15,
"bounds_interventions": [[0.1], [0.5]],
}

OPT_MODELS = [
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
Expand All @@ -121,6 +134,11 @@ def __init__(
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_time,
),
ModelFixture(
os.path.join(MODELS_PATH, "SIR_stockflow.json"),
important_parameter="p_cbeta",
optimize_kwargs=optimize_kwargs_SIRstockflow_param_maxQoI,
),
]

BAD_AMRS = [
Expand Down
1 change: 0 additions & 1 deletion tests/visuals/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ def test_nested_mark_sources(schema_file):


if __name__ == "__main__":

parser = argparse.ArgumentParser("Utility to generate reference images")
parser.add_argument(
"schema",
Expand Down

0 comments on commit 8283eca

Please sign in to comment.