From fa1114eb19edbb540bf4cbb30d1d002b7ba99748 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Fri, 8 Nov 2024 12:24:19 -0800 Subject: [PATCH] Fix failing tests (#1181) --- CHANGELOG.md | 5 ++++- src/spyglass/utils/dj_merge_tables.py | 2 +- src/spyglass/utils/mixins/export.py | 4 ++-- tests/common/conftest.py | 2 +- tests/common/test_position.py | 2 +- tests/common/test_usage.py | 2 +- tests/utils/conftest.py | 8 +++++--- 7 files changed, 15 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index db3a85074..6e91b59ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,8 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116, #1137, #1162 - Fix bool settings imported from dj config file #1117 -- Allow definition of tasks and new probe entries from config #1074, #1120, #1179 +- Allow definition of tasks and new probe entries from config #1074, #1120, + #1179 - Enforce match between ingested nwb probe geometry and existing table entry #1074 - Update DataJoint install and password instructions #1131 @@ -35,9 +36,11 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Remove mambaforge from tests #1153 - Remove debug statement #1164 - Add testing for python versions 3.9, 3.10, 3.11, 3.12 #1169 + - Initialize tables in pytests #1181 - Allow python \< 3.13 #1169 - Remove numpy version restriction #1169 - Merge table delete removes orphaned master entries #1164 +- Edit `merge_fetch` to expect positional before keyword arguments #1181 ### Pipelines diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index 58532fa9f..7d6bf46ba 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -774,7 +774,7 @@ def merge_restrict_class( return parent_class & parent_key def merge_fetch( - self, restriction: str = True, log_export=True, *attrs, **kwargs + self, *attrs, restriction: str = True, log_export=True, **kwargs ) -> list: """Perform a fetch across all parts. If >1 result, return as a list. diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index ba05d14ce..222963ebb 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -7,7 +7,7 @@ from os import environ from re import match as re_match -from datajoint.condition import make_condition +from datajoint.condition import AndList, make_condition from datajoint.table import Table from packaging.version import parse as version_parse @@ -320,7 +320,7 @@ def restrict(self, restriction): log_export = "fetch_nwb" not in self._called_funcs() return self._run_with_log( super().restrict, - restriction=dj.AndList([restriction, self.restriction]), + restriction=AndList([restriction, self.restriction]), log_export=log_export, ) diff --git a/tests/common/conftest.py b/tests/common/conftest.py index 83c0e87b7..ebae0e004 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -56,4 +56,4 @@ def common_ephys(common): @pytest.fixture(scope="session") def pop_common_electrode_group(common_ephys): common_ephys.ElectrodeGroup.populate() - yield common_ephys.ElectrodeGroup + yield common_ephys.ElectrodeGroup() diff --git a/tests/common/test_position.py b/tests/common/test_position.py index 23db2091f..43a979c18 100644 --- a/tests/common/test_position.py +++ b/tests/common/test_position.py @@ -97,7 +97,7 @@ def position_video(common_position): def test_position_video(position_video, upsample_position): _ = position_video.populate() - assert len(position_video) == 1, "Failed to populate PositionVideo table." + assert len(position_video) == 2, "Failed to populate PositionVideo table." def test_convert_to_pixels(): diff --git a/tests/common/test_usage.py b/tests/common/test_usage.py index a3c7f6d70..8e50be14e 100644 --- a/tests/common/test_usage.py +++ b/tests/common/test_usage.py @@ -127,7 +127,7 @@ def test_export_populate(populate_export): table, file = populate_export assert len(file) == 4, "Export tables not captured correctly" - assert len(table) == 35, "Export files not captured correctly" + assert len(table) == 37, "Export files not captured correctly" def test_invalid_export_id(export_tbls): diff --git a/tests/utils/conftest.py b/tests/utils/conftest.py index 3723f191c..de5a80c4d 100644 --- a/tests/utils/conftest.py +++ b/tests/utils/conftest.py @@ -243,14 +243,16 @@ def graph_tables(dj_conn, graph_schema): # Merge inserts after declaring tables merge_keys = graph_schema["PkNode"].fetch("KEY", offset=1, as_dict=True) graph_schema["MergeOutput"].insert(merge_keys, skip_duplicates=True) - merge_child_keys = graph_schema["MergeOutput"].merge_fetch( - True, "merge_id", offset=1 + merge_child_keys = graph_schema["MergeOutput"]().merge_fetch( + "merge_id", restriction=True, offset=1 ) merge_child_inserts = [ (i, j, k + 10) for i, j, k in zip(merge_child_keys, range(4), range(10, 15)) ] - graph_schema["MergeChild"].insert(merge_child_inserts, skip_duplicates=True) + graph_schema["MergeChild"]().insert( + merge_child_inserts, skip_duplicates=True + ) yield graph_schema