Skip to content

Commit

Permalink
Finalise testing.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 15, 2023
1 parent 57d865b commit b306ca6
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 9 deletions.
23 changes: 23 additions & 0 deletions spikewrap/data_classes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
Tuple,
)

import numpy as np

from spikewrap.utils import utils

if TYPE_CHECKING:
import fnmatch
from datashuttle.utils.utils import get_values_from_bids_formatted_name


@dataclass
Expand Down Expand Up @@ -76,6 +79,23 @@ def raise_if_only_and_has_more_than_one_folder(
f"set to 'only'."
)

# TODO:
def check_and_sort_globbed_names(self, all_names: List[str]) -> List[str]:
""""""
all_names = sorted(all_names)

# TODO: rename
values = get_values_from_bids_formatted_name(
all_names, "ses", return_as_int=True
)
name_nums = [int(name.split("_")[0].split("-")[1]) for name in all_names]
if name_nums[0] != 1 or np.any(np.diff(values) != 1):
raise RuntimeError(
"Using the 'all' key has made session names go out of order. Please"
"get in contact and this can be quickly resolved."
)
return all_names

def _convert_session_and_run_keywords_to_foldernames( # TODO: this is called from preprocessing and sorting.
self, get_sub_path: Callable, get_ses_path: Callable
) -> None:
Expand All @@ -96,6 +116,9 @@ def _convert_session_and_run_keywords_to_foldernames( # TODO: this is called fr
path_.stem for path_ in ses_name_filepaths if path_.is_dir()
]

all_session_names = self.check_and_sort_globbed_names(all_session_names)
# TODO: need to sort all session names and then check these names are still in the correct order! If they are not we need to do something....

self.raise_if_only_and_has_more_than_one_folder(
ses_keyword, sub_path, all_session_names, "session"
)
Expand Down
13 changes: 13 additions & 0 deletions spikewrap/utils/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def check_function_arguments(arguments):
if not typecheck(arg_value, bool):
raise TypeError(f"`{arg_name}` must be a bool.")

if (
"concat_runs_for_sorting" in arguments
and arguments["concat_runs_for_sorting"]
) or ("concatenate_runs" in arguments and arguments["concatenate_runs"]):
for ses_name in arguments["sessions_and_runs"].keys():
for run_name in arguments["sessions_and_runs"][ses_name]:
if run_name == "all":
raise ValueError(
"Using the 'all' option for `sessions_and_runs` "
"is currently not supported when concatenating runs for sorting. "
"If you only have one run, use 'only', otherwise please get in contact. "
)

elif arg_name == "existing_preprocessed_data":
if not typecheck(arg_value, HandleExisting):
raise TypeError(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration/test_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
class TestFullPipeline(BaseTest):
# TODO: naming now confusing between test format and SI format
@pytest.mark.parametrize("test_info", ["multi_segment"], indirect=True)
def test_multi_segment(self, test_info, all_session, all_runs):
def test_multi_segment(self, test_info):
with pytest.raises(ValueError) as e:
load_data(*test_info[:3], data_format="spikeinterface")

Expand Down
27 changes: 19 additions & 8 deletions tests/test_integration/test_load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ class TestLoadData(BaseTest):
@pytest.mark.parametrize("mode", ["all_sessions_and_runs", "all_runs"])
def test_all_keyword(self, test_info, mode):
"""
There are 4 cases to test, sesesions on and off, runs and off.
Test the 'all' keyword for the `sessions_and_runs` dictionary which
can take "all" in the sessions key or "all" as a run input. The 'only'
keyword is not tested here.
Session on: if run is on, then everything is discovered as below
if run is off, then there is a crash
session off: if run is on, we are okay (discover per run)
if run is off, then we just have normal use.
The two cases that are tested here are when session level is 'all'
or when the run level for each session is 'all'. In either case,
the data should be loaded as expected from the toy example file data.
"""
base_path, sub_name, sessions_and_runs = test_info

# Set the 'all' key on the sessions and runs
if mode == "all_sessions_and_runs":
new_sessions_and_runs = {"all": ["all"]}
else:
Expand All @@ -40,6 +41,8 @@ def test_all_keyword(self, test_info, mode):
base_path, sub_name, new_sessions_and_runs, data_format="spikeinterface"
)

# Check the preprocess_data dict contains inputs
# in the order that is expected
assert list(preprocess_data.keys()) == ["ses-001", "ses-002", "ses-003"]

for ses_name in preprocess_data.keys():
Expand All @@ -51,6 +54,8 @@ def test_all_keyword(self, test_info, mode):
for run in ["run-001", "run-002"]:
run_name = f"{ses_name}_{run}"

# for each run, check the expected run is loaded into
# `preprocess_data`.
test_run_data = load_extractor(
base_path
/ "rawdata"
Expand All @@ -69,7 +74,11 @@ def test_all_keyword(self, test_info, mode):
@pytest.mark.parametrize("test_info", ["spikeinterface"], indirect=True)
def test_all_keyword_session_all_run_normal(self, test_info):
"""
TODO: document, this is stupid
This is a bad test in the case that sessions is "all"
and runs is not "all". See issue #166 to resolve this.
For the time being, on the current test set this will duplicate
the runs for ses-001 to ses-002 and ses-003. The run names are
different across sessions so this raises and error.
"""
base_path, sub_name, sessions_and_runs = test_info

Expand All @@ -88,7 +97,9 @@ def test_all_keyword_session_all_run_normal(self, test_info):
@pytest.mark.parametrize("session_or_run", ["session", "run"])
def test_only_keyword(self, test_info, session_or_run):
"""
That is raises an error
Test that when the 'only' keyword is used in sessions and runs,
either at the session level, then and error is raised if there is more
than one session or run.
"""
base_path, sub_name, sessions_and_runs = test_info

Expand Down
19 changes: 19 additions & 0 deletions tests/test_integration/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,22 @@ def test_validate_empty_sessions_and_runs(self, test_info):
)

assert str(e.value) == "`sessions_and_runs` cannot contain empty runs."

@pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True)
def test_run_all_with_concatenate_is_blocked(self, test_info):
base_path, sub_name, sessions_and_runs = test_info
sessions_and_runs["ses-001"] = ["all"]

with pytest.raises(ValueError) as e:
self.run_full_pipeline(
base_path,
sub_name,
sessions_and_runs,
DEFAULT_FORMAT,
concatenate_runs=True,
)

assert (
"Using the 'all' option for `sessions_and_runs` is currently "
"not supported when concatenating runs for sorting." in str(e.value)
)

0 comments on commit b306ca6

Please sign in to comment.