Skip to content

Commit

Permalink
Merge pull request #155 from Janelia-Trackathon-2023/bugfix_frame_buffer
Browse files Browse the repository at this point in the history
Fix frame buffer bug where predecessors were not checked properly
  • Loading branch information
msschwartz21 authored Sep 16, 2024
2 parents 6f71910 + 7268c38 commit ba122c7
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 33 deletions.
14 changes: 12 additions & 2 deletions src/traccuracy/track_errors/divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
fp_divs = g_pred.get_nodes_with_flag(NodeFlag.FP_DIV)
fn_divs = g_gt.get_nodes_with_flag(NodeFlag.FN_DIV)

gt_to_pred_dict = dict(mapper)
pred_to_gt_dict = {pred: gt for (gt, pred) in mapper}

# Compare all pairs of fp and fn
for fp_node, fn_node in itertools.product(fp_divs, fn_divs):
correct = False
Expand All @@ -199,7 +202,10 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
for node in g_pred.graph.successors(fp_node)
]
fn_succ = g_gt.graph.successors(fn_node)
if Counter(fp_succ) != Counter(fn_succ):
fn_succ_mapped = [
gt_to_pred_dict[fn] for fn in fn_succ if fn in gt_to_pred_dict
]
if Counter(fp_succ) != Counter(fn_succ_mapped):
# Daughters don't match so division cannot match
continue

Expand All @@ -220,7 +226,11 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1):
for node in g_gt.graph.successors(fn_node)
]
fp_succ = g_pred.graph.successors(fp_node)
if Counter(fp_succ) != Counter(fn_succ):

fp_succ_mapped = [
pred_to_gt_dict[fp] for fp in fp_succ if fp in pred_to_gt_dict
]
if Counter(fp_succ_mapped) != Counter(fn_succ):
# Daughters don't match so division cannot match
continue

Expand Down
3 changes: 2 additions & 1 deletion tests/metrics/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@


def test_DivisionMetrics():
g_gt, g_pred, mapper = get_division_graphs()
g_gt, g_pred, map_gt, map_pred = get_division_graphs()
mapper = list(zip(map_gt, map_pred))
matched = Matched(
TrackingGraph(g_gt),
TrackingGraph(g_pred),
Expand Down
29 changes: 15 additions & 14 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def get_division_graphs():
1_0 -- 1_1 -- 1_2 -- 1_3 -<
3_4
G2
2_2 -- 2_3 -- 2_4
1_0 -- 1_1 -<
3_2 -- 3_3 -- 3_4
5_2 -- 5_3 -- 5_4
4_0 -- 4_1 -<
6_2 -- 6_3 -- 6_4
"""

G1 = nx.DiGraph()
Expand All @@ -134,21 +134,22 @@ def get_division_graphs():
nx.set_node_attributes(G1, attrs)

G2 = nx.DiGraph()
G2.add_edge("1_0", "1_1")
# Divide to generate 2 lineage
G2.add_edge("1_1", "2_2")
G2.add_edge("2_2", "2_3")
G2.add_edge("2_3", "2_4")
# Divide to generate 3 lineage
G2.add_edge("1_1", "3_2")
G2.add_edge("3_2", "3_3")
G2.add_edge("3_3", "3_4")
G2.add_edge("4_0", "4_1")
# Divide to generate 5 lineage
G2.add_edge("4_1", "5_2")
G2.add_edge("5_2", "5_3")
G2.add_edge("5_3", "5_4")
# Divide to generate 6 lineage
G2.add_edge("4_1", "6_2")
G2.add_edge("6_2", "6_3")
G2.add_edge("6_3", "6_4")

attrs = {}
for node in G2.nodes:
attrs[node] = {"t": int(node[-1:]), "x": 0, "y": 0}
nx.set_node_attributes(G2, attrs)

mapper = [("1_0", "1_0"), ("1_1", "1_1"), ("2_4", "2_4"), ("3_4", "3_4")]
mapped_g1 = ["1_0", "1_1", "2_4", "3_4"]
mapped_g2 = ["4_0", "4_1", "5_4", "6_4"]

return G1, G2, mapper
return G1, G2, mapped_g1, mapped_g2
36 changes: 20 additions & 16 deletions tests/track_errors/test_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,18 @@ def test__get_pred_by_t(straight_graph):


def test__get_succ_by_t():
_, g2, _ = get_division_graphs()
_, g2, _, _ = get_division_graphs()
g2 = TrackingGraph(g2)

# Find 2 frames forward correctly
start_node = "2_2"
start_node = "5_2"
delta_t = 2
end_node = "2_4"
end_node = "5_4"
node = _get_succ_by_t(g2, start_node, delta_t)
assert node == end_node

# 3 frames forward returns None
start_node = "2_2"
start_node = "5_2"
delta_t = 3
end_node = None
node = _get_succ_by_t(g2, start_node, delta_t)
Expand All @@ -159,8 +159,9 @@ def test__get_succ_by_t():
class Test_correct_shifted_divisions:
def test_no_change(self):
# Early division in gt
g_pred, g_gt, mapper = get_division_graphs()
g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True
g_pred, g_gt, map_pred, map_gt = get_division_graphs()
mapper = list(zip(map_gt, map_pred))
g_gt.nodes["4_1"][NodeFlag.FN_DIV] = True
g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True

matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper)
Expand All @@ -171,13 +172,14 @@ def test_no_change(self):
ng_gt = new_matched.gt_graph

assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is True
assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is True
assert ng_gt.nodes["4_1"][NodeFlag.FN_DIV] is True
assert len(ng_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0

def test_fn_early(self):
# Early division in gt
g_pred, g_gt, mapper = get_division_graphs()
g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True
g_pred, g_gt, map_pred, map_gt = get_division_graphs()
mapper = list(zip(map_gt, map_pred))
g_gt.nodes["4_1"][NodeFlag.FN_DIV] = True
g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True

matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper)
Expand All @@ -188,14 +190,15 @@ def test_fn_early(self):
ng_gt = new_matched.gt_graph

assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is False
assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is False
assert ng_gt.nodes["4_1"][NodeFlag.FN_DIV] is False
assert ng_pred.nodes["1_3"][NodeFlag.TP_DIV] is True
assert ng_gt.nodes["1_1"][NodeFlag.TP_DIV] is True
assert ng_gt.nodes["4_1"][NodeFlag.TP_DIV] is True

def test_fp_early(self):
# Early division in pred
g_gt, g_pred, mapper = get_division_graphs()
g_pred.nodes["1_1"][NodeFlag.FP_DIV] = True
g_gt, g_pred, map_gt, map_pred = get_division_graphs()
mapper = list(zip(map_gt, map_pred))
g_pred.nodes["4_1"][NodeFlag.FP_DIV] = True
g_gt.nodes["1_3"][NodeFlag.FN_DIV] = True

matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper)
Expand All @@ -205,14 +208,15 @@ def test_fp_early(self):
ng_pred = new_matched.pred_graph
ng_gt = new_matched.gt_graph

assert ng_pred.nodes["1_1"][NodeFlag.FP_DIV] is False
assert ng_pred.nodes["4_1"][NodeFlag.FP_DIV] is False
assert ng_gt.nodes["1_3"][NodeFlag.FN_DIV] is False
assert ng_pred.nodes["1_1"][NodeFlag.TP_DIV] is True
assert ng_pred.nodes["4_1"][NodeFlag.TP_DIV] is True
assert ng_gt.nodes["1_3"][NodeFlag.TP_DIV] is True


def test_evaluate_division_events():
g_gt, g_pred, mapper = get_division_graphs()
g_gt, g_pred, map_gt, map_pred = get_division_graphs()
mapper = list(zip(map_gt, map_pred))
frame_buffer = 2

matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper)
Expand Down

0 comments on commit ba122c7

Please sign in to comment.