From d8de5677a2003cbf139484b0a5da0602fafe6b7e Mon Sep 17 00:00:00 2001 From: ilkilic Date: Wed, 28 Aug 2024 14:12:41 +0200 Subject: [PATCH] refactoring --- .../emodel_pipeline/plotting_utils.py | 22 +++----- bluepyemodel/tools/utils.py | 56 +++++++++++++++++-- bluepyemodel/validation/validation.py | 7 +-- tests/unit_tests/test_tools.py | 38 +++++++++++++ 4 files changed, 98 insertions(+), 25 deletions(-) diff --git a/bluepyemodel/emodel_pipeline/plotting_utils.py b/bluepyemodel/emodel_pipeline/plotting_utils.py index dc473c41..e213cda1 100644 --- a/bluepyemodel/emodel_pipeline/plotting_utils.py +++ b/bluepyemodel/emodel_pipeline/plotting_utils.py @@ -32,6 +32,9 @@ from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus +from bluepyemodel.tools.utils import get_curr_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_protocol_name logger = logging.getLogger("__main__") @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name): simulated_amp = [] for val in values: if prot_name.lower() in val.lower(): - # val is e.g. IV_40.soma.maximum_voltage_from_voltagebase - n = val.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(val) amp_temp = float(protocol_name.split("_")[-1]) if "mean_frequency" in val: simulated_freq.append(values[val]) @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys): ordered_keys = {} for name in keys: - n = name.split(".") - # case where protocol has '.' in its name, e.g. IV_-100.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - prot_name = n[0] + prot_name = get_protocol_name(name) # prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055 if not any(to_skip_ in prot_name for to_skip_ in to_skip): - if len(n) != 3: - raise ValueError(f"Expected 3 elements in {n}") - loc_name = n[1] - curr_name = n[2] + loc_name = get_loc_name(name) + curr_name = get_curr_name(name) if prot_name not in ordered_keys: ordered_keys[prot_name] = {} diff --git a/bluepyemodel/tools/utils.py b/bluepyemodel/tools/utils.py index bef41fc7..bc962819 100644 --- a/bluepyemodel/tools/utils.py +++ b/bluepyemodel/tools/utils.py @@ -256,10 +256,56 @@ def get_amplitude_from_feature_key(feat_key): Args: feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase """ - n = feat_key.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if len(n) == 4 and n[1].isdigit(): - n = [".".join(n[:2]), n[2], n[3]] - protocol_name = n[0] + protocol_name = get_protocol_name(feat_key) return float(protocol_name.split("_")[-1]) + + +def combine_parts_if_dot_in_protocol(feature_name): + """ + Combine the first two elements of a list if the second element is numeric, + indicating the presence of a dot in the protocol. + + Args: + feature_name (list): The list of split parts from the feature name. + """ + if len(feature_name) > 1 and feature_name[1].isdigit(): + return [".".join(feature_name[:2])] + feature_name[2:] + return feature_name + + +def get_protocol_name(feature_name): + """ + Extract the protocol name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + return n[0] + + +def get_loc_name(feature_name): + """ + Extract the location name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + if len(n) < 2: + raise IndexError("cannot get location name from feature name") + return n[1] + + +def get_curr_name(feature_name): + """ + Extract the current name from the feature name. + + Args: + feature_name (str): The full feature name string. + """ + n = combine_parts_if_dot_in_protocol(feature_name.split(".")) + if len(n) < 3: + raise IndexError("cannot get current name from feature name") + return n[2] diff --git a/bluepyemodel/validation/validation.py b/bluepyemodel/validation/validation.py index 6d2af34b..0d217c45 100644 --- a/bluepyemodel/validation/validation.py +++ b/bluepyemodel/validation/validation.py @@ -25,6 +25,7 @@ from bluepyemodel.evaluation.evaluation import compute_responses from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point from bluepyemodel.tools.utils import are_same_protocol +from bluepyemodel.tools.utils import get_protocol_name from bluepyemodel.validation import validation_functions logger = logging.getLogger(__name__) @@ -76,11 +77,7 @@ def compute_scores(model, validation_protocols): scores = model.evaluator.fitness_calculator.calculate_scores(model.responses) for feature_name in scores: - n = feature_name.split(".") - # case where protocol has '.' in its name, e.g. IV_40.0 - if n[1].isdigit(): - n = [".".join(n[:2])] + n[2:] - protocol_name = n[0] + protocol_name = get_protocol_name(feature_name) if any(are_same_protocol(p, protocol_name) for p in validation_protocols): model.scores_validation[feature_name] = scores[feature_name] else: diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index 85e1cb84..f565a270 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -20,6 +20,9 @@ from bluepyemodel.tools.utils import are_same_protocol from bluepyemodel.tools.utils import format_protocol_name_to_list from bluepyemodel.tools.utils import select_rec_for_thumbnail +from bluepyemodel.tools.utils import get_protocol_name +from bluepyemodel.tools.utils import get_loc_name +from bluepyemodel.tools.utils import get_curr_name from tests.utils import DATA @@ -136,3 +139,38 @@ def test_select_rec_for_thumbnail(): assert ( select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v" ) + + +def test_get_protocol_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40.0" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_protocol_name(feature_name) == "IV_40" + + feature_name = "ProtocolA.1.soma.some_feature" + assert get_protocol_name(feature_name) == "ProtocolA.1" + + +def test_get_loc_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_loc_name(feature_name) == "soma" + + feature_name = "IV_40.0" + with pytest.raises(IndexError, match="cannot get location name from feature name"): + get_loc_name(feature_name) + + +def test_get_curr_name(): + feature_name = "IV_40.0.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.soma.v.voltage_base" + assert get_curr_name(feature_name) == "v" + + feature_name = "IV_40.0.soma" + with pytest.raises(IndexError, match="cannot get current name from feature name"): + get_curr_name(feature_name) \ No newline at end of file