Skip to content

Commit

Permalink
improve filter_units execution
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Feb 1, 2024
1 parent 2231173 commit 411c312
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,19 @@ def filter_units(
"""
Filter units based on labels
"""
include_labels = set(include_labels)
exclude_labels = set(exclude_labels)
include_labels = np.unique(include_labels)
exclude_labels = np.unique(exclude_labels)

include_mask = np.zeros(len(labels), dtype=bool)
for ind, unit_labels in enumerate(labels):
if include_labels and not include_labels.intersection(unit_labels):
if isinstance(unit_labels, str):
unit_labels = [unit_labels]
if (
include_labels
and not np.isin(include_labels, unit_labels).any()
):
continue
if exclude_labels.intersection(unit_labels):
if np.isin(exclude_labels, unit_labels).any():
continue
include_mask[ind] = True
return include_mask
Expand Down

0 comments on commit 411c312

Please sign in to comment.