From cc564b09bf537672ce09eabf471c0f519043fa54 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 2 Nov 2023 15:09:43 -0400 Subject: [PATCH 1/5] initial pass at output processing --- .../integration_utils/result_processing.py | 69 +++++++++++++++++++ pyciemss/interfaces.py | 12 ++-- tests/fixtures.py | 4 +- tests/test_interfaces.py | 36 ++++++++-- 4 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 pyciemss/integration_utils/result_processing.py diff --git a/pyciemss/integration_utils/result_processing.py b/pyciemss/integration_utils/result_processing.py new file mode 100644 index 000000000..ad737baf0 --- /dev/null +++ b/pyciemss/integration_utils/result_processing.py @@ -0,0 +1,69 @@ +from typing import Any, Dict + +import numpy as np +import pandas as pd +import torch + + +def prepare_interchange_dictionary( + samples: Dict[str, torch.Tensor], +) -> Dict[str, Any]: + processed_samples = convert_to_output_format(samples) + + result = {"data": processed_samples, "unprocessed_result": samples} + + return result + + +def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame: + """ + Convert the samples from the Pyro model to a DataFrame in the TA4 requested format. + + time_unit -- Label timepoints in a semantically relevant way `timepoint_`. + If None, a `timepoint_` field is not provided. + """ + + pyciemss_results: Dict[str, Dict[str, torch.Tensor]] = { + "parameters": {}, + "states": {}, + } + + for name, sample in samples.items(): + if sample.ndim == 1: + # Any 1D array is a sample from the distribution over parameters. + # Any 2D array is a sample from the distribution over states, unless it's a model weight. + name = name + "_param" + pyciemss_results["parameters"][name] = ( + sample.data.detach().cpu().numpy().astype(np.float64) + ) + else: + pyciemss_results["states"][name] = ( + sample.data.detach().cpu().numpy().astype(np.float64) + ) + + num_samples, num_timepoints = next(iter(pyciemss_results["states"].values())).shape + output = { + "timepoint_id": np.tile(np.array(range(num_timepoints)), num_samples), + "sample_id": np.repeat(np.array(range(num_samples)), num_timepoints), + } + + # Parameters + output = { + **output, + **{ + k: np.repeat(v, num_timepoints) + for k, v in pyciemss_results["parameters"].items() + }, + } + + # Solution (state variables) + output = { + **output, + **{ + k: np.squeeze(v.reshape((num_timepoints * num_samples, 1))) + for k, v in pyciemss_results["states"].items() + }, + } + + result = pd.DataFrame(output) + return result diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 4532bd918..063b06821 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -14,6 +14,7 @@ from pyciemss.compiled_dynamics import CompiledDynamics from pyciemss.integration_utils.custom_decorators import pyciemss_logging_wrapper +from pyciemss.integration_utils.result_processing import prepare_interchange_dictionary @pyciemss_logging_wrapper @@ -44,9 +45,6 @@ def sample( - The step size to use for logging the trajectory. num_samples: int - The number of samples to draw from the model. - interventions: Optional[Iterable[Tuple[float, str, float]]] - - A list of interventions to apply to the model. - Each intervention is a tuple of the form (time, parameter_name, value). solver_method: str - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details. - If performance is incredibly slow, we suggest using `euler` to debug. @@ -57,7 +55,7 @@ def sample( - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. - By default we set the `start_time` to be 0. - inferred_parameters: + inferred_parameters: Optional[pyro.nn.PyroModule] - A Pyro module that contains the inferred parameters of the model. This is typically the result of `calibrate`. - If not provided, we will use the default values from the AMR model. @@ -106,12 +104,14 @@ def wrapped_model(): TorchDiffEq(method=solver_method, options=solver_options), ) # Adding deterministic nodes to the model so that we can access the trajectory in the Predictive object. - [pyro.deterministic(f"state_{k}", v) for k, v in lt.trajectory.items()] + [pyro.deterministic(f"{k}_state", v) for k, v in lt.trajectory.items()] - return pyro.infer.Predictive( + samples = pyro.infer.Predictive( wrapped_model, guide=inferred_parameters, num_samples=num_samples )() + return prepare_interchange_dictionary(samples) + # # TODO # def calibrate( diff --git a/tests/fixtures.py b/tests/fixtures.py index 061101789..36ddfbfb5 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -46,7 +46,7 @@ def check_states_match_in_all_but_values( assert check_keys_match(traj1, traj2) for k in traj1.keys(): - if k[:5] == "state": + if k[-5:] == "state": assert not torch.allclose( traj2[k], traj1[k] ), f"Trajectories are identical in state trajectory of variable {k}, but should differ." @@ -65,7 +65,7 @@ def check_result_sizes( assert isinstance(k, str) assert isinstance(v, torch.Tensor) - if k[:5] == "state": + if k[-5:] == "state": assert v.shape == ( num_samples, len(torch.arange(start_time, end_time, logging_step_size)) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 4ea1d6957..bb8950ea1 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -1,5 +1,7 @@ import pytest import torch +import pandas as pd +import numpy as np from pyciemss.compiled_dynamics import CompiledDynamics from pyciemss.interfaces import sample @@ -25,7 +27,7 @@ def test_sample_no_interventions( ): result = sample( url, end_time, logging_step_size, num_samples, start_time=start_time - ) + )["unprocessed_result"] assert isinstance(result, dict) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -58,11 +60,11 @@ def test_sample_with_static_interventions( num_samples, start_time=start_time, static_interventions=static_interventions, - ) + )["unprocessed_result"] result = sample( url, end_time, logging_step_size, num_samples, start_time=start_time - ) + )["unprocessed_result"] check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -106,11 +108,11 @@ def intervention_event_fn_2(time: torch.Tensor, *args, **kwargs): num_samples, start_time=start_time, dynamic_interventions=dynamic_interventions, - ) + )["unprocessed_result"] result = sample( url, end_time, logging_step_size, num_samples, start_time=start_time - ) + )["unprocessed_result"] check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) @@ -151,14 +153,34 @@ def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs): start_time=start_time, static_interventions=static_interventions, dynamic_interventions=dynamic_interventions, - ) + )["unprocessed_result"] result = sample( url, end_time, logging_step_size, num_samples, start_time=start_time - ) + )["unprocessed_result"] check_states_match_in_all_but_values(result, intervened_result) check_result_sizes(result, start_time, end_time, logging_step_size, num_samples) check_result_sizes( intervened_result, start_time, end_time, logging_step_size, num_samples ) + +@pytest.mark.parametrize("url", MODEL_URLS) +@pytest.mark.parametrize("start_time", START_TIMES) +@pytest.mark.parametrize("end_time", END_TIMES) +@pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("num_samples", NUM_SAMPLES) +def test_output_format(url, start_time, end_time, logging_step_size, num_samples): + processed_result = sample( + url, end_time, logging_step_size, num_samples, start_time=start_time + )["data"] + assert isinstance(processed_result, pd.DataFrame) + assert processed_result.shape[0] == num_samples * len(torch.arange(start_time+logging_step_size, end_time, logging_step_size)) + assert processed_result.shape[1] >= 2 + assert list(processed_result.columns)[:2] == ["timepoint_id", "sample_id"] + for col_name in processed_result.columns[2:]: + assert col_name.split("_")[-1] in ("param", "state", "(unknown)") + assert processed_result[col_name].dtype == np.float64 + + assert processed_result["timepoint_id"].dtype == np.int64 + assert processed_result["sample_id"].dtype == np.int64 \ No newline at end of file From 133c11ce0d61889e985707ee2aa02f665b65a5e0 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 2 Nov 2023 15:10:06 -0400 Subject: [PATCH 2/5] lint --- tests/test_interfaces.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index bb8950ea1..de476ab80 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -1,7 +1,7 @@ +import numpy as np +import pandas as pd import pytest import torch -import pandas as pd -import numpy as np from pyciemss.compiled_dynamics import CompiledDynamics from pyciemss.interfaces import sample @@ -165,6 +165,7 @@ def intervention_event_fn_1(time: torch.Tensor, *args, **kwargs): intervened_result, start_time, end_time, logging_step_size, num_samples ) + @pytest.mark.parametrize("url", MODEL_URLS) @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @@ -175,7 +176,9 @@ def test_output_format(url, start_time, end_time, logging_step_size, num_samples url, end_time, logging_step_size, num_samples, start_time=start_time )["data"] assert isinstance(processed_result, pd.DataFrame) - assert processed_result.shape[0] == num_samples * len(torch.arange(start_time+logging_step_size, end_time, logging_step_size)) + assert processed_result.shape[0] == num_samples * len( + torch.arange(start_time + logging_step_size, end_time, logging_step_size) + ) assert processed_result.shape[1] >= 2 assert list(processed_result.columns)[:2] == ["timepoint_id", "sample_id"] for col_name in processed_result.columns[2:]: @@ -183,4 +186,4 @@ def test_output_format(url, start_time, end_time, logging_step_size, num_samples assert processed_result[col_name].dtype == np.float64 assert processed_result["timepoint_id"].dtype == np.int64 - assert processed_result["sample_id"].dtype == np.int64 \ No newline at end of file + assert processed_result["sample_id"].dtype == np.int64 From adc42779c17f231123bcb52663888372bdc1a5b5 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 2 Nov 2023 16:23:46 -0400 Subject: [PATCH 3/5] add pandas and numpy dependencies --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index 25a911f84..3cbbab777 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,8 @@ install_requires = chirho @ git+https://github.com/BasisResearch/chirho@f3019d4b22f4e49261efbf8da90a30095af2afbc sympytorch torchdiffeq + pandas + numpy zip_safe = false include_package_data = true From ccd1e4a84cdcee29e9097abb72b470fa86dd5fe0 Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Thu, 2 Nov 2023 22:49:05 -0400 Subject: [PATCH 4/5] remove outdated docstring --- pyciemss/integration_utils/result_processing.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyciemss/integration_utils/result_processing.py b/pyciemss/integration_utils/result_processing.py index ad737baf0..704c05ace 100644 --- a/pyciemss/integration_utils/result_processing.py +++ b/pyciemss/integration_utils/result_processing.py @@ -18,9 +18,6 @@ def prepare_interchange_dictionary( def convert_to_output_format(samples: Dict[str, torch.Tensor]) -> pd.DataFrame: """ Convert the samples from the Pyro model to a DataFrame in the TA4 requested format. - - time_unit -- Label timepoints in a semantically relevant way `timepoint_`. - If None, a `timepoint_` field is not provided. """ pyciemss_results: Dict[str, Dict[str, torch.Tensor]] = { From 1512d4fde5f93280ce537a3af1e5dfbb91eac84b Mon Sep 17 00:00:00 2001 From: Sam Witty Date: Fri, 3 Nov 2023 10:54:20 -0400 Subject: [PATCH 5/5] Remove package finding AND add to module exports (#407) (#409) * Remove package finding AND add to module exports * Clear out top-level `__init__.py` * Reinclude explicit exports * Include subpackages * Skip incorrect lint error * reorder imports for lint * revert previous commit and add missing space before noqa --------- Co-authored-by: Five Grant <5@fivegrant.com> --- .gitignore | 1 + pyciemss/__init__.py | 3 ++- setup.cfg | 14 ++++---------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/.gitignore b/.gitignore index 7437f008c..604ab9be1 100644 --- a/.gitignore +++ b/.gitignore @@ -339,3 +339,4 @@ venv/ *~ *_schema.json *data_with_missing_entries.csv +*.nix diff --git a/pyciemss/__init__.py b/pyciemss/__init__.py index 03fbc1db8..b2d16ada0 100644 --- a/pyciemss/__init__.py +++ b/pyciemss/__init__.py @@ -1 +1,2 @@ -from .mira_integration import compiled_dynamics # noqa: F401 +from pyciemss.interfaces import sample # noqa: F401 +from pyciemss.mira_integration import compiled_dynamics # noqa: F401 diff --git a/setup.cfg b/setup.cfg index 3cbbab777..963a51bef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,21 +20,15 @@ zip_safe = false include_package_data = true python_requires = >=3.8 -packages = find: - -package_dir = - = pyciemss +packages = + pyciemss + pyciemss.integration_utils + pyciemss.mira_integration [options.package_data] * = *.json -[options.packages.find] - -where = - pyciemss - - [options.extras_require] tests = pytest