From c4760450e8f77d3563694ea2b42152fd2ab1396c Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 4 Dec 2024 12:44:23 -0600 Subject: [PATCH] Fix fickle test --- pyproject.toml | 6 ++-- src/spyglass/decoding/v1/core.py | 4 +-- src/spyglass/spikesorting/v1/curation.py | 13 ++++--- tests/conftest.py | 45 +++++++++++++++++------- tests/spikesorting/test_curation.py | 1 - tests/utils/test_graph.py | 2 +- 6 files changed, 46 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 164d9fc85..0a1cd627f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,8 +149,10 @@ env = [ "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings ] filterwarnings = [ - "ignore::ResourceWarning", - "ignore::DeprecationWarning", + "ignore::ResourceWarning:.*", + "ignore::DeprecationWarning:.*", + "ignore::UserWarning:.*", + "ignore::MissingRequiredBuildWarning:.*", ] [tool.coverage.run] diff --git a/src/spyglass/decoding/v1/core.py b/src/spyglass/decoding/v1/core.py index d58af1643..177a87d22 100644 --- a/src/spyglass/decoding/v1/core.py +++ b/src/spyglass/decoding/v1/core.py @@ -126,8 +126,8 @@ def create_group( } if self & group_key: logger.error( # Easier for pytests to not raise error on duplicate - f"Group {nwb_file_name}: {group_name} already exists" - + "please delete the group before creating a new one" + f"Group {nwb_file_name}: {group_name} already exists. " + + "Please delete the group before creating a new one" ) return self.insert1( diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index 00b1ef81e..593d2c1de 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -9,7 +9,6 @@ import spikeinterface.extractors as se from spyglass.common import BrainRegion, Electrode -from spyglass.common.common_ephys import Raw from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.spikesorting.v1.recording import ( SortGroup, @@ -17,7 +16,7 @@ SpikeSortingRecordingSelection, ) from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection -from spyglass.utils.dj_mixin import SpyglassMixin +from spyglass.utils import SpyglassMixin, logger schema = dj.schema("spikesorting_v1_curation") @@ -84,13 +83,13 @@ def insert_curation( sort_query = cls & {"sorting_id": sorting_id} parent_curation_id = max(parent_curation_id, -1) - if parent_curation_id == -1: + + parent_query = sort_query & {"curation_id": parent_curation_id} + if parent_curation_id == -1 and len(parent_query): # check to see if this sorting with a parent of -1 # has already been inserted and if so, warn the user - query = sort_query & {"parent_curation_id": -1} - if query: - Warning("Sorting has already been inserted.") - return query.fetch("KEY") + logger.warning("Sorting has already been inserted.") + return parent_query.fetch("KEY") # generate curation ID existing_curation_ids = sort_query.fetch("curation_id") diff --git a/tests/conftest.py b/tests/conftest.py index 7c2bb0a33..1cb54cbd1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,19 +17,13 @@ import pynwb import pytest from datajoint.logging import logger as dj_logger +from hdmf.build.warnings import MissingRequiredBuildWarning from numba import NumbaWarning from pandas.errors import PerformanceWarning from .container import DockerMySQLManager from .data_downloader import DataDownloader -warnings.filterwarnings("ignore", module="tensorflow") -warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") -warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") -warnings.filterwarnings("ignore", category=PerformanceWarning, module="pandas") -warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") - - # ------------------------------- TESTS CONFIG ------------------------------- # globals in pytest_configure: @@ -114,6 +108,19 @@ def pytest_configure(config): download_dlc=not NO_DLC, ) + warnings.filterwarnings("ignore", module="tensorflow") + warnings.filterwarnings("ignore", category=UserWarning, module="hdmf") + warnings.filterwarnings( + "ignore", category=MissingRequiredBuildWarning, module="hdmf" + ) + warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn") + warnings.filterwarnings( + "ignore", category=PerformanceWarning, module="pandas" + ) + warnings.filterwarnings("ignore", category=NumbaWarning, module="numba") + warnings.simplefilter("ignore", category=ResourceWarning) + warnings.simplefilter("ignore", category=DeprecationWarning) + def pytest_unconfigure(config): from spyglass.utils.nwb_helper_fn import close_nwb_files @@ -121,6 +128,9 @@ def pytest_unconfigure(config): close_nwb_files() if TEARDOWN: SERVER.stop() + analysis_dir = BASE_DIR / "analysis" + for file in analysis_dir.glob("*.nwb"): + file.unlink() # ---------------------------- FIXTURES, TEST ENV ---------------------------- @@ -1357,6 +1367,8 @@ def sorter_dict(): @pytest.fixture(scope="session") def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict): + pre = spike_v1.SpikeSorting().fetch("KEY", as_dict=True) + key = { **mini_dict, **sorter_dict, @@ -1367,7 +1379,9 @@ def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict): spike_v1.SpikeSortingSelection.insert_selection(key) spike_v1.SpikeSorting.populate() - yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0] + yield (spike_v1.SpikeSorting() - pre).fetch( + "KEY", as_dict=True, order_by="time_of_sort desc" + )[0] @pytest.fixture(scope="session") @@ -1379,9 +1393,16 @@ def sorting_objs(spike_v1, pop_sort): @pytest.fixture(scope="session") def pop_curation(spike_v1, pop_sort): + + parent_curation_id = -1 + has_sort = spike_v1.CurationV1 & {"sorting_id": pop_sort["sorting_id"]} + if has_sort: + parent_curation_id = has_sort.fetch1("curation_id") + spike_v1.CurationV1.insert_curation( sorting_id=pop_sort["sorting_id"], description="testing sort", + parent_curation_id=parent_curation_id, ) yield (spike_v1.CurationV1() & {"parent_curation_id": -1}).fetch( @@ -1418,20 +1439,20 @@ def metric_objs(spike_v1, pop_metric): @pytest.fixture(scope="session") def pop_curation_metric(spike_v1, pop_metric, metric_objs): labels, merge_groups, metrics = metric_objs - parent_dict = {"parent_curation_id": 0} + desc_dict = dict(description="after metric curation") spike_v1.CurationV1.insert_curation( sorting_id=( spike_v1.MetricCurationSelection & {"metric_curation_id": pop_metric["metric_curation_id"]} ).fetch1("sorting_id"), - **parent_dict, + parent_curation_id=0, labels=labels, merge_groups=merge_groups, metrics=metrics, - description="after metric curation", + **desc_dict, ) - yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0] + yield (spike_v1.CurationV1 & desc_dict).fetch("KEY", as_dict=True)[0] @pytest.fixture(scope="session") diff --git a/tests/spikesorting/test_curation.py b/tests/spikesorting/test_curation.py index eac00ab0e..dccff0f69 100644 --- a/tests/spikesorting/test_curation.py +++ b/tests/spikesorting/test_curation.py @@ -80,7 +80,6 @@ def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric): expected = { "bad_channel": "False", "contacts": "", - "curation_id": 1, "description": "after metric curation", "electrode_group_name": "0", "electrode_id": 0, diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index 4acbc2b1d..c51427810 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -157,7 +157,7 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg): ("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"), ("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"), ("MergeChild", "parent_attr > 18", 2, "merge child"), - ("MergeChild", {"parent_attr": 19}, 1, "dict restr"), + ("MergeChild", {"parent_attr": 18}, 1, "dict restr"), ], ) def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg):