From e8511a52fe07cbc7ee2549ea3face3071ff7acd7 Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Tue, 1 Oct 2024 13:27:14 -0500 Subject: [PATCH] Long distance restr fix (#1137) * Fix 1136 tests * Update changelog --- CHANGELOG.md | 2 +- pyproject.toml | 2 +- src/spyglass/utils/database_settings.py | 2 +- src/spyglass/utils/dj_graph.py | 7 +++---- src/spyglass/utils/dj_merge_tables.py | 4 ++-- tests/utils/test_graph.py | 2 +- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 344aae634..9e87a2bd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop() - Add docstrings to all public methods #1076 - Update DataJoint to 0.14.2 #1081 - Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086, #1126 -- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116 +- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116, #1137 - Fix bool settings imported from dj config file #1117 - Allow definition of tasks and new probe entries from config #1074, #1120 - Enforce match between ingested nwb probe geometry and existing table entry diff --git a/pyproject.toml b/pyproject.toml index e5f8dae5b..ed3a570c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,7 +126,7 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ # "-sv", # no capture, verbose output - "--sw", # stepwise: resume with next test after failure + # "--sw", # stepwise: resume with next test after failure # "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index e7f36479e..56e2aa28f 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -15,7 +15,7 @@ "spikesorting", "decoding", "position", - "position_linearization", + "linearization", "ripple", "lfp", "waveform", diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 26a944e20..48847f61b 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -22,7 +22,6 @@ all_simple_paths, shortest_path, ) -from networkx.algorithms.dag import topological_sort from tqdm import tqdm from spyglass.utils import logger @@ -478,7 +477,7 @@ def _topo_sort( if not self._is_out(node, warn=False) ] graph = self.graph.subgraph(nodes) if subgraph else self.graph - ordered = dj_topo_sort(list(topological_sort(graph))) + ordered = dj_topo_sort(graph) if reverse: ordered.reverse() return [n for n in ordered if n in nodes] @@ -869,10 +868,10 @@ def __init__( self.direction = Direction.DOWN self.leaf = None - if search_restr and not parent: + if search_restr and not self.parent: # using `parent` fails on empty self.direction = Direction.UP self.leaf = self.child - if search_restr and not child: + if search_restr and not self.child: self.direction = Direction.DOWN self.leaf = self.parent if self.leaf: diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index dbe626408..bf13aa254 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -3,7 +3,7 @@ from pprint import pprint from re import sub as re_sub from time import time -from typing import Union, List +from typing import List, Union import datajoint as dj from datajoint.condition import make_condition @@ -348,7 +348,7 @@ def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None: ) key = keys[0] if part & key: - print(f"Key already in part {part_name}: {key}") + logger.info(f"Key already in part {part_name}: {key}") continue master_sk = {cls()._reserved_sk: part_name} uuid = dj.hash.key_hash(key | master_sk) diff --git a/tests/utils/test_graph.py b/tests/utils/test_graph.py index c51427810..4acbc2b1d 100644 --- a/tests/utils/test_graph.py +++ b/tests/utils/test_graph.py @@ -157,7 +157,7 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg): ("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"), ("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"), ("MergeChild", "parent_attr > 18", 2, "merge child"), - ("MergeChild", {"parent_attr": 18}, 1, "dict restr"), + ("MergeChild", {"parent_attr": 19}, 1, "dict restr"), ], ) def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg):