Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prune tracking graph API #111

Merged
merged 30 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f85c2a9
Remove unused functions and arguments
cmalinmayor Nov 4, 2023
25ec94f
Use get_node/edge_attribute in ctc metrics
cmalinmayor Nov 4, 2023
1a79957
Use get_node/edge_attributes in metrics tests
cmalinmayor Nov 4, 2023
d3d0251
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Nov 4, 2023
17e1190
Change nodes and edges to properties
cmalinmayor Nov 28, 2023
cdd8009
Merge branch 'main' into prune-tracking-graph
cmalinmayor Nov 28, 2023
f4d173e
Refactor IOU matcher
cmalinmayor Nov 28, 2023
baf79d4
Use get_nodes_with_flag in division metrics
cmalinmayor Nov 28, 2023
3b7c3bc
Remove 'get_nodes_with_attribute' from TrackingGraph
cmalinmayor Nov 28, 2023
c0d5d04
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Nov 28, 2023
2b1b963
Merge branch 'main' into prune-tracking-graph
cmalinmayor Nov 29, 2023
e1f8545
Separate setting flag on one node/edge from all nodes/edges
cmalinmayor Nov 29, 2023
0886696
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Nov 29, 2023
1e9aa7e
Simplify IOU dictionary naming
cmalinmayor Nov 29, 2023
df06ed6
Merge remote-tracking branch 'origin/prune-tracking-graph' into prune…
cmalinmayor Nov 29, 2023
e34b44b
Use dict syntax to get node and edge flags
cmalinmayor Nov 29, 2023
7426439
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Nov 29, 2023
72d143e
Fix ruff and mypy complaints
cmalinmayor Nov 29, 2023
e88f0ed
style(pre-commit.ci): auto fixes [...]
pre-commit-ci[bot] Nov 29, 2023
8e2c0e3
Actually fix mypy typing issue
cmalinmayor Nov 29, 2023
a5e8d62
Actually actually fix mypy typing errors
cmalinmayor Nov 29, 2023
6fc9015
Change from Node/EdgeAttr to Node/EdgeFlag
cmalinmayor Nov 29, 2023
99933ed
Add typing annotations to TrackingGraph
cmalinmayor Nov 30, 2023
b68e3a8
Return set from TrackingGraph node/edge_by_flag
cmalinmayor Dec 5, 2023
f8e4973
Separate out and test helper function in iou matcher
cmalinmayor Jan 8, 2024
87dda14
Test nodes/edges_by_flag dict when updating flags
cmalinmayor Jan 8, 2024
a49865a
Merge branch 'main' into prune-tracking-graph
cmalinmayor Jan 8, 2024
528323b
Remove get from test names for flag setting
cmalinmayor Jan 8, 2024
f036f28
Remove get_preds and get_succs
cmalinmayor Jan 10, 2024
f50eca2
Merge branch 'main' into prune-tracking-graph
cmalinmayor Jan 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/traccuracy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
__version__ = "uninstalled"

from ._run_metrics import run_metrics
from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph
from ._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph

__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"]
__all__ = ["TrackingGraph", "run_metrics", "NodeFlag", "EdgeFlag"]
484 changes: 161 additions & 323 deletions src/traccuracy/_tracking_graph.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/traccuracy/matchers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def compute_mapping(
matched.matcher_info = self.info

# Report matching performance
total_gt = len(matched.gt_graph.nodes())
total_gt = len(matched.gt_graph.nodes)
matched_gt = len({m[0] for m in matched.mapping})
total_pred = len(matched.pred_graph.nodes())
total_pred = len(matched.pred_graph.nodes)
matched_pred = len({m[1] for m in matched.mapping})
logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.")
logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.")
Expand Down
5 changes: 4 additions & 1 deletion src/traccuracy/matchers/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
traccuracy.matchers.Matched: Matched data object containing the CTC mapping

Raises:
ValueError: GT and pred segmentations must be the same shape
ValueError: if GT and pred segmentations are None or are not the same shape
"""
gt = gt_graph
pred = pred_graph
Expand All @@ -46,6 +46,9 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph):
G_gt, mask_gt = gt, gt.segmentation
G_pred, mask_pred = pred, pred.segmentation

if mask_gt is None or mask_pred is None:
raise ValueError("Segmentation is None, cannot perform matching")

if mask_gt.shape != mask_pred.shape:
raise ValueError("Segmentation shapes must match between gt and pred")

Expand Down
50 changes: 38 additions & 12 deletions src/traccuracy/matchers/_iou.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Hashable

import numpy as np
from tqdm import tqdm

Expand Down Expand Up @@ -46,6 +48,35 @@ def _match_nodes(gt, res, threshold=1):
return gtcells, rescells


def _construct_time_to_seg_id_map(
graph: TrackingGraph,
) -> dict[int, dict[Hashable, Hashable]]:
"""For each time frame in the graph, create a mapping from segmentation ids
(the ids in the segmentation array, stored in graph.label_key) to the
node ids (the ids of the TrackingGraph nodes).

Args:
graph(TrackingGraph): a tracking graph with a label_key on each node

Returns:
dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}}

Raises:
AssertionError: If two nodes in a time frame have the same segmentation_id
"""
time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {}
for node_id, data in graph.nodes(data=True):
time = data[graph.frame_key]
seg_id = data[graph.label_key]
seg_id_to_node_id_map = time_to_seg_id_map.get(time, {})
assert (
seg_id not in seg_id_to_node_id_map
), f"Segmentation ID {seg_id} occurred twice in frame {time}."
seg_id_to_node_id_map[seg_id] = node_id
time_to_seg_id_map[time] = seg_id_to_node_id_map
return time_to_seg_id_map


def match_iou(gt, pred, threshold=0.6):
"""Identifies pairs of cells between gt and pred that have iou > threshold

Expand Down Expand Up @@ -78,24 +109,19 @@ def match_iou(gt, pred, threshold=0.6):
# Get overlaps for each frame
frame_range = range(gt.start_frame, gt.end_frame)
total = len(list(frame_range))

gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt)
pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred)

for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total):
matches = _match_nodes(
gt.segmentation[i], pred.segmentation[i], threshold=threshold
)

# Construct node id tuple for each match
for gt_id, pred_id in zip(*matches):
for gt_seg_id, pred_seg_id in zip(*matches):
# Find node id based on time and segmentation label
gt_node = gt.get_nodes_with_attribute(
gt.label_key,
criterion=lambda x: x == gt_id, # noqa
limit_to=gt.get_nodes_in_frame(t),
)[0]
pred_node = pred.get_nodes_with_attribute(
pred.label_key,
criterion=lambda x: x == pred_id, # noqa
limit_to=pred.get_nodes_in_frame(t),
)[0]
gt_node = gt_time_to_seg_id_map[t][gt_seg_id]
pred_node = pred_time_to_seg_id_map[t][pred_seg_id]
mapper.append((gt_node, pred_node))
return mapper

Expand Down
14 changes: 7 additions & 7 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
from traccuracy._tracking_graph import EdgeFlag, NodeFlag
from traccuracy.track_errors._ctc import evaluate_ctc_events

from ._base import Metric
Expand Down Expand Up @@ -38,14 +38,14 @@ def _compute(self, data: Matched):
evaluate_ctc_events(data)

vertex_error_counts = {
"ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)),
"ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)),
}
edge_error_counts = {
"ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)),
"ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)),
}
error_sum = get_weighted_error_sum(
vertex_error_counts,
Expand Down
14 changes: 4 additions & 10 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import NodeAttr
from traccuracy._tracking_graph import NodeFlag
from traccuracy.track_errors.divisions import _evaluate_division_events

from ._base import Metric
Expand Down Expand Up @@ -90,15 +90,9 @@ def _compute(self, data: Matched):
}

def _calculate_metrics(self, g_gt, g_pred):
tp_division_count = len(
g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x)
)
fn_division_count = len(
g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x)
)
fp_division_count = len(
g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x)
)
tp_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.TP_DIV))
fn_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.FN_DIV))
fp_division_count = len(g_pred.get_nodes_with_flag(NodeFlag.FP_DIV))

try:
recall = tp_division_count / (tp_division_count + fn_division_count)
Expand Down
67 changes: 32 additions & 35 deletions src/traccuracy/track_errors/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tqdm import tqdm

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
from traccuracy._tracking_graph import EdgeFlag, NodeFlag

if TYPE_CHECKING:
from traccuracy.matchers import Matched
Expand Down Expand Up @@ -39,12 +39,12 @@ def get_vertex_errors(matched_data: Matched):
logger.info("Node errors already calculated. Skipping graph annotation")
return

comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.TRUE_POS, False)
comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.NON_SPLIT, False)
comp_graph.set_flag_on_all_nodes(NodeFlag.TRUE_POS, False)
comp_graph.set_flag_on_all_nodes(NodeFlag.NON_SPLIT, False)

# will flip this when we come across the vertex in the mapping
comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.FALSE_POS, True)
gt_graph.set_node_attribute(list(gt_graph.nodes()), NodeAttr.FALSE_NEG, True)
comp_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, True)
gt_graph.set_flag_on_all_nodes(NodeFlag.FALSE_NEG, True)

# we need to know how many computed vertices are "non-split", so we make
# a mapping of gt vertices to their matched comp vertices
Expand All @@ -57,15 +57,16 @@ def get_vertex_errors(matched_data: Matched):
gt_ids = dict_mapping[pred_id]
if len(gt_ids) == 1:
gid = gt_ids[0]
comp_graph.set_node_attribute(pred_id, NodeAttr.TRUE_POS, True)
comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False)
gt_graph.set_node_attribute(gid, NodeAttr.FALSE_NEG, False)
comp_graph.set_flag_on_node(pred_id, NodeFlag.TRUE_POS, True)
comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False)
gt_graph.set_flag_on_node(gid, NodeFlag.FALSE_NEG, False)
elif len(gt_ids) > 1:
comp_graph.set_node_attribute(pred_id, NodeAttr.NON_SPLIT, True)
comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False)
comp_graph.set_flag_on_node(pred_id, NodeFlag.NON_SPLIT, True)
comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False)
# number of split operations that would be required to correct the vertices
ns_count += len(gt_ids) - 1
gt_graph.set_node_attribute(gt_ids, NodeAttr.FALSE_NEG, False)
for gt_id in gt_ids:
gt_graph.set_flag_on_node(gt_id, NodeFlag.FALSE_NEG, False)

# Record presence of annotations on the TrackingGraph
comp_graph.node_errors = True
Expand All @@ -87,34 +88,30 @@ def get_edge_errors(matched_data: Matched):
get_vertex_errors(matched_data)

induced_graph = comp_graph.get_subgraph(
comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)
comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)
).graph

comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False)
comp_graph.set_edge_attribute(
list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False
)
gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False)
comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False)
comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False)
gt_graph.set_flag_on_all_edges(EdgeFlag.FALSE_NEG, False)

gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph}
comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph}

# intertrack edges = connection between parent and daughter
for graph in [comp_graph, gt_graph]:
# Set to False by default
graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False)
graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False)

for parent in graph.get_divisions():
for daughter in graph.get_succs(parent):
graph.set_edge_attribute(
(parent, daughter), EdgeAttr.INTERTRACK_EDGE, True
for daughter in graph.graph.successors(parent):
graph.set_flag_on_edge(
(parent, daughter), EdgeFlag.INTERTRACK_EDGE, True
)

for merge in graph.get_merges():
for parent in graph.get_preds(merge):
graph.set_edge_attribute(
(parent, merge), EdgeAttr.INTERTRACK_EDGE, True
)
for parent in graph.graph.predecessors(merge):
graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True)

# fp edges - edges in induced_graph that aren't in gt_graph
for edge in tqdm(induced_graph.edges, "Evaluating FP edges"):
Expand All @@ -124,32 +121,32 @@ def get_edge_errors(matched_data: Matched):
target_gt_id = comp_gt_mapping[target]

expected_gt_edge = (source_gt_id, target_gt_id)
if expected_gt_edge not in gt_graph.edges():
comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True)
if expected_gt_edge not in gt_graph.edges:
comp_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_POS, True)
else:
# check if semantics are correct
is_parent_gt = gt_graph.edges()[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE]
is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE]
is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeFlag.INTERTRACK_EDGE]
is_parent_comp = comp_graph.edges[edge][EdgeFlag.INTERTRACK_EDGE]
if is_parent_gt != is_parent_comp:
comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True)
comp_graph.set_flag_on_edge(edge, EdgeFlag.WRONG_SEMANTIC, True)

# fn edges - edges in gt_graph that aren't in induced graph
for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"):
for edge in tqdm(gt_graph.edges, "Evaluating FN edges"):
source, target = edge[0], edge[1]
# this edge is adjacent to an edge we didn't detect, so it definitely is an fn
if (
gt_graph.nodes()[source][NodeAttr.FALSE_NEG]
or gt_graph.nodes()[target][NodeAttr.FALSE_NEG]
gt_graph.nodes[source][NodeFlag.FALSE_NEG]
or gt_graph.nodes[target][NodeFlag.FALSE_NEG]
):
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)
continue

source_comp_id = gt_comp_mapping[source]
target_comp_id = gt_comp_mapping[target]

expected_comp_edge = (source_comp_id, target_comp_id)
if expected_comp_edge not in induced_graph.edges:
gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True)
gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True)

gt_graph.edge_errors = True
comp_graph.edge_errors = True
Loading