Skip to content

Commit

Permalink
Simplify interchange dictionary processing (#564)
Browse files Browse the repository at this point in the history
* simplify intervention splicing

* lint

* removed test
  • Loading branch information
SamWitty authored Apr 8, 2024
1 parent 8283eca commit f5454df
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 60 deletions.
29 changes: 6 additions & 23 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import re
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
from typing import Any, Dict, Iterable, Mapping, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -136,26 +135,6 @@ def get_times_for(intervention: str, intervention_times: Mapping[str, Iterable[f
return intervention_times[valid[0]]


def find_target_col(var: str, options: List[str]):
"""
Find the column that corresponds to the var
var -- The parsed variable name
options -- Column names to search for the variable name
"""
# TODO: This "underscore-trailing-name matching" seems very fragile....
# It is done this way since you can intervene on params & states
# and that will match either.
pattern = re.compile(f"(?:^|_){var}_(state|param)")
options = [c for c in options if pattern.search(c)]
if len(options) == 0:
raise KeyError(f"No target column match found for '{var}'.")
if len(options) > 1:
raise ValueError(
f"Could not uniquely determine target column for '{var}'. Found: {options}"
)
return options[0]


def set_intervention_values(
df: pd.DataFrame,
intervention: str,
Expand All @@ -171,7 +150,11 @@ def set_intervention_values(
"""
times = get_times_for(intervention, intervention_times)
target_var = "_".join(intervention.split("_")[3:-1])
target_col = find_target_col(target_var, df.columns)
target_col = f"persistent_{target_var}_param"

if target_col not in df.columns:
raise KeyError(f"Could not find target column for '{target_var}'")

time_col = [
c for c in df.columns if c.startswith("timepoint_") and c != "timepoint_id"
][0]
Expand Down
40 changes: 3 additions & 37 deletions tests/integration_utils/test_result_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,40 +57,6 @@ def test_get_times_for(intervention):
)


@pytest.mark.parametrize("name", ["underscored", "with_underscore", "I", "i"])
def test_find_target_col(name):
good_columns = [
"before_underscored_param",
"underscored_after_state",
"sample_with_underscore_state",
"i_state",
"sampli_id_state",
"persistent_I_param",
]
result = result_processing.find_target_col(name, good_columns)
assert name in result
multiple_match_columns = [
"i_state",
"persistent_i_param",
"before_underscored_param",
"underscored_param",
"with_underscore_param",
"not_with_underscore_state",
"With_I_param",
"I_state",
]
with pytest.raises(ValueError):
result_processing.find_target_col(name, multiple_match_columns)
no_match_columns = [
"stuff_I_stuff_state",
"sampli_state",
"before_with_underscore_after_param",
"underscored_after_state",
]
with pytest.raises(KeyError):
result_processing.find_target_col(name, no_match_columns)


@pytest.mark.parametrize("logging_step_size", [1, 5, 10, 12, 23])
def test_set_intervention_values(logging_step_size):
model_1_path = (
Expand All @@ -114,10 +80,10 @@ def test_set_intervention_values(logging_step_size):
"parameter_intervention_value_beta_c_0": torch.tensor([0.0, 100.0, 200.0])
}

raw_internention_times = [logging_step_size * (n + 1) for n in range(num_samples)]
raw_intervention_times = [logging_step_size * (n + 1) for n in range(num_samples)]

intervention_times = {
"parameter_intervention_time_0": torch.tensor(raw_internention_times)
"parameter_intervention_time_0": torch.tensor(raw_intervention_times)
}
intervention = "parameter_intervention_value_beta_c_0"
df = result_processing.set_intervention_values(
Expand All @@ -129,7 +95,7 @@ def test_set_intervention_values(logging_step_size):

for name, group in df.groupby("sample_id"):
group = group.set_index("timepoint_nominal")
time = raw_internention_times[name]
time = raw_intervention_times[name]
expected = name * 100

if time - logging_step_size > 0:
Expand Down

0 comments on commit f5454df

Please sign in to comment.