Skip to content

Commit

Permalink
Long distance restr fix (#1137)
Browse files Browse the repository at this point in the history
* Fix 1136 tests

* Update changelog
  • Loading branch information
CBroz1 authored Oct 1, 2024
1 parent 05444bb commit e8511a5
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Add docstrings to all public methods #1076
- Update DataJoint to 0.14.2 #1081
- Allow restriction based on parent keys in `Merge.fetch_nwb()` #1086, #1126
- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116
- Import `datajoint.dependencies.unite_master_parts` -> `topo_sort` #1116, #1137
- Fix bool settings imported from dj config file #1117
- Allow definition of tasks and new probe entries from config #1074, #1120
- Enforce match between ingested nwb probe geometry and existing table entry
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ignore-words-list = 'nevers'
minversion = "7.0"
addopts = [
# "-sv", # no capture, verbose output
"--sw", # stepwise: resume with next test after failure
# "--sw", # stepwise: resume with next test after failure
# "--pdb", # drop into debugger on failure
"-p no:warnings",
# "--no-teardown", # don't teardown the database after tests
Expand Down
2 changes: 1 addition & 1 deletion src/spyglass/utils/database_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"spikesorting",
"decoding",
"position",
"position_linearization",
"linearization",
"ripple",
"lfp",
"waveform",
Expand Down
7 changes: 3 additions & 4 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
all_simple_paths,
shortest_path,
)
from networkx.algorithms.dag import topological_sort
from tqdm import tqdm

from spyglass.utils import logger
Expand Down Expand Up @@ -478,7 +477,7 @@ def _topo_sort(
if not self._is_out(node, warn=False)
]
graph = self.graph.subgraph(nodes) if subgraph else self.graph
ordered = dj_topo_sort(list(topological_sort(graph)))
ordered = dj_topo_sort(graph)
if reverse:
ordered.reverse()
return [n for n in ordered if n in nodes]
Expand Down Expand Up @@ -869,10 +868,10 @@ def __init__(
self.direction = Direction.DOWN

self.leaf = None
if search_restr and not parent:
if search_restr and not self.parent: # using `parent` fails on empty
self.direction = Direction.UP
self.leaf = self.child
if search_restr and not child:
if search_restr and not self.child:
self.direction = Direction.DOWN
self.leaf = self.parent
if self.leaf:
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pprint import pprint
from re import sub as re_sub
from time import time
from typing import Union, List
from typing import List, Union

import datajoint as dj
from datajoint.condition import make_condition
Expand Down Expand Up @@ -348,7 +348,7 @@ def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None:
)
key = keys[0]
if part & key:
print(f"Key already in part {part_name}: {key}")
logger.info(f"Key already in part {part_name}: {key}")
continue
master_sk = {cls()._reserved_sk: part_name}
uuid = dj.hash.key_hash(key | master_sk)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_restr_from_upstream(graph_tables, restr, expect_n, msg):
("PkAliasNode", "parent_attr > 17", 2, "pk pk alias"),
("SkAliasNode", "parent_attr > 18", 2, "sk sk alias"),
("MergeChild", "parent_attr > 18", 2, "merge child"),
("MergeChild", {"parent_attr": 18}, 1, "dict restr"),
("MergeChild", {"parent_attr": 19}, 1, "dict restr"),
],
)
def test_restr_from_downstream(graph_tables, table, restr, expect_n, msg):
Expand Down

0 comments on commit e8511a5

Please sign in to comment.