Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc fixes #1192

Merged
merged 11 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Remove numpy version restriction #1169
- Merge table delete removes orphaned master entries #1164
- Edit `merge_fetch` to expect positional before keyword arguments #1181
- Allow part restriction `SpyglassMixinPart.delete` #1192

### Pipelines

Expand All @@ -52,8 +53,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Improve electrodes import efficiency #1125
- Fix logger method call in `common_task` #1132
- Export fixes #1164
- Allow `get_abs_path` to add selection entry.
- Log restrictions and joins.
- Allow `get_abs_path` to add selection entry. #1164
- Log restrictions and joins. #1164
- Check if querying table inherits mixin in `fetch_nwb`. #1192
- Ensure externals entries before adding to export. #1192
- Error specificity in `LabMemberInfo` #1192

- Decoding

Expand All @@ -74,6 +78,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
`open-cv` #1168
- `VideoMaker` class to process frames in multithreaded batches #1168, #1174
- `TrodesPosVideo` updates for `matplotlib` processor #1174
- User prompt if ambiguous insert in `DLCModelSource` #1192

- Spike Sorting

Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ addopts = [
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
# "--quiet-spy", # don't show logging from spyglass
"--quiet-spy", # don't show logging from spyglass
# "--no-dlc", # don't run DLC tests
"--show-capture=no",
"--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger
Expand All @@ -148,6 +148,12 @@ env = [
"TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning:.*",
"ignore::DeprecationWarning:.*",
"ignore::UserWarning:.*",
"ignore::MissingRequiredBuildWarning:.*",
]

[tool.coverage.run]
source = ["*/src/spyglass/*"]
Expand Down
6 changes: 4 additions & 2 deletions src/spyglass/common/common_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,11 @@ def get_djuser_name(cls, dj_user) -> str:
)

if len(query) != 1:
remedy = f"delete {len(query)-1}" if len(query) > 1 else "add one"
raise ValueError(
f"Could not find name for datajoint user {dj_user}"
+ f" in common.LabMember.LabMemberInfo: {query}"
f"Could not find exactly 1 datajoint user {dj_user}"
+ " in common.LabMember.LabMemberInfo. "
+ f"Please {remedy}: {query}"
)

return query[0]
Expand Down
40 changes: 21 additions & 19 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from typing import List, Union

import datajoint as dj
from datajoint import FreeTable
from datajoint import config as dj_config
from pynwb import NWBHDF5IO

from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.settings import export_dir, test_mode
from spyglass.settings import test_mode
from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger
from spyglass.utils.dj_graph import RestrGraph
from spyglass.utils.dj_helper_fn import (
Expand Down Expand Up @@ -174,7 +172,6 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
Return as a list of dicts: [{'file_path': x}]. Default True.
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
Expand Down Expand Up @@ -210,21 +207,26 @@ def _add_externals_to_restr_graph(
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

if raw_files := self._list_raw_files(key):
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = "filepath in ('" + "','".join(raw_files) + "')"
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)
restr_graph.visited.add(raw_name)

if analysis_files := self._list_analysis_files(key):
analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
# to avoid issues with analysis subdir, we use REGEXP
# this is slow, but we're only doing this once, and future-proof
analysis_restr = (
"filepath REGEXP '" + "|".join(analysis_files) + "'"
)
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)
restr_graph.visited.add(analysis_name)

restr_graph.visited.update({raw_name, analysis_name})

Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def create_group(
}
if self & group_key:
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
f"Group {nwb_file_name}: {group_name} already exists. "
+ "Please delete the group before creating a new one"
)
return
self.insert1(
Expand Down
18 changes: 13 additions & 5 deletions src/spyglass/position/v1/position_dlc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,16 +98,24 @@ def insert_entry(
dj.conn(), full_table_name=part_table.parents()[-1]
) & {"project_name": project_name}

if cls._test_mode: # temporary fix for #1105
project_path = table_query.fetch(limit=1)[0]
else:
project_path = table_query.fetch1("project_path")
n_found = len(table_query)
if n_found != 1:
logger.warning(
f"Found {len(table_query)} entries found for project "
+ f"{project_name}:\n{table_query}"
)

choice = "y"
if n_found > 1 and not cls._test_mode:
choice = dj.utils.user_choice("Use first entry?")[0]
if n_found == 0 or choice != "y":
return

part_table.insert1(
{
"dlc_model_name": dlc_model_name,
"project_name": project_name,
"project_path": project_path,
"project_path": table_query.fetch("project_path", limit=1)[0],
**key,
},
**kwargs,
Expand Down
2 changes: 0 additions & 2 deletions src/spyglass/position/v1/position_dlc_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ class DLCProject(SpyglassMixin, dj.Manual):
With ability to edit config, extract frames, label frames
"""

# Add more parameters as secondary keys...
# TODO: collapse params into blob dict
definition = """
project_name : varchar(100) # name of DLC project
---
Expand Down
13 changes: 6 additions & 7 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import spikeinterface.extractors as se

from spyglass.common import BrainRegion, Electrode
from spyglass.common.common_ephys import Raw
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.spikesorting.v1.recording import (
SortGroup,
SpikeSortingRecording,
SpikeSortingRecordingSelection,
)
from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils import SpyglassMixin, logger

schema = dj.schema("spikesorting_v1_curation")

Expand Down Expand Up @@ -84,13 +83,13 @@ def insert_curation(

sort_query = cls & {"sorting_id": sorting_id}
parent_curation_id = max(parent_curation_id, -1)
if parent_curation_id == -1:

parent_query = sort_query & {"curation_id": parent_curation_id}
if parent_curation_id == -1 and len(parent_query):
# check to see if this sorting with a parent of -1
# has already been inserted and if so, warn the user
query = sort_query & {"parent_curation_id": -1}
if query:
Warning("Sorting has already been inserted.")
return query.fetch("KEY")
logger.warning("Sorting has already been inserted.")
return parent_query.fetch("KEY")

# generate curation ID
existing_curation_ids = sort_query.fetch("curation_id")
Expand Down
14 changes: 11 additions & 3 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
Function to get the absolute path to the NWB file.
"""
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.utils.dj_mixin import SpyglassMixin

kwargs["as_dict"] = True # force return as dictionary
attrs = attrs or query_expression.heading.names # if none, all
Expand All @@ -234,11 +235,18 @@ def get_nwb_table(query_expression, tbl, attr_name, *attrs, **kwargs):
}
file_name_str, file_path_fn = tbl_map[which]

# logging arg only if instanced table inherits Mixin
inst = ( # instancing may not be necessary
query_expression()
if isinstance(query_expression, type)
and issubclass(query_expression, dj.Table)
else query_expression
)
arg = dict(log_export=False) if isinstance(inst, SpyglassMixin) else dict()

# TODO: check that the query_expression restricts tbl - CBroz
nwb_files = (
query_expression.join(
tbl.proj(nwb2load_filepath=attr_name), log_export=False
)
query_expression.join(tbl.proj(nwb2load_filepath=attr_name), **arg)
).fetch(file_name_str)

# Disabled #1024
Expand Down
11 changes: 8 additions & 3 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ def _nwb_table_tuple(self) -> tuple:

Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb.
Implemented as a cached_property to avoid circular imports."""
from spyglass.common.common_nwbfile import (
from spyglass.common.common_nwbfile import ( # noqa F401
AnalysisNwbfile,
Nwbfile,
) # noqa F401
)

table_dict = {
AnalysisNwbfile: "analysis_file_abs_path",
Expand Down Expand Up @@ -857,4 +857,9 @@ def delete(self, *args, **kwargs):
"""Delete master and part entries."""
restriction = self.restriction or True # for (tbl & restr).delete()

(self.master & restriction).delete(*args, **kwargs)
try: # try restriction on master
restricted = self.master & restriction
except DataJointError: # if error, assume restr of self
restricted = self & restriction

samuelbray32 marked this conversation as resolved.
Show resolved Hide resolved
restricted.delete(*args, **kwargs)
45 changes: 33 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@
import pynwb
import pytest
from datajoint.logging import logger as dj_logger
from hdmf.build.warnings import MissingRequiredBuildWarning
from numba import NumbaWarning
from pandas.errors import PerformanceWarning

from .container import DockerMySQLManager
from .data_downloader import DataDownloader

warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas")
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")
warnings.filterwarnings("ignore", category=ResourceWarning, module="datajoint")

# ------------------------------- TESTS CONFIG -------------------------------

# globals in pytest_configure:
Expand Down Expand Up @@ -114,13 +108,29 @@ def pytest_configure(config):
download_dlc=not NO_DLC,
)

warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings(
"ignore", category=MissingRequiredBuildWarning, module="hdmf"
)
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings(
"ignore", category=PerformanceWarning, module="pandas"
)
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")
warnings.simplefilter("ignore", category=ResourceWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)


def pytest_unconfigure(config):
from spyglass.utils.nwb_helper_fn import close_nwb_files

close_nwb_files()
if TEARDOWN:
SERVER.stop()
analysis_dir = BASE_DIR / "analysis"
for file in analysis_dir.glob("*.nwb"):
file.unlink()


# ---------------------------- FIXTURES, TEST ENV ----------------------------
Expand Down Expand Up @@ -1357,6 +1367,8 @@ def sorter_dict():

@pytest.fixture(scope="session")
def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
pre = spike_v1.SpikeSorting().fetch("KEY", as_dict=True)

key = {
**mini_dict,
**sorter_dict,
Expand All @@ -1367,7 +1379,9 @@ def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
spike_v1.SpikeSortingSelection.insert_selection(key)
spike_v1.SpikeSorting.populate()

yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0]
yield (spike_v1.SpikeSorting() - pre).fetch(
"KEY", as_dict=True, order_by="time_of_sort desc"
)[0]


@pytest.fixture(scope="session")
Expand All @@ -1379,9 +1393,16 @@ def sorting_objs(spike_v1, pop_sort):

@pytest.fixture(scope="session")
def pop_curation(spike_v1, pop_sort):

parent_curation_id = -1
has_sort = spike_v1.CurationV1 & {"sorting_id": pop_sort["sorting_id"]}
if has_sort:
parent_curation_id = has_sort.fetch1("curation_id")

spike_v1.CurationV1.insert_curation(
sorting_id=pop_sort["sorting_id"],
description="testing sort",
parent_curation_id=parent_curation_id,
)

yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch(
Expand Down Expand Up @@ -1418,20 +1439,20 @@ def metric_objs(spike_v1, pop_metric):
@pytest.fixture(scope="session")
def pop_curation_metric(spike_v1, pop_metric, metric_objs):
labels, merge_groups, metrics = metric_objs
parent_dict = {"parent_curation_id": 0}
desc_dict = dict(description="after metric curation")
spike_v1.CurationV1.insert_curation(
sorting_id=(
spike_v1.MetricCurationSelection
& {"metric_curation_id": pop_metric["metric_curation_id"]}
).fetch1("sorting_id"),
**parent_dict,
parent_curation_id=0,
labels=labels,
merge_groups=merge_groups,
metrics=metrics,
description="after metric curation",
**desc_dict,
)

yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0]
yield (spike_v1.CurationV1 & desc_dict).fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
Expand Down
Loading
Loading