Skip to content

Commit

Permalink
Merge branch 'master' into nightly_orphan_cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 authored Dec 5, 2024
2 parents a9261f9 + f56aba0 commit 41f9964
Show file tree
Hide file tree
Showing 32 changed files with 1,151 additions and 336 deletions.
12 changes: 10 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add testing for python versions 3.9, 3.10, 3.11, 3.12 #1169
- Initialize tables in pytests #1181
- Download test data without credentials, trigger on approved PRs #1180
- Add coverage of decoding pipeline to pytests #1155
- Allow python \< 3.13 #1169
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Allow part restriction `SpyglassMixinPart.delete` #1192
- Move cleanup of `IntervalList` orphan entries to nightly cleanup #1195


### Pipelines

- Common
Expand All @@ -52,8 +55,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Improve electrodes import efficiency #1125
- Fix logger method call in `common_task` #1132
- Export fixes #1164
- Allow `get_abs_path` to add selection entry.
- Log restrictions and joins.
- Allow `get_abs_path` to add selection entry. #1164
- Log restrictions and joins. #1164
- Check if querying table inherits mixin in `fetch_nwb`. #1192
- Ensure externals entries before adding to export. #1192
- Error specificity in `LabMemberInfo` #1192

- Decoding

Expand All @@ -74,13 +80,15 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
`open-cv` #1168
- `VideoMaker` class to process frames in multithreaded batches #1168, #1174
- `TrodesPosVideo` updates for `matplotlib` processor #1174
- User prompt if ambiguous insert in `DLCModelSource` #1192

- Spike Sorting

- Fix bug in `get_group_by_shank` #1096
- Fix bug in `_compute_metric` #1099
- Fix bug in `insert_curation` returned key #1114
- Fix handling of waveform extraction sparse parameter #1132
- Limit Artifact detection intervals to valid times #1196

## [0.5.3] (August 27, 2024)

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--quiet-spy", # don't show logging from spyglass
# "--no-dlc", # don't run DLC tests
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
Expand All @@ -148,6 +148,12 @@ env = [
"TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning:.*",
"ignore::DeprecationWarning:.*",
"ignore::UserWarning:.*",
"ignore::MissingRequiredBuildWarning:.*",
]

[tool.coverage.run]
source = ["*/src/spyglass/*"]
Expand All @@ -157,7 +163,8 @@ omit = [ # which submodules have no tests
"*/cli/*",
# "*/common/*",
"*/data_import/*",
"*/decoding/*",
"*/decoding/v0/*",
# "*/decoding/*",
"*/figurl_views/*",
# "*/lfp/*",
# "*/linearization/*",
Expand Down
6 changes: 4 additions & 2 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def get_djuser_name(cls, dj_user) -> str:
)

if len(query) != 1:
remedy = f"delete {len(query)-1}" if len(query) > 1 else "add one"
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f" in common.LabMember.LabMemberInfo: {query}"
f"Could not find exactly 1 datajoint user {dj_user}"
+ " in common.LabMember.LabMemberInfo. "
+ f"Please {remedy}: {query}"
)

return query[0]
Expand Down
40 changes: 21 additions & 19 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from typing import List, Union

import datajoint as dj
from datajoint import FreeTable
from datajoint import config as dj_config
from pynwb import NWBHDF5IO

from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.settings import export_dir, test_mode
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_graph import RestrGraph
from spyglass.utils.dj_helper_fn import (
Expand Down Expand Up @@ -174,7 +172,6 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
Return as a list of dicts: [{'file_path': x}]. Default True.
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
Expand Down Expand Up @@ -210,21 +207,26 @@ def _add_externals_to_restr_graph(
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

if raw_files := self._list_raw_files(key):
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = "filepath in ('" + "','".join(raw_files) + "')"
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)
restr_graph.visited.add(raw_name)

if analysis_files := self._list_analysis_files(key):
analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
# to avoid issues with analysis subdir, we use REGEXP
# this is slow, but we're only doing this once, and future-proof
analysis_restr = (
"filepath REGEXP '" + "|".join(analysis_files) + "'"
)
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)
restr_graph.visited.add(analysis_name)

restr_graph.visited.update({raw_name, analysis_name})

Expand Down
48 changes: 16 additions & 32 deletions src/spyglass/decoding/decoding_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,53 +85,41 @@ def cleanup(self, dry_run=False):
@classmethod
def fetch_results(cls, key):
"""Fetch the decoding results for a given key."""
return cls().merge_get_parent_class(key).fetch_results()
return cls().merge_restrict_class(key).fetch_results()

@classmethod
def fetch_model(cls, key):
"""Fetch the decoding model for a given key."""
return cls().merge_get_parent_class(key).fetch_model()
return cls().merge_restrict_class(key).fetch_model()

@classmethod
def fetch_environments(cls, key):
"""Fetch the decoding environments for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_environments(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_environments(decoding_selection_key)

@classmethod
def fetch_position_info(cls, key):
"""Fetch the decoding position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_position_info(decoding_selection_key)

@classmethod
def fetch_linear_position_info(cls, key):
"""Fetch the decoding linear position info for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(decoding_selection_key)
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_linear_position_info(decoding_selection_key)

@classmethod
def fetch_spike_data(cls, key, filter_by_interval=True):
"""Fetch the decoding spike data for a given key."""
decoding_selection_key = cls.merge_get_parent(key).fetch1("KEY")
return (
cls()
.merge_get_parent_class(key)
.fetch_linear_position_info(
decoding_selection_key, filter_by_interval=filter_by_interval
)
restr_parent = cls().merge_restrict_class(key)
decoding_selection_key = restr_parent.fetch1("KEY")
return restr_parent.fetch_spike_data(
decoding_selection_key, filter_by_interval=filter_by_interval
)

@classmethod
Expand Down Expand Up @@ -167,11 +155,7 @@ def create_decoding_view(cls, key, head_direction_name="head_orientation"):
head_dir=position_info[head_direction_name],
)
else:
(
position_info,
position_variable_names,
) = cls.fetch_linear_position_info(key)
return create_1D_decode_view(
posterior=posterior,
linear_position=position_info["linear_position"],
linear_position=cls.fetch_linear_position_info(key),
)
10 changes: 6 additions & 4 deletions src/spyglass/decoding/v1/clusterless.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ def create_group(
"waveform_features_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # No error on duplicate helps with pytests
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one",
)
return
self.insert1(
group_key,
skip_duplicates=True,
Expand Down Expand Up @@ -586,7 +587,8 @@ def get_ahead_behind_distance(self, track_graph=None, time_slice=None):
classifier.environments[0].track_graph, *traj_data
)
else:
position_info = self.fetch_position_info(self.fetch1("KEY")).loc[
# `fetch_position_info` returns a tuple
position_info = self.fetch_position_info(self.fetch1("KEY"))[0].loc[
time_slice
]
map_position = analysis.maximum_a_posteriori_estimate(posterior)
Expand Down
18 changes: 10 additions & 8 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
restore_classes,
)
from spyglass.position.position_merge import PositionOutput # noqa: F401
from spyglass.utils import SpyglassMixin, SpyglassMixinPart
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger

schema = dj.schema("decoding_core_v1")

Expand Down Expand Up @@ -56,14 +56,15 @@ class DecodingParameters(SpyglassMixin, dj.Lookup):
@classmethod
def insert_default(cls):
"""Insert default decoding parameters"""
cls.insert(cls.contents, skip_duplicates=True)
cls.super().insert(cls.contents, skip_duplicates=True)

def insert(self, rows, *args, **kwargs):
"""Override insert to convert classes to dict before inserting"""
for row in rows:
row["decoding_params"] = convert_classes_to_dict(
vars(row["decoding_params"])
)
params = row["decoding_params"]
if hasattr(params, "__dict__"):
params = vars(params)
row["decoding_params"] = convert_classes_to_dict(params)
super().insert(rows, *args, **kwargs)

def fetch(self, *args, **kwargs):
Expand Down Expand Up @@ -124,10 +125,11 @@ def create_group(
"position_group_name": group_name,
}
if self & group_key:
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists. "
+ "Please delete the group before creating a new one"
)
return
self.insert1(
{
**group_key,
Expand Down
18 changes: 13 additions & 5 deletions src/spyglass/position/v1/position_dlc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,24 @@ def insert_entry(
dj.conn(), full_table_name=part_table.parents()[-1]
) & {"project_name": project_name}

if cls._test_mode: # temporary fix for #1105
project_path = table_query.fetch(limit=1)[0]
else:
project_path = table_query.fetch1("project_path")
n_found = len(table_query)
if n_found != 1:
logger.warning(
f"Found {len(table_query)} entries found for project "
+ f"{project_name}:\n{table_query}"
)

choice = "y"
if n_found > 1 and not cls._test_mode:
choice = dj.utils.user_choice("Use first entry?")[0]
if n_found == 0 or choice != "y":
return

part_table.insert1(
{
"dlc_model_name": dlc_model_name,
"project_name": project_name,
"project_path": project_path,
"project_path": table_query.fetch("project_path", limit=1)[0],
**key,
},
**kwargs,
Expand Down
2 changes: 0 additions & 2 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class DLCProject(SpyglassMixin, dj.Manual):
With ability to edit config, extract frames, label frames
"""

# Add more parameters as secondary keys...
# TODO: collapse params into blob dict
definition = """
project_name : varchar(100) # name of DLC project
---
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/spikesorting/analysis/v1/unit_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def add_annotation(self, key, **kwargs):
).fetch_nwb()[0]
nwb_field_name = _get_spike_obj_name(nwb_file)
spikes = nwb_file[nwb_field_name]["spike_times"].to_list()
if key["unit_id"] > len(spikes):
if key["unit_id"] > len(spikes) and not self._test_mode:
raise ValueError(
f"unit_id {key['unit_id']} is greater than ",
f"the number of units in {key['spikesorting_merge_id']}",
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/spikesorting/v0/spikesorting_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ def _get_artifact_times(
for interval_idx, interval in enumerate(artifact_intervals):
artifact_intervals_s[interval_idx] = [
valid_timestamps[interval[0]] - half_removal_window_s,
valid_timestamps[interval[1]] + half_removal_window_s,
np.minimum(
valid_timestamps[interval[1]] + half_removal_window_s,
valid_timestamps[-1],
),
]
# make the artifact intervals disjoint
if len(artifact_intervals_s) > 1:
Expand Down
5 changes: 4 additions & 1 deletion src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,10 @@ def _get_artifact_times(
),
np.searchsorted(
valid_timestamps,
valid_timestamps[interval[1]] + half_removal_window_s,
np.minimum(
valid_timestamps[interval[1]] + half_removal_window_s,
valid_timestamps[-1],
),
),
]
artifact_intervals_s[interval_idx] = [
Expand Down
Loading

0 comments on commit 41f9964

Please sign in to comment.