Skip to content

Commit

Permalink
Misc fixes (#1192)
Browse files Browse the repository at this point in the history
* #1175

* #1185

* #1183

* Fix circular import

* #1163

* #1105

* Fix failing tests, close download subprocesses

* WIP: fix decode changes spikesort tests

* Fix fickle test

* Revert typo
  • Loading branch information
CBroz1 authored Dec 5, 2024
1 parent 6faed4c commit f56aba0
Show file tree
Hide file tree
Showing 15 changed files with 162 additions and 84 deletions.
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Allow part restriction `SpyglassMixinPart.delete` #1192

### Pipelines

Expand All @@ -52,8 +53,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Improve electrodes import efficiency #1125
- Fix logger method call in `common_task` #1132
- Export fixes #1164
- Allow `get_abs_path` to add selection entry.
- Log restrictions and joins.
- Allow `get_abs_path` to add selection entry. #1164
- Log restrictions and joins. #1164
- Check if querying table inherits mixin in `fetch_nwb`. #1192
- Ensure externals entries before adding to export. #1192
- Error specificity in `LabMemberInfo` #1192

- Decoding

Expand All @@ -74,6 +78,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
`open-cv` #1168
- `VideoMaker` class to process frames in multithreaded batches #1168, #1174
- `TrodesPosVideo` updates for `matplotlib` processor #1174
- User prompt if ambiguous insert in `DLCModelSource` #1192

- Spike Sorting

Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--quiet-spy", # don't show logging from spyglass
# "--no-dlc", # don't run DLC tests
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
Expand All @@ -148,6 +148,12 @@ env = [
"TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning:.*",
"ignore::DeprecationWarning:.*",
"ignore::UserWarning:.*",
"ignore::MissingRequiredBuildWarning:.*",
]

[tool.coverage.run]
source = ["*/src/spyglass/*"]
Expand Down
6 changes: 4 additions & 2 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def get_djuser_name(cls, dj_user) -> str:
)

if len(query) != 1:
remedy = f"delete {len(query)-1}" if len(query) > 1 else "add one"
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f" in common.LabMember.LabMemberInfo: {query}"
f"Could not find exactly 1 datajoint user {dj_user}"
+ " in common.LabMember.LabMemberInfo. "
+ f"Please {remedy}: {query}"
)

return query[0]
Expand Down
40 changes: 21 additions & 19 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from typing import List, Union

import datajoint as dj
from datajoint import FreeTable
from datajoint import config as dj_config
from pynwb import NWBHDF5IO

from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.settings import export_dir, test_mode
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_graph import RestrGraph
from spyglass.utils.dj_helper_fn import (
Expand Down Expand Up @@ -174,7 +172,6 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
Return as a list of dicts: [{'file_path': x}]. Default True.
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
Expand Down Expand Up @@ -210,21 +207,26 @@ def _add_externals_to_restr_graph(
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

if raw_files := self._list_raw_files(key):
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = "filepath in ('" + "','".join(raw_files) + "')"
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)
restr_graph.visited.add(raw_name)

if analysis_files := self._list_analysis_files(key):
analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
# to avoid issues with analysis subdir, we use REGEXP
# this is slow, but we're only doing this once, and future-proof
analysis_restr = (
"filepath REGEXP '" + "|".join(analysis_files) + "'"
)
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)
restr_graph.visited.add(analysis_name)

restr_graph.visited.update({raw_name, analysis_name})

Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def create_group(
}
if self & group_key:
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
f"Group {nwb_file_name}: {group_name} already exists. "
+ "Please delete the group before creating a new one"
)
return
self.insert1(
Expand Down
18 changes: 13 additions & 5 deletions src/spyglass/position/v1/position_dlc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,24 @@ def insert_entry(
dj.conn(), full_table_name=part_table.parents()[-1]
) & {"project_name": project_name}

if cls._test_mode: # temporary fix for #1105
project_path = table_query.fetch(limit=1)[0]
else:
project_path = table_query.fetch1("project_path")
n_found = len(table_query)
if n_found != 1:
logger.warning(
f"Found {len(table_query)} entries found for project "
+ f"{project_name}:\n{table_query}"
)

choice = "y"
if n_found > 1 and not cls._test_mode:
choice = dj.utils.user_choice("Use first entry?")[0]
if n_found == 0 or choice != "y":
return

part_table.insert1(
{
"dlc_model_name": dlc_model_name,
"project_name": project_name,
"project_path": project_path,
"project_path": table_query.fetch("project_path", limit=1)[0],
**key,
},
**kwargs,
Expand Down
2 changes: 0 additions & 2 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class DLCProject(SpyglassMixin, dj.Manual):
With ability to edit config, extract frames, label frames
"""

# Add more parameters as secondary keys...
# TODO: collapse params into blob dict
definition = """
project_name : varchar(100) # name of DLC project
---
Expand Down
13 changes: 6 additions & 7 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import spikeinterface.extractors as se

from spyglass.common import BrainRegion, Electrode
from spyglass.common.common_ephys import Raw
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.spikesorting.v1.recording import (
SortGroup,
SpikeSortingRecording,
SpikeSortingRecordingSelection,
)
from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils import SpyglassMixin, logger

schema = dj.schema("spikesorting_v1_curation")

Expand Down Expand Up @@ -84,13 +83,13 @@ def insert_curation(

sort_query = cls & {"sorting_id": sorting_id}
parent_curation_id = max(parent_curation_id, -1)
if parent_curation_id == -1:

parent_query = sort_query & {"curation_id": parent_curation_id}
if parent_curation_id == -1 and len(parent_query):
# check to see if this sorting with a parent of -1
# has already been inserted and if so, warn the user
query = sort_query & {"parent_curation_id": -1}
if query:
Warning("Sorting has already been inserted.")
return query.fetch("KEY")
logger.warning("Sorting has already been inserted.")
return parent_query.fetch("KEY")

# generate curation ID
existing_curation_ids = sort_query.fetch("curation_id")
Expand Down
14 changes: 11 additions & 3 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
Function to get the absolute path to the NWB file.
"""
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.utils.dj_mixin import SpyglassMixin

kwargs["as_dict"] = True # force return as dictionary
attrs = attrs or query_expression.heading.names # if none, all
Expand All @@ -234,11 +235,18 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
}
file_name_str, file_path_fn = tbl_map[which]

# logging arg only if instanced table inherits Mixin
inst = ( # instancing may not be necessary
query_expression()
if isinstance(query_expression, type)
and issubclass(query_expression, dj.Table)
else query_expression
)
arg = dict(log_export=False) if isinstance(inst, SpyglassMixin) else dict()

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression.join(
tbl.proj(nwb2load_filepath=attr_name), log_export=False
)
query_expression.join(tbl.proj(nwb2load_filepath=attr_name), **arg)
).fetch(file_name_str)

# Disabled #1024
Expand Down
11 changes: 8 additions & 3 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _nwb_table_tuple(self) -> tuple:
Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb.
Implemented as a cached_property to avoid circular imports."""
from spyglass.common.common_nwbfile import (
from spyglass.common.common_nwbfile import ( # noqa F401
AnalysisNwbfile,
Nwbfile,
) # noqa F401
)

table_dict = {
AnalysisNwbfile: "analysis_file_abs_path",
Expand Down Expand Up @@ -857,4 +857,9 @@ def delete(self, *args, **kwargs):
"""Delete master and part entries."""
restriction = self.restriction or True # for (tbl & restr).delete()

(self.master & restriction).delete(*args, **kwargs)
try: # try restriction on master
restricted = self.master & restriction
except DataJointError: # if error, assume restr of self
restricted = self & restriction

restricted.delete(*args, **kwargs)
45 changes: 33 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@
import pynwb
import pytest
from datajoint.logging import logger as dj_logger
from hdmf.build.warnings import MissingRequiredBuildWarning
from numba import NumbaWarning
from pandas.errors import PerformanceWarning

from .container import DockerMySQLManager
from .data_downloader import DataDownloader

warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas")
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")
warnings.filterwarnings("ignore", category=ResourceWarning, module="datajoint")

# ------------------------------- TESTS CONFIG -------------------------------

# globals in pytest_configure:
Expand Down Expand Up @@ -114,13 +108,29 @@ def pytest_configure(config):
download_dlc=not NO_DLC,
)

warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings(
"ignore", category=MissingRequiredBuildWarning, module="hdmf"
)
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings(
"ignore", category=PerformanceWarning, module="pandas"
)
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")
warnings.simplefilter("ignore", category=ResourceWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)


def pytest_unconfigure(config):
from spyglass.utils.nwb_helper_fn import close_nwb_files

close_nwb_files()
if TEARDOWN:
SERVER.stop()
analysis_dir = BASE_DIR / "analysis"
for file in analysis_dir.glob("*.nwb"):
file.unlink()


# ---------------------------- FIXTURES, TEST ENV ----------------------------
Expand Down Expand Up @@ -1357,6 +1367,8 @@ def sorter_dict():

@pytest.fixture(scope="session")
def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
pre = spike_v1.SpikeSorting().fetch("KEY", as_dict=True)

key = {
**mini_dict,
**sorter_dict,
Expand All @@ -1367,7 +1379,9 @@ def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
spike_v1.SpikeSortingSelection.insert_selection(key)
spike_v1.SpikeSorting.populate()

yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0]
yield (spike_v1.SpikeSorting() - pre).fetch(
"KEY", as_dict=True, order_by="time_of_sort desc"
)[0]


@pytest.fixture(scope="session")
Expand All @@ -1379,9 +1393,16 @@ def sorting_objs(spike_v1, pop_sort):

@pytest.fixture(scope="session")
def pop_curation(spike_v1, pop_sort):

parent_curation_id = -1
has_sort = spike_v1.CurationV1 & {"sorting_id": pop_sort["sorting_id"]}
if has_sort:
parent_curation_id = has_sort.fetch1("curation_id")

spike_v1.CurationV1.insert_curation(
sorting_id=pop_sort["sorting_id"],
description="testing sort",
parent_curation_id=parent_curation_id,
)

yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch(
Expand Down Expand Up @@ -1418,20 +1439,20 @@ def metric_objs(spike_v1, pop_metric):
@pytest.fixture(scope="session")
def pop_curation_metric(spike_v1, pop_metric, metric_objs):
labels, merge_groups, metrics = metric_objs
parent_dict = {"parent_curation_id": 0}
desc_dict = dict(description="after metric curation")
spike_v1.CurationV1.insert_curation(
sorting_id=(
spike_v1.MetricCurationSelection
& {"metric_curation_id": pop_metric["metric_curation_id"]}
).fetch1("sorting_id"),
**parent_dict,
parent_curation_id=0,
labels=labels,
merge_groups=merge_groups,
metrics=metrics,
description="after metric curation",
**desc_dict,
)

yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0]
yield (spike_v1.CurationV1 & desc_dict).fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
Expand Down
Loading

0 comments on commit f56aba0

Please sign in to comment.