Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic output formatting #405

Merged
merged 5 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,4 @@ venv/
*~
*_schema.json
*data_with_missing_entries.csv
*.nix
3 changes: 2 additions & 1 deletion pyciemss/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .mira_integration import compiled_dynamics # noqa: F401
from pyciemss.interfaces import sample # noqa: F401
from pyciemss.mira_integration import compiled_dynamics # noqa: F401
66 changes: 66 additions & 0 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Any, Dict

import numpy as np
import pandas as pd
import torch


def prepare_interchange_dictionary(
samples: Dict[str, torch.Tensor],
) -> Dict[str, Any]:
processed_samples = convert_to_output_format(samples)

result = {"data": processed_samples, "unprocessed_result": samples}

return result


def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame:
"""
Convert the samples from the Pyro model to a DataFrame in the TA4 requested format.
"""

pyciemss_results: Dict[str, Dict[str, torch.Tensor]] = {
"parameters": {},
"states": {},
}

for name, sample in samples.items():
if sample.ndim == 1:
# Any 1D array is a sample from the distribution over parameters.
# Any 2D array is a sample from the distribution over states, unless it's a model weight.
name = name + "_param"
pyciemss_results["parameters"][name] = (
sample.data.detach().cpu().numpy().astype(np.float64)
)
else:
pyciemss_results["states"][name] = (
sample.data.detach().cpu().numpy().astype(np.float64)
)

num_samples, num_timepoints = next(iter(pyciemss_results["states"].values())).shape
output = {
"timepoint_id": np.tile(np.array(range(num_timepoints)), num_samples),
"sample_id": np.repeat(np.array(range(num_samples)), num_timepoints),
}

# Parameters
output = {
**output,
**{
k: np.repeat(v, num_timepoints)
for k, v in pyciemss_results["parameters"].items()
},
}

# Solution (state variables)
output = {
**output,
**{
k: np.squeeze(v.reshape((num_timepoints * num_samples, 1)))
for k, v in pyciemss_results["states"].items()
},
}

result = pd.DataFrame(output)
return result
12 changes: 6 additions & 6 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from pyciemss.compiled_dynamics import CompiledDynamics
from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper
from pyciemss.integration_utils.result_processing import prepare_interchange_dictionary


@pyciemss_logging_wrapper
Expand Down Expand Up @@ -44,9 +45,6 @@ def sample(
- The step size to use for logging the trajectory.
num_samples: int
- The number of samples to draw from the model.
interventions: Optional[Iterable[Tuple[float, str, float]]]
- A list of interventions to apply to the model.
Each intervention is a tuple of the form (time, parameter_name, value).
solver_method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug.
Expand All @@ -57,7 +55,7 @@ def sample(
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
inferred_parameters:
inferred_parameters: Optional[pyro.nn.PyroModule]
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
Expand Down Expand Up @@ -106,12 +104,14 @@ def wrapped_model():
TorchDiffEq(method=solver_method, options=solver_options),
)
# Adding deterministic nodes to the model so that we can access the trajectory in the Predictive object.
[pyro.deterministic(f"state_{k}", v) for k, v in lt.trajectory.items()]
[pyro.deterministic(f"{k}_state", v) for k, v in lt.trajectory.items()]

return pyro.infer.Predictive(
samples = pyro.infer.Predictive(
wrapped_model, guide=inferred_parameters, num_samples=num_samples
)()

return prepare_interchange_dictionary(samples)


# # TODO
# def calibrate(
Expand Down
16 changes: 6 additions & 10 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,22 @@ install_requires =
chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc
sympytorch
torchdiffeq
pandas
numpy

zip_safe = false
include_package_data = true
python_requires = >=3.8

packages = find:

package_dir =
= pyciemss
packages =
pyciemss
pyciemss.integration_utils
pyciemss.mira_integration

[options.package_data]
* = *.json


[options.packages.find]

where =
pyciemss


[options.extras_require]
tests =
pytest
Expand Down
4 changes: 2 additions & 2 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def check_states_match_in_all_but_values(
assert check_keys_match(traj1, traj2)

for k in traj1.keys():
if k[:5] == "state":
if k[-5:] == "state":
assert not torch.allclose(
traj2[k], traj1[k]
), f"Trajectories are identical in state trajectory of variable {k}, but should differ."
Expand All @@ -65,7 +65,7 @@ def check_result_sizes(
assert isinstance(k, str)
assert isinstance(v, torch.Tensor)

if k[:5] == "state":
if k[-5:] == "state":
assert v.shape == (
num_samples,
len(torch.arange(start_time, end_time, logging_step_size))
Expand Down
39 changes: 32 additions & 7 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
import pandas as pd
import pytest
import torch

Expand Down Expand Up @@ -25,7 +27,7 @@ def test_sample_no_interventions(
):
result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)
)["unprocessed_result"]
assert isinstance(result, dict)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)

Expand Down Expand Up @@ -58,11 +60,11 @@ def test_sample_with_static_interventions(
num_samples,
start_time=start_time,
static_interventions=static_interventions,
)
)["unprocessed_result"]

result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)
)["unprocessed_result"]

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
Expand Down Expand Up @@ -106,11 +108,11 @@ def intervention_event_fn_2(time: torch.Tensor, *args, **kwargs):
num_samples,
start_time=start_time,
dynamic_interventions=dynamic_interventions,
)
)["unprocessed_result"]

result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)
)["unprocessed_result"]

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
Expand Down Expand Up @@ -151,14 +153,37 @@ def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs):
start_time=start_time,
static_interventions=static_interventions,
dynamic_interventions=dynamic_interventions,
)
)["unprocessed_result"]

result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)
)["unprocessed_result"]

check_states_match_in_all_but_values(result, intervened_result)
check_result_sizes(result, start_time, end_time, logging_step_size, num_samples)
check_result_sizes(
intervened_result, start_time, end_time, logging_step_size, num_samples
)


@pytest.mark.parametrize("url", MODEL_URLS)
@pytest.mark.parametrize("start_time", START_TIMES)
@pytest.mark.parametrize("end_time", END_TIMES)
@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES)
@pytest.mark.parametrize("num_samples", NUM_SAMPLES)
def test_output_format(url, start_time, end_time, logging_step_size, num_samples):
processed_result = sample(
url, end_time, logging_step_size, num_samples, start_time=start_time
)["data"]
assert isinstance(processed_result, pd.DataFrame)
assert processed_result.shape[0] == num_samples * len(
torch.arange(start_time + logging_step_size, end_time, logging_step_size)
)
assert processed_result.shape[1] >= 2
assert list(processed_result.columns)[:2] == ["timepoint_id", "sample_id"]
for col_name in processed_result.columns[2:]:
assert col_name.split("_")[-1] in ("param", "state", "(unknown)")
assert processed_result[col_name].dtype == np.float64

assert processed_result["timepoint_id"].dtype == np.int64
assert processed_result["sample_id"].dtype == np.int64
Loading