Skip to content

Commit

Permalink
Add test in the no concatneation case.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Dec 18, 2023
1 parent 44c5cef commit f3e1af7
Show file tree
Hide file tree
Showing 9 changed files with 1,829 additions and 14 deletions.
1 change: 1 addition & 0 deletions tests/data/small_toy_data/in_container_params.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
743 changes: 743 additions & 0 deletions tests/data/small_toy_data/in_container_recording.json

Large diffs are not rendered by default.

46 changes: 46 additions & 0 deletions tests/data/small_toy_data/in_container_sorter_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import json
from pathlib import Path

from spikeinterface import load_extractor
from spikeinterface.sorters import run_sorter_local

if __name__ == "__main__":
# this __name__ protection help in some case with multiprocessing (for instance HS2)
# load recording in container
json_rec = Path(
"/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/in_container_recording.json"
)
pickle_rec = Path(
"/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/in_container_recording.pickle"
)
if json_rec.exists():
recording = load_extractor(json_rec)
else:
recording = load_extractor(pickle_rec)

# load params in container
with open(
"/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/in_container_params.json",
encoding="utf8",
mode="r",
) as f:
sorter_params = json.load(f)

# run in container
output_folder = (
"/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/mountainsort5_output"
)
sorting = run_sorter_local(
"mountainsort5",
recording,
output_folder=output_folder,
remove_existing_folder=False,
delete_output_folder=False,
verbose=False,
raise_error=True,
with_output=True,
**sorter_params,
)
sorting.save_to_folder(
folder="/fMRIData/git-repo/spikewrap/tests/data/small_toy_data/mountainsort5_output/in_container_sorting"
)
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"sorter_name": "mountainsort5",
"sorter_version": "0.3.0",
"datetime": "2023-12-18T18:48:45.578078",
"runtime_trace": [],
"error": false,
"run_time": 0.26515789999393746
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"sorter_name": "mountainsort5",
"sorter_params": {
"scheme": "2",
"detect_threshold": 5.5,
"detect_sign": -1,
"detect_time_radius_msec": 0.5,
"snippet_T1": 20,
"snippet_T2": 20,
"npca_per_channel": 3,
"npca_per_subdivision": 10,
"snippet_mask_radius": 250,
"scheme1_detect_channel_radius": 150,
"scheme2_phase1_detect_channel_radius": 200,
"scheme2_detect_channel_radius": 50,
"scheme2_max_num_snippets_per_training_batch": 200,
"scheme2_training_duration_sec": 300,
"scheme2_training_recording_sampling_mode": "uniform",
"scheme3_block_duration_sec": 1800,
"freq_min": 300,
"freq_max": 6000,
"filter": false,
"whiten": false
}
}

Large diffs are not rendered by default.

15 changes: 14 additions & 1 deletion tests/test_integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ def run_full_pipeline(
)

def check_correct_folders_exist(
self, test_info, concatenate_sessions, concatenate_runs, sorter="kilosort2_5"
self,
test_info,
concatenate_sessions,
concatenate_runs,
sorter="kilosort2_5",
sort_by_group=False,
):
sub_path = test_info[0] / "derivatives" / "spikewrap" / test_info[1]
sessions_and_runs = test_info[2]
Expand Down Expand Up @@ -184,6 +189,14 @@ def check_correct_folders_exist(
else:
assert len(run_level_sorting) == 1

if sort_by_group:
sorted_groups = list(run_level_sorting[0].glob("group-*"))
assert len(sorted_groups) > 1

for sorting_output_path in sorted_groups:
assert (sorting_output_path / "sorting").is_dir()
assert (sorting_output_path / "postprocessing").is_dir()

ses_path = sub_path / ses_name / "ephys"

concat_all_run_names = "".join(
Expand Down
136 changes: 123 additions & 13 deletions tests/test_integration/test_full_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def test_no_concatenation_all_sorters_single_run(self, test_info, sorter):
self.check_no_concat_results(test_info, loaded_data, sorting_data, sorter)

@pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True)
def test_no_concatenation_single_run(self, test_info):
@pytest.mark.parametrize("sort_by_group", [True, False])
def test_no_concatenation_single_run(self, test_info, sort_by_group):
"""
Run the full pipeline for a single
session and run, and check preprocessing, sorting and waveforms.
Expand All @@ -135,14 +136,17 @@ def test_no_concatenation_single_run(self, test_info):
loaded_data, sorting_data = self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
sorter=DEFAULT_SORTER,
concatenate_sessions=False,
concatenate_runs=False,
)

self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER)
self.check_correct_folders_exist(
test_info, False, False, DEFAULT_SORTER, sort_by_group=sort_by_group
)
self.check_no_concat_results(
test_info, loaded_data, sorting_data, DEFAULT_SORTER
test_info, loaded_data, sorting_data, DEFAULT_SORTER, sort_by_group
)

@pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True)
Expand All @@ -159,8 +163,6 @@ def test_no_concatenation_multiple_runs(self, test_info):
sorter=DEFAULT_SORTER,
)

self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER)

self.check_correct_folders_exist(test_info, False, False, DEFAULT_SORTER)
self.check_no_concat_results(test_info, loaded_data, sorting_data)

Expand Down Expand Up @@ -225,7 +227,8 @@ def test_ses_concat_no_run_concat(self, test_info):
)

@pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True)
def test_existing_output_settings(self, test_info):
@pytest.mark.parametrize("sort_by_group", [True, False])
def test_existing_output_settings(self, test_info, sort_by_group):
"""
In spikewrap existing preprocessed and sorting output data is
handled with options `fail_if_exists`, `skip_if_exists` or
Expand All @@ -245,18 +248,22 @@ def test_existing_output_settings(self, test_info):
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="fail_if_exists",
existing_sorting_output="fail_if_exists",
overwrite_postprocessing=False,
sorter=DEFAULT_SORTER,
)

# Test outputs are overwritten if `overwrite` set.
file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name)
file_paths = self.write_an_empty_file_in_outputs(
test_info, ses_name, run_name, sort_by_group
)

self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="overwrite",
existing_sorting_output="overwrite",
overwrite_postprocessing=True,
Expand All @@ -266,13 +273,16 @@ def test_existing_output_settings(self, test_info):
for path_ in file_paths:
assert not path_.is_file()

file_paths = self.write_an_empty_file_in_outputs(test_info, ses_name, run_name)
file_paths = self.write_an_empty_file_in_outputs(
test_info, ses_name, run_name, sort_by_group
)

# Test outputs are not overwritten if `skip_if_exists`.
# Postprocessing is always deleted
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="skip_if_exists",
existing_sorting_output="skip_if_exists",
overwrite_postprocessing=True,
Expand All @@ -287,6 +297,7 @@ def test_existing_output_settings(self, test_info):
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="fail_if_exists",
existing_sorting_output="skip_if_exists",
overwrite_postprocessing=True,
Expand All @@ -307,6 +318,7 @@ def test_existing_output_settings(self, test_info):
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="skip_if_exists",
existing_sorting_output="fail_if_exists",
overwrite_postprocessing=True,
Expand All @@ -320,6 +332,7 @@ def test_existing_output_settings(self, test_info):
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
sort_by_group=sort_by_group,
existing_preprocessed_data="skip_if_exists",
existing_sorting_output="skip_if_exists",
overwrite_postprocessing=False,
Expand Down Expand Up @@ -354,7 +367,12 @@ def test_smoke_supply_chunk_size(self, test_info, capsys, specify_chunk_size):
# ----------------------------------------------------------------------------------

def check_no_concat_results(
self, test_info, loaded_data, sorting_data, sorter=DEFAULT_SORTER
self,
test_info,
loaded_data,
sorting_data,
sorter=DEFAULT_SORTER,
sort_by_group=False,
):
"""
After `full_pipeline` is run, check the preprocessing, sorting and postprocessing
Expand Down Expand Up @@ -678,7 +696,7 @@ def check_waveforms(
assert np.array_equal(data, first_unit_waveforms[0])

def write_an_empty_file_in_outputs(
self, test_info, ses_name, run_name, sorter=DEFAULT_SORTER
self, test_info, ses_name, run_name, sort_by_group, sorter=DEFAULT_SORTER
):
"""
Write a file called `test_file.txt` with contents `test_file` in
Expand All @@ -689,12 +707,20 @@ def write_an_empty_file_in_outputs(

paths_to_write = []
for output in ["preprocessing", "sorting_path", "postprocessing"]:
paths_to_write.append(paths[output] / "test_file.txt")
if sort_by_group and output in ["sorting_path", "postprocessing"]:
group_paths = paths[output].parent.glob("group-*")
paths_to_extend = [
path_ / paths[output].name / "test_file.txt"
for path_ in group_paths
]
else:
paths_to_extend = [paths[output] / "test_file.txt"]

paths_to_write.extend(paths_to_extend)

for path_ in paths_to_write:
with open(path_, "w") as file:
with open(path_.as_posix(), "w") as file:
file.write("test file.")

return paths_to_write

def get_output_paths(
Expand Down Expand Up @@ -754,6 +780,90 @@ def get_output_paths(

return paths

@pytest.mark.parametrize("test_info", [DEFAULT_FORMAT], indirect=True)
def test_sort_by_group(self, test_info):
self.run_full_pipeline(
*test_info,
data_format=DEFAULT_FORMAT,
concatenate_runs=False,
concatenate_sessions=False,
sort_by_group=True,
existing_preprocessed_data="overwrite",
existing_sorting_output="overwrite",
overwrite_postprocessing=True,
sorter=DEFAULT_SORTER,
)

base_path, sub_name, sessions_and_runs = test_info

from spikeinterface import sorters

for ses_name in sessions_and_runs.keys():
for run_name in sessions_and_runs[ses_name]:
_, test_preprocessed = self.get_test_rawdata_and_preprocessed_data(
base_path, sub_name, ses_name, run_name
)

split_recording = test_preprocessed.split_by("group")

if "kilosort" in DEFAULT_SORTER:
singularity_image = True if platform.system() == "Linux" else False
docker_image = not singularity_image
else:
singularity_image = docker_image = False

# TODO: load config!
sortings = {}
for group, sub_recording in split_recording.items():
sorting = sorters.run_sorter(
sorter_name=DEFAULT_SORTER,
recording=sub_recording,
output_folder=None,
docker_image=docker_image,
singularity_image=singularity_image,
remove_existing_folder=True,
**{
"scheme": "2",
"filter": False,
"whiten": False,
"verbose": True,
},
)

sortings[group] = sorting

out_path = (
base_path
/ "derivatives"
/ "spikewrap"
/ sub_name
/ ses_name
/ "ephys"
/ run_name
/ DEFAULT_SORTER
)

from spikewrap.data_classes.postprocessing import (
load_saved_sorting_output,
)

group_paths = list(sorted(out_path.glob("group-*")))

for idx, path_ in enumerate(group_paths):
group_sorting = load_saved_sorting_output(
path_ / "sorting" / "sorter_output", DEFAULT_SORTER
)

assert np.array_equal(
group_sorting.get_unit_ids(), sortings[idx].get_unit_ids()
)

for unit in group_sorting.get_unit_ids():
assert np.array_equal(
group_sorting.get_unit_spike_train(unit),
sortings[idx].get_unit_spike_train(unit),
)

# ----------------------------------------------------------------------------------
# Getters
# ----------------------------------------------------------------------------------
Expand Down

0 comments on commit f3e1af7

Please sign in to comment.