Skip to content

Commit

Permalink
wip refactoring main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
luiztauffer committed Oct 13, 2023
1 parent 63962cb commit 38a2b5a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 56 deletions.
113 changes: 64 additions & 49 deletions containers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,33 +127,12 @@ def main(
filterwarnings(action="ignore", message="No cached namespaces found in .*")
filterwarnings(action="ignore", message="Ignoring cached namespace .*")

# Set SpikeInterface global job kwargs
si.set_global_job_kwargs(
n_jobs=os.cpu_count(),
chunk_duration="1s",
progress_bar=False
)

# Create folders
data_folder = Path("data/")
data_folder.mkdir(exist_ok=True)

scratch_folder = Path("scratch/")
scratch_folder.mkdir(exist_ok=True)

results_folder = Path("results/")
results_folder.mkdir(exist_ok=True)
tmp_folder = scratch_folder / "tmp"
if tmp_folder.is_dir():
shutil.rmtree(tmp_folder)
tmp_folder.mkdir()

# Checks
if source_name not in ["local", "s3", "dandi"]:
logger.error(f"Source {source_name} not supported. Choose from: local, s3, dandi.")
raise ValueError(f"Source {source_name} not supported. Choose from: local, s3, dandi.")

# TODO: here we could leverage spikeinterface and add more options
# TODO: here we could eventually leverage spikeinterface and add more options
if source_data_type not in ["nwb", "spikeglx"]:
logger.error(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.")
raise ValueError(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.")
Expand All @@ -174,7 +153,28 @@ def main(
output_s3_bucket = output_path_parsed.split("/")[0]
output_s3_bucket_folder = "/".join(output_path_parsed.split("/")[1:])

s3_client = boto3.client("s3")
# Set SpikeInterface global job kwargs
si.set_global_job_kwargs(
n_jobs=os.cpu_count(),
chunk_duration="1s",
progress_bar=False
)

# Create SpikeInterface folders
data_folder = Path("data/")
data_folder.mkdir(exist_ok=True)
scratch_folder = Path("scratch/")
scratch_folder.mkdir(exist_ok=True)
results_folder = Path("results/")
results_folder.mkdir(exist_ok=True)
tmp_folder = scratch_folder / "tmp"
if tmp_folder.is_dir():
shutil.rmtree(tmp_folder)
tmp_folder.mkdir()

# S3 client
if source_name == "s3" or output_destination == "s3":
s3_client = boto3.client("s3")

# Test with toy recording
if test_with_toy_recording:
Expand All @@ -198,7 +198,6 @@ def main(
bucket_name=bucket_name,
file_path=file_path,
)

logger.info("Reading recording...")
# E.g.: se.read_spikeglx(folder_path="/data", stream_id="imec.ap")
if source_data_type == "spikeglx":
Expand All @@ -207,44 +206,49 @@ def main(
recording = se.read_nwb_recording(file_path=f"/data/{file_name}", **recording_kwargs)
recording_name = "recording_on_s3"

# Load data from DANDI archive
elif source_name == "dandi":
dandiset_s3_file_url = source_data_paths["file"]
if not dandiset_s3_file_url.startswith("https://dandiarchive"):
raise Exception(
f"DANDISET_S3_FILE_URL should be a valid Dandiset S3 url. Value received was: {dandiset_s3_file_url}"
)

if not test_with_subrecording:
logger.info(f"Downloading dataset: {dandiset_s3_file_url}")
download_file_from_url(dandiset_s3_file_url)

logger.info("Reading recording from NWB...")
recording = se.read_nwb_recording(file_path="/data/filename.nwb", **recording_kwargs)
else:
logger.info("Reading recording from NWB...")
recording = se.read_nwb_recording(file_path=dandiset_s3_file_url, stream_mode="fsspec", **recording_kwargs)
recording_name = "recording_on_dandi"

# TODO - Load data from local files
elif source_name == "local":
pass

# Run with subrecording
if test_with_subrecording:
n_frames = int(min(test_subrecording_n_frames, recording.get_num_frames()))
recording = recording.frame_slice(start_frame=0, end_frame=n_frames)

# ------------------------------------------------------------------------------------
# Preprocessing
# ------------------------------------------------------------------------------------
logger.info("Starting preprocessing...")
preprocessing_notes = ""
preprocessed_folder = tmp_folder / "preprocessed"
t_preprocessing_start = time.perf_counter()
logger.info(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s")

if "inter_sample_shift" in recording.get_property_keys():
recording_ps_full = spre.phase_shift(recording, **preprocessing_params["phase_shift"])
recording_ps_full = spre.phase_shift(recording, **preprocessing_kwargs["phase_shift"])
else:
recording_ps_full = recording

recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_params["highpass_filter"])
recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_kwargs["highpass_filter"])
# IBL bad channel detection
_, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_params["detect_bad_channels"])
_, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_kwargs["detect_bad_channels"])
dead_channel_mask = channel_labels == "dead"
noise_channel_mask = channel_labels == "noise"
out_channel_mask = channel_labels == "out"
Expand All @@ -257,7 +261,7 @@ def main(
out_channel_ids = recording_hp_full.channel_ids[out_channel_mask]

all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids))
max_bad_channel_fraction_to_remove = preprocessing_params["max_bad_channel_fraction_to_remove"]
max_bad_channel_fraction_to_remove = preprocessing_kwargs["max_bad_channel_fraction_to_remove"]
if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()):
logger.info(
f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). "
Expand All @@ -267,28 +271,28 @@ def main(
# in this case, we don't bother sorting
return
else:
if preprocessing_params["remove_out_channels"]:
if preprocessing_kwargs["remove_out_channels"]:
logger.info(f"\tRemoving {len(out_channel_ids)} out channels")
recording_rm_out = recording_hp_full.remove_channels(out_channel_ids)
preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain."
else:
recording_rm_out = recording_hp_full

recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_params["common_reference"])
recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_kwargs["common_reference"])

bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids))
recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids)
recording_hp_spatial = spre.highpass_spatial_filter(
recording_interp, **preprocessing_params["highpass_spatial_filter"]
recording_interp, **preprocessing_kwargs["highpass_spatial_filter"]
)

preproc_strategy = preprocessing_params["preprocessing_strategy"]
preproc_strategy = preprocessing_kwargs["preprocessing_strategy"]
if preproc_strategy == "cmr":
recording_processed = recording_processed_cmr
else:
recording_processed = recording_hp_spatial

if preprocessing_params["remove_bad_channels"]:
if preprocessing_kwargs["remove_bad_channels"]:
logger.info(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing")
recording_processed = recording_processed.remove_channels(bad_channel_ids)
preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n"
Expand All @@ -300,6 +304,8 @@ def main(

# ------------------------------------------------------------------------------------
# Spike Sorting
# ------------------------------------------------------------------------------------
sorter_name = sorter_kwargs["sorter_name"]
logger.info(f"\n\nStarting spike sorting with {sorter_name}")
spikesorting_notes = ""
sorting_params = None
Expand All @@ -313,7 +319,7 @@ def main(
if recording_processed.get_num_segments() > 1:
recording_processed = si.concatenate_recordings([recording_processed])

# run ks2.5
# Run sorter
try:
sorting = ss.run_sorter(
sorter_name,
Expand All @@ -334,6 +340,7 @@ def main(

# remove empty units
sorting = sorting.remove_empty_units()

# remove spikes beyond num_samples (if any)
sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording_processed)
logger.info(f"\tSorting output without empty units: {sorting}")
Expand All @@ -353,14 +360,15 @@ def main(

# ------------------------------------------------------------------------------------
# Postprocessing
# ------------------------------------------------------------------------------------
logger.info("\n\Starting postprocessing...")
postprocessing_notes = ""
t_postprocessing_start = time.perf_counter()

# first extract some raw waveforms in memory to deduplicate based on peak alignment
wf_dedup_folder = tmp_folder / "postprocessed" / recording_name
we_raw = si.extract_waveforms(
recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_params["waveforms_deduplicate"]
recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_kwargs["waveforms_deduplicate"]
)
# de-duplication
sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_params["duplicate_threshold"])
Expand Down Expand Up @@ -389,32 +397,34 @@ def main(
sparsity=sparsity,
sparse=True,
overwrite=True,
**postprocessing_params["waveforms"],
**postprocessing_kwargs["waveforms"],
)
logger.info("\tComputing spike amplitides")
spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_params["spike_amplitudes"])
spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_kwargs["spike_amplitudes"])
logger.info("\tComputing unit locations")
unit_locations = spost.compute_unit_locations(we, **postprocessing_params["locations"])
unit_locations = spost.compute_unit_locations(we, **postprocessing_kwargs["locations"])
logger.info("\tComputing spike locations")
spike_locations = spost.compute_spike_locations(we, **postprocessing_params["locations"])
spike_locations = spost.compute_spike_locations(we, **postprocessing_kwargs["locations"])
logger.info("\tComputing correlograms")
ccg, bins = spost.compute_correlograms(we, **postprocessing_params["correlograms"])
ccg, bins = spost.compute_correlograms(we, **postprocessing_kwargs["correlograms"])
logger.info("\tComputing ISI histograms")
isi, bins = spost.compute_isi_histograms(we, **postprocessing_params["isis"])
isi, bins = spost.compute_isi_histograms(we, **postprocessing_kwargs["isis"])
logger.info("\tComputing template similarity")
sim = spost.compute_template_similarity(we, **postprocessing_params["similarity"])
sim = spost.compute_template_similarity(we, **postprocessing_kwargs["similarity"])
logger.info("\tComputing template metrics")
tm = spost.compute_template_metrics(we, **postprocessing_params["template_metrics"])
tm = spost.compute_template_metrics(we, **postprocessing_kwargs["template_metrics"])
logger.info("\tComputing PCA")
pca = spost.compute_principal_components(we, **postprocessing_params["principal_components"])
pca = spost.compute_principal_components(we, **postprocessing_kwargs["principal_components"])
logger.info("\tComputing quality metrics")
qm = sqm.compute_quality_metrics(we, **postprocessing_params["quality_metrics"])
qm = sqm.compute_quality_metrics(we, **postprocessing_kwargs["quality_metrics"])

t_postprocessing_end = time.perf_counter()
elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2)
logger.info(f"Postprocessing time: {elapsed_time_postprocessing}s")

###### CURATION ##############
# ------------------------------------------------------------------------------------
# Curation
# ------------------------------------------------------------------------------------
logger.info("\n\Starting curation...")
curation_notes = ""
t_curation_start = time.perf_counter()
Expand Down Expand Up @@ -452,10 +462,15 @@ def main(
elapsed_time_curation = np.round(t_curation_end - t_curation_start, 2)
logger.info(f"Curation time: {elapsed_time_curation}s")

# ------------------------------------------------------------------------------------
# TODO: Visualization with FIGURL (needs credentials)
# ------------------------------------------------------------------------------------



# ------------------------------------------------------------------------------------
# Conversion and upload
# ------------------------------------------------------------------------------------
logger.info("Writing sorting results to NWB...")
metadata = {
"NWBFile": {
Expand Down Expand Up @@ -603,7 +618,7 @@ def main(
visualization_kwargs = json.loads(os.environ.get("SI_VISUALIZATION_KWARGS", "{}"))

# Get output kwargs from ENV variables
output_kwargs = json.loads(os.environ.get("SI_OUTPUT_KWARGS", "{}"))
output_kwargs = json.loads(os.environ.get("SI_OUTPUT_DATA_KWARGS", "{}"))
output_destination = validate_not_none(output_kwargs, "output_destination")
output_path = validate_not_none(output_kwargs, "output_path")

Expand Down
3 changes: 3 additions & 0 deletions rest/clients/local_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
PostprocessingKwargs,
CurationKwargs,
VisualizationKwargs,
OutputDataKwargs
)


Expand All @@ -30,6 +31,7 @@ def run_sorting(
postprocessing_kwargs: PostprocessingKwargs,
curation_kwargs: CurationKwargs,
visualization_kwargs: VisualizationKwargs,
output_data_kwargs: OutputDataKwargs,
) -> None:
# Pass kwargs as environment variables to the container
env_vars = dict(
Expand All @@ -41,6 +43,7 @@ def run_sorting(
SI_POSTPROCESSING_KWARGS=postprocessing_kwargs.json(),
SI_CURATION_KWARGS=curation_kwargs.json(),
SI_VISUALIZATION_KWARGS=visualization_kwargs.json(),
SI_OUTPUT_DATA_KWARGS=output_data_kwargs.json(),
)

# Local volumes to mount
Expand Down
9 changes: 6 additions & 3 deletions rest/models/sorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pydantic import BaseModel, Field, Extra
from typing import Optional, Dict, List, Union, Tuple
from enum import Enum
from datetime import datetime


# ------------------------------
Expand All @@ -10,10 +11,13 @@ class RunAt(str, Enum):
aws = "aws"
local = "local"

def default_run_identifier():
return datetime.now().strftime("%Y%m%d-%H%M%S")

class RunKwargs(BaseModel):
run_at: RunAt = Field(..., description="Where to run the sorting job. Choose from: aws, local.")
run_identifier: str = Field(..., description="Unique identifier for the run.")
run_description: str = Field(..., description="Description of the run.")
run_identifier: str = Field(default_factory=default_run_identifier, description="Unique identifier for the run.")
run_description: str = Field(default="", description="Description of the run.")
test_with_toy_recording: bool = Field(default=False, description="Whether to test with a toy recording.")
test_with_subrecording: bool = Field(default=False, description="Whether to test with a subrecording.")
test_subrecording_n_frames: Optional[int] = Field(default=30000, description="Number of frames to use for the subrecording.")
Expand Down Expand Up @@ -278,4 +282,3 @@ class VisualizationKwargs(BaseModel):
timeseries: Timeseries
drift: Drift


9 changes: 5 additions & 4 deletions rest/routes/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
PostprocessingKwargs,
CurationKwargs,
VisualizationKwargs,
OutputDataKwargs
)


Expand All @@ -32,6 +33,7 @@ def sorting_background_task(
postprocessing_kwargs: PostprocessingKwargs,
curation_kwargs: CurationKwargs,
visualization_kwargs: VisualizationKwargs,
output_data_kwargs: OutputDataKwargs,
):
# Run sorting and update db entry status
db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING)
Expand All @@ -49,6 +51,7 @@ def sorting_background_task(
postprocessing_kwargs=postprocessing_kwargs,
curation_kwargs=curation_kwargs,
visualization_kwargs=visualization_kwargs,
output_data_kwargs=output_data_kwargs,
)
elif run_at == "aws":
# TODO: Implement this
Expand Down Expand Up @@ -77,12 +80,9 @@ async def route_run_sorting(
postprocessing_kwargs: PostprocessingKwargs,
curation_kwargs: CurationKwargs,
visualization_kwargs: VisualizationKwargs,
output_data_kwargs: OutputDataKwargs,
background_tasks: BackgroundTasks
) -> JSONResponse:
if not run_kwargs.run_identifier:
run_identifier = datetime.now().strftime("%Y%m%d%H%M%S")
else:
run_identifier = run_kwargs.run_identifier
try:
# Create Database entries
db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING)
Expand Down Expand Up @@ -121,6 +121,7 @@ async def route_run_sorting(
postprocessing_kwargs=postprocessing_kwargs,
curation_kwargs=curation_kwargs,
visualization_kwargs=visualization_kwargs,
output_data_kwargs=output_data_kwargs,
)

except Exception as e:
Expand Down

0 comments on commit 38a2b5a

Please sign in to comment.