Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify interchange dictionary processing #564

Merged
merged 3 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 6 additions & 23 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from typing import Any, Dict, Iterable, Mapping, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -136,26 +135,6 @@ def get_times_for(intervention: str, intervention_times: Mapping[str, Iterable[f
return intervention_times[valid[0]]


def find_target_col(var: str, options: List[str]):
"""
Find the column that corresponds to the var
var -- The parsed variable name
options -- Column names to search for the variable name
"""
# TODO: This "underscore-trailing-name matching" seems very fragile....
# It is done this way since you can intervene on params & states
# and that will match either.
pattern = re.compile(f"(?:^|_){var}_(state|param)")
options = [c for c in options if pattern.search(c)]
if len(options) == 0:
raise KeyError(f"No target column match found for '{var}'.")
if len(options) > 1:
raise ValueError(
f"Could not uniquely determine target column for '{var}'. Found: {options}"
)
return options[0]


def set_intervention_values(
df: pd.DataFrame,
intervention: str,
Expand All @@ -171,7 +150,11 @@ def set_intervention_values(
"""
times = get_times_for(intervention, intervention_times)
target_var = "_".join(intervention.split("_")[3:-1])
target_col = find_target_col(target_var, df.columns)
target_col = f"persistent_{target_var}_param"

if target_col not in df.columns:
raise KeyError(f"Could not find target column for '{target_var}'")

time_col = [
c for c in df.columns if c.startswith("timepoint_") and c != "timepoint_id"
][0]
Expand Down
40 changes: 3 additions & 37 deletions tests/integration_utils/test_result_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,6 @@ def test_get_times_for(intervention):
)


@pytest.mark.parametrize("name", ["underscored", "with_underscore", "I", "i"])
def test_find_target_col(name):
good_columns = [
"before_underscored_param",
"underscored_after_state",
"sample_with_underscore_state",
"i_state",
"sampli_id_state",
"persistent_I_param",
]
result = result_processing.find_target_col(name, good_columns)
assert name in result
multiple_match_columns = [
"i_state",
"persistent_i_param",
"before_underscored_param",
"underscored_param",
"with_underscore_param",
"not_with_underscore_state",
"With_I_param",
"I_state",
]
with pytest.raises(ValueError):
result_processing.find_target_col(name, multiple_match_columns)
no_match_columns = [
"stuff_I_stuff_state",
"sampli_state",
"before_with_underscore_after_param",
"underscored_after_state",
]
with pytest.raises(KeyError):
result_processing.find_target_col(name, no_match_columns)


@pytest.mark.parametrize("logging_step_size", [1, 5, 10, 12, 23])
def test_set_intervention_values(logging_step_size):
model_1_path = (
Expand All @@ -114,10 +80,10 @@ def test_set_intervention_values(logging_step_size):
"parameter_intervention_value_beta_c_0": torch.tensor([0.0, 100.0, 200.0])
}

raw_internention_times = [logging_step_size * (n + 1) for n in range(num_samples)]
raw_intervention_times = [logging_step_size * (n + 1) for n in range(num_samples)]

intervention_times = {
"parameter_intervention_time_0": torch.tensor(raw_internention_times)
"parameter_intervention_time_0": torch.tensor(raw_intervention_times)
}
intervention = "parameter_intervention_value_beta_c_0"
df = result_processing.set_intervention_values(
Expand All @@ -129,7 +95,7 @@ def test_set_intervention_values(logging_step_size):

for name, group in df.groupby("sample_id"):
group = group.set_index("timepoint_nominal")
time = raw_internention_times[name]
time = raw_intervention_times[name]
expected = name * 100

if time - logging_step_size > 0:
Expand Down
Loading