Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkilic committed Aug 28, 2024
1 parent a4f74cd commit d8de567
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 25 deletions.
22 changes: 7 additions & 15 deletions bluepyemodel/emodel_pipeline/plotting_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from bluepyemodel.evaluation.protocols import ThresholdBasedProtocol
from bluepyemodel.evaluation.recordings import FixedDtRecordingCustom
from bluepyemodel.evaluation.recordings import FixedDtRecordingStimulus
from bluepyemodel.tools.utils import get_curr_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_protocol_name

logger = logging.getLogger("__main__")

Expand Down Expand Up @@ -379,12 +382,7 @@ def get_simulated_FI_curve_for_plotting(evaluator, responses, prot_name):
simulated_amp = []
for val in values:
if prot_name.lower() in val.lower():
# val is e.g. IV_40.soma.maximum_voltage_from_voltagebase
n = val.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(val)
amp_temp = float(protocol_name.split("_")[-1])
if "mean_frequency" in val:
simulated_freq.append(values[val])
Expand Down Expand Up @@ -593,17 +591,11 @@ def get_ordered_currentscape_keys(keys):

ordered_keys = {}
for name in keys:
n = name.split(".")
# case where protocol has '.' in its name, e.g. IV_-100.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
prot_name = n[0]
prot_name = get_protocol_name(name)
# prot_name can be e.g. RMPProtocol, or RMPProtocol_apical055
if not any(to_skip_ in prot_name for to_skip_ in to_skip):
if len(n) != 3:
raise ValueError(f"Expected 3 elements in {n}")
loc_name = n[1]
curr_name = n[2]
loc_name = get_loc_name(name)
curr_name = get_curr_name(name)

if prot_name not in ordered_keys:
ordered_keys[prot_name] = {}
Expand Down
56 changes: 51 additions & 5 deletions bluepyemodel/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,56 @@ def get_amplitude_from_feature_key(feat_key):
Args:
feat_key (str): feature key, e.g. IV_40.soma.maximum_voltage_from_voltagebase
"""
n = feat_key.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if len(n) == 4 and n[1].isdigit():
n = [".".join(n[:2]), n[2], n[3]]
protocol_name = n[0]
protocol_name = get_protocol_name(feat_key)

return float(protocol_name.split("_")[-1])


def combine_parts_if_dot_in_protocol(feature_name):
"""
Combine the first two elements of a list if the second element is numeric,
indicating the presence of a dot in the protocol.
Args:
feature_name (list): The list of split parts from the feature name.
"""
if len(feature_name) > 1 and feature_name[1].isdigit():
return [".".join(feature_name[:2])] + feature_name[2:]
return feature_name


def get_protocol_name(feature_name):
"""
Extract the protocol name from the feature name.
Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
return n[0]


def get_loc_name(feature_name):
"""
Extract the location name from the feature name.
Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
if len(n) < 2:
raise IndexError("cannot get location name from feature name")
return n[1]


def get_curr_name(feature_name):
"""
Extract the current name from the feature name.
Args:
feature_name (str): The full feature name string.
"""
n = combine_parts_if_dot_in_protocol(feature_name.split("."))
if len(n) < 3:
raise IndexError("cannot get current name from feature name")
return n[2]
7 changes: 2 additions & 5 deletions bluepyemodel/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from bluepyemodel.evaluation.evaluation import compute_responses
from bluepyemodel.evaluation.evaluation import get_evaluator_from_access_point
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.validation import validation_functions

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -76,11 +77,7 @@ def compute_scores(model, validation_protocols):

scores = model.evaluator.fitness_calculator.calculate_scores(model.responses)
for feature_name in scores:
n = feature_name.split(".")
# case where protocol has '.' in its name, e.g. IV_40.0
if n[1].isdigit():
n = [".".join(n[:2])] + n[2:]
protocol_name = n[0]
protocol_name = get_protocol_name(feature_name)
if any(are_same_protocol(p, protocol_name) for p in validation_protocols):
model.scores_validation[feature_name] = scores[feature_name]
else:
Expand Down
38 changes: 38 additions & 0 deletions tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from bluepyemodel.tools.utils import are_same_protocol
from bluepyemodel.tools.utils import format_protocol_name_to_list
from bluepyemodel.tools.utils import select_rec_for_thumbnail
from bluepyemodel.tools.utils import get_protocol_name
from bluepyemodel.tools.utils import get_loc_name
from bluepyemodel.tools.utils import get_curr_name
from tests.utils import DATA


Expand Down Expand Up @@ -136,3 +139,38 @@ def test_select_rec_for_thumbnail():
assert (
select_rec_for_thumbnail(rec_names, thumbnail_rec="sAHP_20.soma.v") == "IDrest_130.soma.v"
)


def test_get_protocol_name():
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40.0"

feature_name = "IV_40.soma.v.voltage_base"
assert get_protocol_name(feature_name) == "IV_40"

feature_name = "ProtocolA.1.soma.some_feature"
assert get_protocol_name(feature_name) == "ProtocolA.1"


def test_get_loc_name():
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.soma.v.voltage_base"
assert get_loc_name(feature_name) == "soma"

feature_name = "IV_40.0"
with pytest.raises(IndexError, match="cannot get location name from feature name"):
get_loc_name(feature_name)


def test_get_curr_name():
feature_name = "IV_40.0.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.soma.v.voltage_base"
assert get_curr_name(feature_name) == "v"

feature_name = "IV_40.0.soma"
with pytest.raises(IndexError, match="cannot get current name from feature name"):
get_curr_name(feature_name)

0 comments on commit d8de567

Please sign in to comment.