Skip to content

Commit

Permalink
Fix fickle test
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Dec 4, 2024
1 parent 5f2231f commit c476045
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 25 deletions.
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ env = [
"TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings
]
filterwarnings = [
"ignore::ResourceWarning",
"ignore::DeprecationWarning",
"ignore::ResourceWarning:.*",
"ignore::DeprecationWarning:.*",
"ignore::UserWarning:.*",
"ignore::MissingRequiredBuildWarning:.*",
]

[tool.coverage.run]
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/decoding/v1/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def create_group(
}
if self & group_key:
logger.error( # Easier for pytests to not raise error on duplicate
f"Group {nwb_file_name}: {group_name} already exists"
+ "please delete the group before creating a new one"
f"Group {nwb_file_name}: {group_name} already exists. "
+ "Please delete the group before creating a new one"
)
return
self.insert1(
Expand Down
13 changes: 6 additions & 7 deletions src/spyglass/spikesorting/v1/curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
import spikeinterface.extractors as se

from spyglass.common import BrainRegion, Electrode
from spyglass.common.common_ephys import Raw
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.spikesorting.v1.recording import (
SortGroup,
SpikeSortingRecording,
SpikeSortingRecordingSelection,
)
from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection
from spyglass.utils.dj_mixin import SpyglassMixin
from spyglass.utils import SpyglassMixin, logger

schema = dj.schema("spikesorting_v1_curation")

Expand Down Expand Up @@ -84,13 +83,13 @@ def insert_curation(

sort_query = cls & {"sorting_id": sorting_id}
parent_curation_id = max(parent_curation_id, -1)
if parent_curation_id == -1:

parent_query = sort_query & {"curation_id": parent_curation_id}
if parent_curation_id == -1 and len(parent_query):
# check to see if this sorting with a parent of -1
# has already been inserted and if so, warn the user
query = sort_query & {"parent_curation_id": -1}
if query:
Warning("Sorting has already been inserted.")
return query.fetch("KEY")
logger.warning("Sorting has already been inserted.")
return parent_query.fetch("KEY")

# generate curation ID
existing_curation_ids = sort_query.fetch("curation_id")
Expand Down
45 changes: 33 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,13 @@
import pynwb
import pytest
from datajoint.logging import logger as dj_logger
from hdmf.build.warnings import MissingRequiredBuildWarning
from numba import NumbaWarning
from pandas.errors import PerformanceWarning

from .container import DockerMySQLManager
from .data_downloader import DataDownloader

warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas")
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")


# ------------------------------- TESTS CONFIG -------------------------------

# globals in pytest_configure:
Expand Down Expand Up @@ -114,13 +108,29 @@ def pytest_configure(config):
download_dlc=not NO_DLC,
)

warnings.filterwarnings("ignore", module="tensorflow")
warnings.filterwarnings("ignore", category=UserWarning, module="hdmf")
warnings.filterwarnings(
"ignore", category=MissingRequiredBuildWarning, module="hdmf"
)
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
warnings.filterwarnings(
"ignore", category=PerformanceWarning, module="pandas"
)
warnings.filterwarnings("ignore", category=NumbaWarning, module="numba")
warnings.simplefilter("ignore", category=ResourceWarning)
warnings.simplefilter("ignore", category=DeprecationWarning)


def pytest_unconfigure(config):
from spyglass.utils.nwb_helper_fn import close_nwb_files

close_nwb_files()
if TEARDOWN:
SERVER.stop()
analysis_dir = BASE_DIR / "analysis"
for file in analysis_dir.glob("*.nwb"):
file.unlink()


# ---------------------------- FIXTURES, TEST ENV ----------------------------
Expand Down Expand Up @@ -1357,6 +1367,8 @@ def sorter_dict():

@pytest.fixture(scope="session")
def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
pre = spike_v1.SpikeSorting().fetch("KEY", as_dict=True)

key = {
**mini_dict,
**sorter_dict,
Expand All @@ -1367,7 +1379,9 @@ def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict):
spike_v1.SpikeSortingSelection.insert_selection(key)
spike_v1.SpikeSorting.populate()

yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0]
yield (spike_v1.SpikeSorting() - pre).fetch(
"KEY", as_dict=True, order_by="time_of_sort desc"
)[0]


@pytest.fixture(scope="session")
Expand All @@ -1379,9 +1393,16 @@ def sorting_objs(spike_v1, pop_sort):

@pytest.fixture(scope="session")
def pop_curation(spike_v1, pop_sort):

parent_curation_id = -1
has_sort = spike_v1.CurationV1 & {"sorting_id": pop_sort["sorting_id"]}
if has_sort:
parent_curation_id = has_sort.fetch1("curation_id")

spike_v1.CurationV1.insert_curation(
sorting_id=pop_sort["sorting_id"],
description="testing sort",
parent_curation_id=parent_curation_id,
)

yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch(
Expand Down Expand Up @@ -1418,20 +1439,20 @@ def metric_objs(spike_v1, pop_metric):
@pytest.fixture(scope="session")
def pop_curation_metric(spike_v1, pop_metric, metric_objs):
labels, merge_groups, metrics = metric_objs
parent_dict = {"parent_curation_id": 0}
desc_dict = dict(description="after metric curation")
spike_v1.CurationV1.insert_curation(
sorting_id=(
spike_v1.MetricCurationSelection
& {"metric_curation_id": pop_metric["metric_curation_id"]}
).fetch1("sorting_id"),
**parent_dict,
parent_curation_id=0,
labels=labels,
merge_groups=merge_groups,
metrics=metrics,
description="after metric curation",
**desc_dict,
)

yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0]
yield (spike_v1.CurationV1 & desc_dict).fetch("KEY", as_dict=True)[0]


@pytest.fixture(scope="session")
Expand Down
1 change: 0 additions & 1 deletion tests/spikesorting/test_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric):
expected = {
"bad_channel": "False",
"contacts": "",
"curation_id": 1,
"description": "after metric curation",
"electrode_group_name": "0",
"electrode_id": 0,
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg):
("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"),
("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"),
("MergeChild", "parent_attr > 18", 2, "merge child"),
("MergeChild", {"parent_attr": 19}, 1, "dict restr"),
("MergeChild", {"parent_attr": 18}, 1, "dict restr"),
],
)
def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg):
Expand Down

0 comments on commit c476045

Please sign in to comment.