From f85c2a9278c47b2fd2b30dd7bfd35a1b5dbd11ec Mon Sep 17 00:00:00 2001 From: cmalinmayor Date: Fri, 3 Nov 2023 20:53:50 -0400 Subject: [PATCH 01/25] Remove unused functions and arguments I used VSCode "Find all references" to inspect all the calls to all the TrackingGraph functions and removed the following unused elements. - `limit_to` args in nodes and edges. These now literally call networkx.nodes and networkx.edges and could be removed. - `get_nodes_by_roi` function. This is probably unnecessary for metrics computation. - `get_edges_with_attribute` function. This is never used, although `get_nodes_with_attribute` is used in the iou matcher. I suggest refactoring the iou matcher to do the actual task more efficiently and getting rid of both of these functions. This same part of the iou matcher is the only place we use TrackingGraph.get_nodes_in_frame as well. Exceptions include: - `get_locations` function. It was never called and is probably not necessary for computing metrics after matching is over. However, it will be used in the point based matcher. We can revisit after that matcher is implemented. - `get_connected_components` and `get_tracklets` functions. They were never called but will likely be needed for metrics such as Cell Cycle Accuracy. We can revisit after that metric is implemented. - `get_node_attribute` and `get_edge_attribute`. These were just implemented and I will refactor the metrics to use them in the next commit. --- src/traccuracy/_tracking_graph.py | 118 +----------------------------- tests/test_tracking_graph.py | 20 +---- 2 files changed, 5 insertions(+), 133 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 1153ea6e..8c7027f5 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -201,44 +201,22 @@ def __init__( self.node_errors = False self.edge_errors = False - def nodes(self, limit_to=None): + def nodes(self): """Get all the nodes in the graph, along with their attributes. - Args: - limit_to (list[hashable], optional): Limit returned dictionary - to nodes with the provided ids. Defaults to None. - Will raise KeyError if any of these node_ids are not present. - Returns: NodeView: Provides set-like operations on the nodes as well as node attribute lookup. """ - if limit_to is None: - return self.graph.nodes - else: - for node in limit_to: - if not self.graph.has_node(node): - raise KeyError(f"Queried node {node} not present in graph.") - return self.graph.subgraph(limit_to).nodes + return self.graph.nodes - def edges(self, limit_to=None): + def edges(self): """Get all the edges in the graph, along with their attributes. - Args: - limit_to (list[tuple[hashable]], optional): Limit returned dictionary - to edges with the provided ids. Defaults to None. - Will raise KeyError if any of these edge ids are not present. - Returns: OutEdgeView: Provides set-like operations on the edge-tuples as well as edge attribute lookup. """ - if limit_to is None: - return self.graph.edges - else: - for edge in limit_to: - if not self.graph.has_edge(*edge): - raise KeyError(f"Queried edge {edge} not present in graph.") - return self.graph.edge_subgraph(limit_to).edges + return self.graph.edges def get_nodes_in_frame(self, frame): """Get the node ids of all nodes in the given frame. @@ -295,62 +273,6 @@ def get_edges_with_flag(self, attr): raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.") return list(self.edges_by_flag[attr]) - def get_nodes_by_roi(self, **kwargs): - """Gets the nodes in a given region of interest (ROI). The ROI is - defined by keyword arguments that correspond to the frame key and - location keys, where each argument should be a (start, end) tuple - (the end is exclusive). Dimensions that are not passed as arguments - are unbounded. None can be passed as an element of the tuple to - signify an unbounded ROI on that side. - - For example, if frame_key='t' and location_keys=('x', 'y'): - `graph.get_nodes_by_roi(t=(10, None), x=(0, 100))` - would return all nodes with time >= 10, and 0 <= x < 100, with no limit - on the y values. - - Returns: - list of hashable: A list of node_ids for all nodes in the ROI. - """ - frames = None - dimensions = [] - for dim, limit in kwargs.items(): - if not (dim == self.frame_key or dim in self.location_keys): - raise ValueError( - f"Provided argument {dim} is neither the frame key" - f" {self.frame_key} or one of the location keys" - f" {self.location_keys}." - ) - if dim == self.frame_key: - frames = list(limit) - else: - dimensions.append((dim, limit[0], limit[1])) - nodes = [] - if frames: - if frames[0] is None: - frames[0] = self.start_frame - if frames[1] is None: - frames[1] = self.end_frame - possible_nodes = [] - for frame in range(frames[0], frames[1]): - if frame in self.nodes_by_frame: - possible_nodes.extend(self.nodes_by_frame[frame]) - else: - possible_nodes = self.graph.nodes() - - for node in possible_nodes: - attrs = self.graph.nodes[node] - inside = True - for dim, start, end in dimensions: - if start is not None and attrs[dim] < start: - inside = False - break - if end is not None and attrs[dim] >= end: - inside = False - break - if inside: - nodes.append(node) - return nodes - def get_nodes_with_attribute(self, attr, criterion=None, limit_to=None): """Get the node_ids of all nodes who have an attribute, optionally limiting to nodes whose value at that attribute meet a given criteria. @@ -384,38 +306,6 @@ def get_nodes_with_attribute(self, attr, criterion=None, limit_to=None): nodes.append(node) return nodes - def get_edges_with_attribute(self, attr, criterion=None, limit_to=None): - """Get the edge_ids of all edges who have an attribute, optionally - limiting to edges whose value at that attribute meet a given criteria. - - For example, get all edges that have an attribute called "fp", - or where the value for "fp" == True. - - Args: - attr (str): the name of the attribute to search for in the edge metadata - criterion ((any)->bool, optional): A function that takes a value and returns - a boolean. If provided, edges will only be returned if the value at - edge[attr] meets this criterion. Defaults to None. - limit_to (list[hashable], optional): If provided the function will only - return edge ids in this list. Will raise KeyError if ids provided here - are not present. - - Returns: - list of hashable: A list of edge_ids which have the given attribute - (and optionally have values at that attribute that meet the given criterion, - and/or are in the list of edge ids.) - """ - if not limit_to: - limit_to = self.graph.edges.keys() - - edges = [] - for edge in limit_to: - attributes = self.graph.edges[edge] - if attr in attributes.keys(): - if criterion is None or criterion(attributes[attr]): - edges.append(edge) - return edges - def get_divisions(self): """Get all nodes that have at least two edges pointing to the next time frame diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 2ee56e01..bcd841a6 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -2,6 +2,7 @@ import networkx as nx import pytest + from traccuracy import EdgeAttr, NodeAttr, TrackingGraph @@ -140,25 +141,6 @@ def test_get_cells_by_frame(simple_graph): assert simple_graph.get_nodes_in_frame(5) == [] -def test_get_nodes_by_roi(simple_graph): - assert simple_graph.get_nodes_by_roi(t=(0, 1)) == ["1_0"] - assert Counter(simple_graph.get_nodes_by_roi(x=(1, None))) == Counter( - ["1_0", "1_1", "1_3", "1_4"] - ) - assert Counter(simple_graph.get_nodes_by_roi(x=(None, 2), t=(1, None))) == Counter( - ["1_1", "1_2"] - ) - - -def test_get_location(nx_comp1): - graph1 = TrackingGraph(nx_comp1, location_keys=["x", "y"]) - assert graph1.get_location("1_2") == [0, 1] - assert graph1.get_location("1_4") == [2, 1] - graph2 = TrackingGraph(nx_comp1, location_keys=["y", "x"]) - assert graph2.get_location("1_2") == [1, 0] - assert graph2.get_location("1_4") == [1, 2] - - def test_get_nodes_with_flag(simple_graph): assert simple_graph.get_nodes_with_flag(NodeAttr.TP_DIV) == ["1_1"] assert simple_graph.get_nodes_with_flag(NodeAttr.FP_DIV) == [] From 25ec94fd693d9666a0fd2a72313e1b7711d14c18 Mon Sep 17 00:00:00 2001 From: cmalinmayor Date: Fri, 3 Nov 2023 21:35:19 -0400 Subject: [PATCH 02/25] Use get_node/edge_attribute in ctc metrics The advantage of using this over the prior approach is that these functions assume that if the attribute is not present it is False. Therefore, I also removed all instances where we set an attribute to False for all nodes/edges before flipping some of them to True. --- src/traccuracy/track_errors/_ctc.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 642399df..c468f061 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -37,9 +37,6 @@ def get_vertex_errors(matched_data: "Matched"): logger.info("Node errors already calculated. Skipping graph annotation") return - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.TRUE_POS, False) - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.NON_SPLIT, False) - # will flip this when we come across the vertex in the mapping comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.FALSE_POS, True) gt_graph.set_node_attribute(list(gt_graph.nodes()), NodeAttr.FALSE_NEG, True) @@ -88,19 +85,11 @@ def get_edge_errors(matched_data: "Matched"): comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS) ).graph - comp_graph.set_edge_attribute(list(comp_graph.edges()), EdgeAttr.FALSE_POS, False) - comp_graph.set_edge_attribute( - list(comp_graph.edges()), EdgeAttr.WRONG_SEMANTIC, False - ) - gt_graph.set_edge_attribute(list(gt_graph.edges()), EdgeAttr.FALSE_NEG, False) - gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: - # Set to False by default - graph.set_edge_attribute(list(graph.edges()), EdgeAttr.INTERTRACK_EDGE, False) for parent in graph.get_divisions(): for daughter in graph.get_succs(parent): @@ -126,8 +115,8 @@ def get_edge_errors(matched_data: "Matched"): comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True) else: # check if semantics are correct - is_parent_gt = gt_graph.edges()[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE] - is_parent_comp = comp_graph.edges()[edge][EdgeAttr.INTERTRACK_EDGE] + is_parent_gt = gt_graph.get_edge_attribute(expected_gt_edge, EdgeAttr.INTERTRACK_EDGE) + is_parent_comp = comp_graph.get_edge_attribute(edge, EdgeAttr.INTERTRACK_EDGE) if is_parent_gt != is_parent_comp: comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) @@ -136,8 +125,8 @@ def get_edge_errors(matched_data: "Matched"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn if ( - gt_graph.nodes()[source][NodeAttr.FALSE_NEG] - or gt_graph.nodes()[target][NodeAttr.FALSE_NEG] + gt_graph.get_node_attribute(source, NodeAttr.FALSE_NEG) + or gt_graph.get_node_attribute(target, NodeAttr.FALSE_NEG) ): gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) continue From 1a799573e54c76d4484ef49734b60f14ace0a064 Mon Sep 17 00:00:00 2001 From: cmalinmayor Date: Fri, 3 Nov 2023 22:00:54 -0400 Subject: [PATCH 03/25] Use get_node/edge_attributes in metrics tests This accompanies the prior commit updating the metrics computations, and additionally updates the tests, since they can no longer assume that False attributes are explicitly annotated. --- tests/track_errors/test_ctc_errors.py | 27 ++++++++++++--------- tests/track_errors/test_divisions.py | 35 ++++++++++++--------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index ceba4989..5646efc4 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,5 +1,6 @@ import networkx as nx import numpy as np + from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph from traccuracy.matchers._matched import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors @@ -49,17 +50,17 @@ def test_get_vertex_errors(): assert len(G_comp.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 assert len(G_gt.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 - assert gt_g.nodes[15][NodeAttr.FALSE_NEG] - assert not gt_g.nodes[17][NodeAttr.FALSE_NEG] + assert G_gt.get_node_attribute(15, NodeAttr.FALSE_NEG) + assert not G_gt.get_node_attribute(17, NodeAttr.FALSE_NEG) - assert comp_g.nodes[3][NodeAttr.NON_SPLIT] - assert not comp_g.nodes[7][NodeAttr.NON_SPLIT] + assert G_comp.get_node_attribute(3, NodeAttr.NON_SPLIT) + assert not G_comp.get_node_attribute(7, NodeAttr.NON_SPLIT) - assert comp_g.nodes[7][NodeAttr.TRUE_POS] - assert not comp_g.nodes[3][NodeAttr.TRUE_POS] + assert G_comp.get_node_attribute(7, NodeAttr.TRUE_POS) + assert not G_comp.get_node_attribute(3, NodeAttr.TRUE_POS) - assert comp_g.nodes[10][NodeAttr.FALSE_POS] - assert not comp_g.nodes[7][NodeAttr.FALSE_POS] + assert G_comp.get_node_attribute(10, NodeAttr.FALSE_POS) + assert not G_comp.get_node_attribute(7, NodeAttr.FALSE_POS) def test_assign_edge_errors(): @@ -99,9 +100,10 @@ def test_assign_edge_errors(): matched_data.mapping = mapping get_edge_errors(matched_data) - - assert comp_g.edges[(7, 8)][EdgeAttr.FALSE_POS] - assert gt_g.edges[(17, 18)][EdgeAttr.FALSE_NEG] + matched_comp = matched_data.pred_graph + matched_gt = matched_data.gt_graph + assert matched_comp.get_edge_attribute((7, 8), EdgeAttr.FALSE_POS) + assert matched_gt.get_edge_attribute((17, 18), EdgeAttr.FALSE_NEG) def test_assign_edge_errors_semantics(): @@ -138,7 +140,8 @@ def test_assign_edge_errors_semantics(): matched_data = DummyMatched(TrackingGraph(gt), TrackingGraph(comp)) matched_data.mapping = mapping + matched_comp = matched_data.pred_graph get_edge_errors(matched_data) - assert comp.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] + assert matched_comp.get_edge_attribute(("1_2", "1_3"), EdgeAttr.WRONG_SEMANTIC) diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 0a10bf66..2f3e759d 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -1,17 +1,14 @@ import networkx as nx import numpy as np import pytest -from traccuracy import NodeAttr, TrackingGraph -from traccuracy.matchers._matched import Matched -from traccuracy.track_errors.divisions import ( - _classify_divisions, - _correct_shifted_divisions, - _evaluate_division_events, - _get_pred_by_t, - _get_succ_by_t, -) from tests.test_utils import get_division_graphs +from traccuracy import NodeAttr, TrackingGraph +from traccuracy.matchers._matched import Matched +from traccuracy.track_errors.divisions import (_classify_divisions, + _correct_shifted_divisions, + _evaluate_division_events, + _get_pred_by_t, _get_succ_by_t) class DummyMatched(Matched): @@ -185,8 +182,8 @@ def test_no_change(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is True - assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is True + assert ng_pred.get_node_attribute("1_3", NodeAttr.FP_DIV) is True + assert ng_gt.get_node_attribute("1_1", NodeAttr.FN_DIV) is True assert len(ng_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 def test_fn_early(self): @@ -203,10 +200,10 @@ def test_fn_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_3"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes()["1_1"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes()["1_3"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes()["1_1"][NodeAttr.TP_DIV] is True + assert ng_pred.get_node_attribute("1_3", NodeAttr.FP_DIV) is False + assert ng_gt.get_node_attribute("1_1", NodeAttr.FN_DIV) is False + assert ng_pred.get_node_attribute("1_3", NodeAttr.TP_DIV) is True + assert ng_gt.get_node_attribute("1_1", NodeAttr.TP_DIV) is True def test_fp_early(self): # Early division in pred @@ -222,10 +219,10 @@ def test_fp_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes()["1_1"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes()["1_3"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes()["1_1"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes()["1_3"][NodeAttr.TP_DIV] is True + assert ng_pred.get_node_attribute("1_1", NodeAttr.FP_DIV) is False + assert ng_gt.get_node_attribute("1_3", NodeAttr.FN_DIV) is False + assert ng_pred.get_node_attribute("1_1", NodeAttr.TP_DIV) is True + assert ng_gt.get_node_attribute("1_3", NodeAttr.TP_DIV) is True def test_evaluate_division_events(): From d3d025156716fc742acef235270a3a480df45116 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Nov 2023 02:43:57 +0000 Subject: [PATCH 04/25] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/track_errors/_ctc.py | 16 +++++++++------- tests/test_tracking_graph.py | 1 - tests/track_errors/test_ctc_errors.py | 1 - tests/track_errors/test_divisions.py | 15 +++++++++------ 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index c468f061..ac726ff7 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -90,7 +90,6 @@ def get_edge_errors(matched_data: "Matched"): # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: - for parent in graph.get_divisions(): for daughter in graph.get_succs(parent): graph.set_edge_attribute( @@ -115,8 +114,12 @@ def get_edge_errors(matched_data: "Matched"): comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True) else: # check if semantics are correct - is_parent_gt = gt_graph.get_edge_attribute(expected_gt_edge, EdgeAttr.INTERTRACK_EDGE) - is_parent_comp = comp_graph.get_edge_attribute(edge, EdgeAttr.INTERTRACK_EDGE) + is_parent_gt = gt_graph.get_edge_attribute( + expected_gt_edge, EdgeAttr.INTERTRACK_EDGE + ) + is_parent_comp = comp_graph.get_edge_attribute( + edge, EdgeAttr.INTERTRACK_EDGE + ) if is_parent_gt != is_parent_comp: comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) @@ -124,10 +127,9 @@ def get_edge_errors(matched_data: "Matched"): for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn - if ( - gt_graph.get_node_attribute(source, NodeAttr.FALSE_NEG) - or gt_graph.get_node_attribute(target, NodeAttr.FALSE_NEG) - ): + if gt_graph.get_node_attribute( + source, NodeAttr.FALSE_NEG + ) or gt_graph.get_node_attribute(target, NodeAttr.FALSE_NEG): gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) continue diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index bcd841a6..f00c01fd 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -2,7 +2,6 @@ import networkx as nx import pytest - from traccuracy import EdgeAttr, NodeAttr, TrackingGraph diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 5646efc4..2622b42f 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,6 +1,5 @@ import networkx as nx import numpy as np - from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph from traccuracy.matchers._matched import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 2f3e759d..fc985225 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -1,14 +1,17 @@ import networkx as nx import numpy as np import pytest - -from tests.test_utils import get_division_graphs from traccuracy import NodeAttr, TrackingGraph from traccuracy.matchers._matched import Matched -from traccuracy.track_errors.divisions import (_classify_divisions, - _correct_shifted_divisions, - _evaluate_division_events, - _get_pred_by_t, _get_succ_by_t) +from traccuracy.track_errors.divisions import ( + _classify_divisions, + _correct_shifted_divisions, + _evaluate_division_events, + _get_pred_by_t, + _get_succ_by_t, +) + +from tests.test_utils import get_division_graphs class DummyMatched(Matched): From 17e1190602b2e9f5351a030189a1fc573b329042 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 28 Nov 2023 15:02:29 -0500 Subject: [PATCH 05/25] Change nodes and edges to properties --- src/traccuracy/_tracking_graph.py | 2 ++ src/traccuracy/matchers/_matched.py | 4 ++-- src/traccuracy/track_errors/_ctc.py | 8 ++++---- tests/track_errors/test_divisions.py | 12 ++++++------ 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 8c7027f5..49ff5c2d 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -201,6 +201,7 @@ def __init__( self.node_errors = False self.edge_errors = False + @property def nodes(self): """Get all the nodes in the graph, along with their attributes. @@ -209,6 +210,7 @@ def nodes(self): """ return self.graph.nodes + @property def edges(self): """Get all the edges in the graph, along with their attributes. diff --git a/src/traccuracy/matchers/_matched.py b/src/traccuracy/matchers/_matched.py index 85399848..40637dad 100644 --- a/src/traccuracy/matchers/_matched.py +++ b/src/traccuracy/matchers/_matched.py @@ -26,9 +26,9 @@ def __init__(self, gt_graph: "TrackingGraph", pred_graph: "TrackingGraph"): self.mapping = self.compute_mapping() # Report matching performance - total_gt = len(self.gt_graph.nodes()) + total_gt = len(self.gt_graph.nodes) matched_gt = len({m[0] for m in self.mapping}) - total_pred = len(self.pred_graph.nodes()) + total_pred = len(self.pred_graph.nodes) matched_pred = len({m[1] for m in self.mapping}) logger.info(f"Matched {matched_gt} out of {total_gt} ground truth nodes.") logger.info(f"Matched {matched_pred} out of {total_pred} predicted nodes.") diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index ac726ff7..5f17e0ba 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -38,8 +38,8 @@ def get_vertex_errors(matched_data: "Matched"): return # will flip this when we come across the vertex in the mapping - comp_graph.set_node_attribute(list(comp_graph.nodes()), NodeAttr.FALSE_POS, True) - gt_graph.set_node_attribute(list(gt_graph.nodes()), NodeAttr.FALSE_NEG, True) + comp_graph.set_node_attribute(list(comp_graph.nodes), NodeAttr.FALSE_POS, True) + gt_graph.set_node_attribute(list(gt_graph.nodes), NodeAttr.FALSE_NEG, True) # we need to know how many computed vertices are "non-split", so we make # a mapping of gt vertices to their matched comp vertices @@ -110,7 +110,7 @@ def get_edge_errors(matched_data: "Matched"): target_gt_id = comp_gt_mapping[target] expected_gt_edge = (source_gt_id, target_gt_id) - if expected_gt_edge not in gt_graph.edges(): + if expected_gt_edge not in gt_graph.edges: comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True) else: # check if semantics are correct @@ -124,7 +124,7 @@ def get_edge_errors(matched_data: "Matched"): comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) # fn edges - edges in gt_graph that aren't in induced graph - for edge in tqdm(gt_graph.edges(), "Evaluating FN edges"): + for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn if gt_graph.get_node_attribute( diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index fc985225..fb6562df 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -61,8 +61,8 @@ def test_classify_divisions_tp(g): assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert NodeAttr.TP_DIV in g_gt.nodes["2_2"] + assert NodeAttr.TP_DIV in g_pred.nodes["2_2"] # Check division flag assert g_gt.division_annotations @@ -92,9 +92,9 @@ def test_classify_divisions_fp(g): _classify_divisions(matched_data) assert len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert NodeAttr.FP_DIV in g_pred.nodes()["1_2"] - assert NodeAttr.TP_DIV in g_gt.nodes()["2_2"] - assert NodeAttr.TP_DIV in g_pred.nodes()["2_2"] + assert NodeAttr.FP_DIV in g_pred.nodes["1_2"] + assert NodeAttr.TP_DIV in g_gt.nodes["2_2"] + assert NodeAttr.TP_DIV in g_pred.nodes["2_2"] def test_classify_divisions_fn(g): @@ -116,7 +116,7 @@ def test_classify_divisions_fn(g): assert len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 assert len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 - assert NodeAttr.FN_DIV in g_gt.nodes()["2_2"] + assert NodeAttr.FN_DIV in g_gt.nodes["2_2"] @pytest.fixture From f4d173e3bd5169c9c760075df3d720302535916e Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 28 Nov 2023 15:51:01 -0500 Subject: [PATCH 06/25] Refactor IOU matcher Improve efficiency by creating a pre-computed dictionary of time to segmentation_id to node_id. This avoids calling TrackingGraph functions to find each node with the given label_key in the given frame when constructing the matching tuples, which under the hood loops over all nodes. --- src/traccuracy/matchers/_iou.py | 38 +++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 189c4e6d..63fde0f8 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -76,24 +76,40 @@ def match_iou(gt, pred, threshold=0.6): # Get overlaps for each frame frame_range = range(gt.start_frame, gt.end_frame) total = len(list(frame_range)) + + def construct_time_to_seg_id_map(graph): + """ + Args: + graph(TrackingGraph) + + Returns a dictionary {time: {segmentation_id: node_id}} + """ + time_to_seg_id_map = {} + for node_id, data in graph.nodes(data=True): + time = data[graph.frame_key] + seg_id = data[graph.label_key] + seg_id_to_node_id_map = time_to_seg_id_map.get(time, {}) + assert seg_id not in seg_id_to_node_id_map,\ + f"Segmentation ID {seg_id} occurred twice in frame {time}." + seg_id_to_node_id_map[seg_id] = node_id + time_to_seg_id_map[time] = seg_id_to_node_id_map + return time_to_seg_id_map + + gt_time_to_seg_id_map = construct_time_to_seg_id_map(gt) + pred_time_to_seg_id_map = construct_time_to_seg_id_map(pred) + for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total): matches = _match_nodes( gt.segmentation[i], pred.segmentation[i], threshold=threshold ) + gt_seg_to_node_map = gt_time_to_seg_id_map[t] + pred_seg_to_node_map = pred_time_to_seg_id_map[t] # Construct node id tuple for each match - for gt_id, pred_id in zip(*matches): + for gt_seg_id, pred_seg_id in zip(*matches): # Find node id based on time and segmentation label - gt_node = gt.get_nodes_with_attribute( - gt.label_key, - criterion=lambda x: x == gt_id, # noqa - limit_to=gt.get_nodes_in_frame(t), - )[0] - pred_node = pred.get_nodes_with_attribute( - pred.label_key, - criterion=lambda x: x == pred_id, # noqa - limit_to=pred.get_nodes_in_frame(t), - )[0] + gt_node = gt_seg_to_node_map[gt_seg_id] + pred_node = pred_seg_to_node_map[pred_seg_id] mapper.append((gt_node, pred_node)) return mapper From baf79d42361f880c1ede13fe016dac8e8ebaf825 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 28 Nov 2023 15:57:22 -0500 Subject: [PATCH 07/25] Use get_nodes_with_flag in division metrics --- src/traccuracy/metrics/_divisions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 4bd494fa..587a7fc5 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -91,13 +91,13 @@ def compute(self, data: Matched): def _calculate_metrics(self, g_gt, g_pred): tp_division_count = len( - g_gt.get_nodes_with_attribute(NodeAttr.TP_DIV, lambda x: x) + g_gt.get_nodes_with_flag(NodeAttr.TP_DIV) ) fn_division_count = len( - g_gt.get_nodes_with_attribute(NodeAttr.FN_DIV, lambda x: x) + g_gt.get_nodes_with_flag(NodeAttr.FN_DIV) ) fp_division_count = len( - g_pred.get_nodes_with_attribute(NodeAttr.FP_DIV, lambda x: x) + g_pred.get_nodes_with_flag(NodeAttr.FP_DIV) ) try: From 3b7c3bc67d3fce4f43cf4f5165d9e726191f9d05 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 28 Nov 2023 15:57:39 -0500 Subject: [PATCH 08/25] Remove 'get_nodes_with_attribute' from TrackingGraph This function was overly general and thus inefficient. After refactoring the IOU matcher is was no longer used and could be removed safely. --- src/traccuracy/_tracking_graph.py | 32 ------------------------------- tests/test_tracking_graph.py | 31 ------------------------------ 2 files changed, 63 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 7b6b2112..496af545 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -275,38 +275,6 @@ def get_edges_with_flag(self, attr): raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.") return list(self.edges_by_flag[attr]) - def get_nodes_with_attribute(self, attr, criterion=None, limit_to=None): - """Get the node_ids of all nodes who have an attribute, optionally - limiting to nodes whose value at that attribute meet a given criteria. - - For example, get all nodes that have an attribute called "division", - or where the value for "division" == True. - This also works on location keys, for example to get all nodes with y > 100. - - Args: - attr (str): the name of the attribute to search for in the node metadata - criterion ((any)->bool, optional): A function that takes a value and returns - a boolean. If provided, nodes will only be returned if the value at - node[attr] meets this criterion. Defaults to None. - limit_to (list[hashable], optional): If provided the function will only - return node ids in this list. Will raise KeyError if ids provided here - are not present. - - Returns: - list of hashable: A list of node_ids which have the given attribute - (and optionally have values at that attribute that meet the given criterion, - and/or are in the list of node ids.) - """ - if not limit_to: - limit_to = self.graph.nodes.keys() - - nodes = [] - for node in limit_to: - attributes = self.graph.nodes[node] - if attr in attributes.keys(): - if criterion is None or criterion(attributes[attr]): - nodes.append(node) - return nodes def get_divisions(self): """Get all nodes that have at least two edges pointing to the next time frame diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index f00c01fd..a5d9a1c1 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -154,37 +154,6 @@ def test_get_edges_with_flag(simple_graph): assert simple_graph.get_nodes_with_flag("is_tp") -def test_get_nodes_with_attribute(simple_graph): - assert simple_graph.get_nodes_with_attribute("is_tp_division") == ["1_1"] - assert simple_graph.get_nodes_with_attribute("null") == [] - assert simple_graph.get_nodes_with_attribute( - "is_tp_division", criterion=lambda x: x - ) == ["1_1"] - assert ( - simple_graph.get_nodes_with_attribute( - "is_tp_division", criterion=lambda x: not x - ) - == [] - ) - assert simple_graph.get_nodes_with_attribute("x", criterion=lambda x: x > 1) == [ - "1_3", - "1_4", - ] - assert simple_graph.get_nodes_with_attribute( - "x", criterion=lambda x: x > 1, limit_to=["1_3"] - ) == [ - "1_3", - ] - assert ( - simple_graph.get_nodes_with_attribute( - "x", criterion=lambda x: x > 1, limit_to=["1_0"] - ) - == [] - ) - with pytest.raises(KeyError): - simple_graph.get_nodes_with_attribute("x", limit_to=["5"]) - - def test_get_divisions(complex_graph): assert complex_graph.get_divisions() == ["1_1", "2_2"] From c0d5d04e04a6290c303959a524698db15131a331 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Nov 2023 20:59:48 +0000 Subject: [PATCH 09/25] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_tracking_graph.py | 1 - src/traccuracy/matchers/_iou.py | 13 +++++++------ src/traccuracy/metrics/_divisions.py | 12 +++--------- tests/track_errors/test_ctc_errors.py | 4 +++- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 496af545..6aadcb03 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -275,7 +275,6 @@ def get_edges_with_flag(self, attr): raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.") return list(self.edges_by_flag[attr]) - def get_divisions(self): """Get all nodes that have at least two edges pointing to the next time frame diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index 63fde0f8..01bcb9eb 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -78,10 +78,10 @@ def match_iou(gt, pred, threshold=0.6): total = len(list(frame_range)) def construct_time_to_seg_id_map(graph): - """ + """ Args: graph(TrackingGraph) - + Returns a dictionary {time: {segmentation_id: node_id}} """ time_to_seg_id_map = {} @@ -89,15 +89,16 @@ def construct_time_to_seg_id_map(graph): time = data[graph.frame_key] seg_id = data[graph.label_key] seg_id_to_node_id_map = time_to_seg_id_map.get(time, {}) - assert seg_id not in seg_id_to_node_id_map,\ - f"Segmentation ID {seg_id} occurred twice in frame {time}." + assert ( + seg_id not in seg_id_to_node_id_map + ), f"Segmentation ID {seg_id} occurred twice in frame {time}." seg_id_to_node_id_map[seg_id] = node_id time_to_seg_id_map[time] = seg_id_to_node_id_map return time_to_seg_id_map - + gt_time_to_seg_id_map = construct_time_to_seg_id_map(gt) pred_time_to_seg_id_map = construct_time_to_seg_id_map(pred) - + for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total): matches = _match_nodes( gt.segmentation[i], pred.segmentation[i], threshold=threshold diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 587a7fc5..4f3f4b6f 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -90,15 +90,9 @@ def compute(self, data: Matched): } def _calculate_metrics(self, g_gt, g_pred): - tp_division_count = len( - g_gt.get_nodes_with_flag(NodeAttr.TP_DIV) - ) - fn_division_count = len( - g_gt.get_nodes_with_flag(NodeAttr.FN_DIV) - ) - fp_division_count = len( - g_pred.get_nodes_with_flag(NodeAttr.FP_DIV) - ) + tp_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) + fn_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) + fp_division_count = len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) try: recall = tp_division_count / (tp_division_count + fn_division_count) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 3524ca43..3305ef3b 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -133,4 +133,6 @@ def test_assign_edge_errors_semantics(): get_edge_errors(matched_data) - assert matched_data.pred_graph.get_edge_attribute(("1_2", "1_3"), EdgeAttr.WRONG_SEMANTIC) + assert matched_data.pred_graph.get_edge_attribute( + ("1_2", "1_3"), EdgeAttr.WRONG_SEMANTIC + ) From e1f8545b951c1e4de2ed1dd92439fe0cb7c63993 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 14:35:33 -0500 Subject: [PATCH 10/25] Separate setting flag on one node/edge from all nodes/edges Having one function was clunky (had to cast ids specifically to a list to detect one vs many) and inefficient (couldn't leverage networkx functions for setting all attributes). This commit also renames the functions to the format `set_flag_on_node` or `set_flag_on_all_nodes` for maximum clarity. This naming is clear that we are setting one flag on one or many nodes or edges, and gets a head start on using `flag` instead of `attribute` for our custom flags. --- src/traccuracy/_tracking_graph.py | 124 ++++++++++++++++------- src/traccuracy/track_errors/_ctc.py | 29 +++--- src/traccuracy/track_errors/divisions.py | 19 ++-- tests/test_tracking_graph.py | 30 ++++-- 4 files changed, 138 insertions(+), 64 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 6aadcb03..633bd60a 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import copy import enum import logging import networkx as nx +from typing import Hashable logger = logging.getLogger(__name__) @@ -367,63 +370,114 @@ def get_subgraph(self, nodes): return new_trackgraph - def set_node_attribute(self, ids, attr, value=True): - """Set an attribute flag for a set of nodes specified by - ids. If an id is not found in the graph, a KeyError will be raised. - If the key already exists, the existing value will be overwritten. + def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool=True): + """Set an attribute flag for a single node. + If the id is not found in the graph, a KeyError will be raised. + If the flag already exists, the existing value will be overwritten. Args: - ids (hashable | list[hashable]): The node id or list of node ids - to set the attribute for. - attr (traccuracy.NodeAttr): The node attribute to set. Must be + _id (Hashable): The node id on which to set the flag. + flag (traccuracy.NodeAttr): The node flag to set. Must be of type NodeAttr - you may not not pass strings, even if they are included in the NodeAttr enum values. value (bool, optional): Attributes are flags and can only be set to True or False. Defaults to True. + + Raises: + KeyError if the provided id is not in the graph. + ValueError if the provided flag is not a NodeAttr """ - if not isinstance(ids, list): - ids = [ids] - if not isinstance(attr, NodeAttr): + if not isinstance(flag, NodeAttr): raise ValueError( - f"Provided attribute {attr} is not of type NodeAttr. " + f"Provided flag {flag} is not of type NodeAttr. " "Please use the enum instead of passing string values, " "and add new attributes to the class to avoid key collision." ) - for _id in ids: - self.graph.nodes[_id][attr] = value - if value: - self.nodes_by_flag[attr].add(_id) - else: - self.nodes_by_flag[attr].discard(_id) + self.graph.nodes[_id][flag] = value + if value: + self.nodes_by_flag[flag].add(_id) + else: + self.nodes_by_flag[flag].discard(_id) + + def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool=True): + """Set an attribute flag for all nodes in the graph. + If the flag already exists, the existing values will be overwritten. + + Args: + flag (traccuracy.NodeAttr): The node flag to set. Must be + of type NodeAttr - you may not not pass strings, even if they + are included in the NodeAttr enum values. + value (bool, optional): Flags can only be set to True or False. + Defaults to True. + + Raises: + ValueError if the provided flag is not a NodeAttr. + """ + if not isinstance(flag, NodeAttr): + raise ValueError( + f"Provided flag {flag} is not of type NodeAttr. " + "Please use the enum instead of passing string values, " + "and add new attributes to the class to avoid key collision." + ) + nx.set_node_attributes(self.graph, value, name=flag) + if value: + self.nodes_by_flag[flag].update(self.graph.nodes) + else: + self.nodes_by_flag[flag] = set() - def set_edge_attribute(self, ids, attr, value=True): - """Set an attribute flag for a set of edges specified by - ids. If an edge is not found in the graph, a KeyError will be raised. - If the key already exists, the existing value will be overwritten. + def set_flag_on_edge(self, _id: tuple(Hashable), flag: EdgeAttr, value: bool=True): + """Set an attribute flag for an edge. + If the flag already exists, the existing value will be overwritten. Args: - ids (tuple(hashable) | list[tuple(hashable)]): The edge id or list of edge ids + ids (tuple(Hashable)): The edge id or list of edge ids to set the attribute for. Edge ids are a 2-tuple of node ids. - attr (traccuracy.EdgeAttr): The edge attribute to set. Must be + flag (traccuracy.EdgeAttr): The edge flag to set. Must be of type EdgeAttr - you may not pass strings, even if they are included in the EdgeAttr enum values. - value (bool): Attributes are flags and can only be set to - True or False. Defaults to True. + value (bool): Flags can only be set to True or False. + Defaults to True. + + Raises: + KeyError if edge with _id not in graph. """ - if not isinstance(ids, list): - ids = [ids] - if not isinstance(attr, EdgeAttr): + if not isinstance(flag, EdgeAttr): raise ValueError( - f"Provided attribute {attr} is not of type EdgeAttr. " + f"Provided attribute {flag} is not of type EdgeAttr. " "Please use the enum instead of passing string values, " "and add new attributes to the class to avoid key collision." ) - for _id in ids: - self.graph.edges[_id][attr] = value - if value: - self.edges_by_flag[attr].add(_id) - else: - self.edges_by_flag[attr].discard(_id) + self.graph.edges[_id][flag] = value + if value: + self.edges_by_flag[flag].add(_id) + else: + self.edges_by_flag[flag].discard(_id) + + def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool=True): + """Set an attribute flag for all edges in the graph. + If the flag already exists, the existing values will be overwritten. + + Args: + flag (traccuracy.EdgeAttr): The edge flag to set. Must be + of type EdgeAttr - you may not not pass strings, even if they + are included in the EdgeAttr enum values. + value (bool, optional): Flags can only be set to True or False. + Defaults to True. + + Raises: + ValueError if the provided flag is not an EdgeAttr. + """ + if not isinstance(flag, EdgeAttr): + raise ValueError( + f"Provided flag {flag} is not of type EdgeAttr. " + "Please use the enum instead of passing string values, " + "and add new attributes to the class to avoid key collision." + ) + nx.set_edge_attributes(self.graph, value, name=flag) + if value: + self.edges_by_flag[flag].update(self.graph.edges) + else: + self.edges_by_flag[flag] = set() def get_node_attribute(self, _id, attr): """Get the boolean value of a given attribute for a given node. diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index a9c2b8ab..640e140a 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -40,8 +40,8 @@ def get_vertex_errors(matched_data: Matched): return # will flip this when we come across the vertex in the mapping - comp_graph.set_node_attribute(list(comp_graph.nodes), NodeAttr.FALSE_POS, True) - gt_graph.set_node_attribute(list(gt_graph.nodes), NodeAttr.FALSE_NEG, True) + comp_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, True) + gt_graph.set_flag_on_all_nodes(NodeAttr.FALSE_NEG, True) # we need to know how many computed vertices are "non-split", so we make # a mapping of gt vertices to their matched comp vertices @@ -54,15 +54,16 @@ def get_vertex_errors(matched_data: Matched): gt_ids = dict_mapping[pred_id] if len(gt_ids) == 1: gid = gt_ids[0] - comp_graph.set_node_attribute(pred_id, NodeAttr.TRUE_POS, True) - comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False) - gt_graph.set_node_attribute(gid, NodeAttr.FALSE_NEG, False) + comp_graph.set_flag_on_node(pred_id, NodeAttr.TRUE_POS, True) + comp_graph.set_flag_on_node(pred_id, NodeAttr.FALSE_POS, False) + gt_graph.set_flag_on_node(gid, NodeAttr.FALSE_NEG, False) elif len(gt_ids) > 1: - comp_graph.set_node_attribute(pred_id, NodeAttr.NON_SPLIT, True) - comp_graph.set_node_attribute(pred_id, NodeAttr.FALSE_POS, False) + comp_graph.set_flag_on_node(pred_id, NodeAttr.NON_SPLIT, True) + comp_graph.set_flag_on_node(pred_id, NodeAttr.FALSE_POS, False) # number of split operations that would be required to correct the vertices ns_count += len(gt_ids) - 1 - gt_graph.set_node_attribute(gt_ids, NodeAttr.FALSE_NEG, False) + for gt_id in gt_ids: + gt_graph.set_flag_on_node(gt_id, NodeAttr.FALSE_NEG, False) # Record presence of annotations on the TrackingGraph comp_graph.node_errors = True @@ -94,13 +95,13 @@ def get_edge_errors(matched_data: Matched): for graph in [comp_graph, gt_graph]: for parent in graph.get_divisions(): for daughter in graph.get_succs(parent): - graph.set_edge_attribute( + graph.set_flag_on_edge( (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True ) for merge in graph.get_merges(): for parent in graph.get_preds(merge): - graph.set_edge_attribute( + graph.set_flag_on_edge( (parent, merge), EdgeAttr.INTERTRACK_EDGE, True ) @@ -113,7 +114,7 @@ def get_edge_errors(matched_data: Matched): expected_gt_edge = (source_gt_id, target_gt_id) if expected_gt_edge not in gt_graph.edges: - comp_graph.set_edge_attribute(edge, EdgeAttr.FALSE_POS, True) + comp_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_POS, True) else: # check if semantics are correct is_parent_gt = gt_graph.get_edge_attribute( @@ -123,7 +124,7 @@ def get_edge_errors(matched_data: Matched): edge, EdgeAttr.INTERTRACK_EDGE ) if is_parent_gt != is_parent_comp: - comp_graph.set_edge_attribute(edge, EdgeAttr.WRONG_SEMANTIC, True) + comp_graph.set_flag_on_edge(edge, EdgeAttr.WRONG_SEMANTIC, True) # fn edges - edges in gt_graph that aren't in induced graph for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): @@ -132,7 +133,7 @@ def get_edge_errors(matched_data: Matched): if gt_graph.get_node_attribute( source, NodeAttr.FALSE_NEG ) or gt_graph.get_node_attribute(target, NodeAttr.FALSE_NEG): - gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_NEG, True) continue source_comp_id = gt_comp_mapping[source] @@ -140,7 +141,7 @@ def get_edge_errors(matched_data: Matched): expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: - gt_graph.set_edge_attribute(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_NEG, True) gt_graph.edge_errors = True comp_graph.edge_errors = True diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 7ab3fc6b..43626f29 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -62,7 +62,7 @@ def _find_pred_node_matches(pred_node): pred_node = _find_gt_node_matches(gt_node) # No matching node so division missed if pred_node is None: - g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeAttr.FN_DIV, True) # Check if the division has the correct daughters else: succ_gt = g_gt.get_succs(gt_node) @@ -73,18 +73,19 @@ def _find_pred_node_matches(pred_node): # If daughters are same, division is correct if Counter(succ_gt) == Counter(succ_pred): - g_gt.set_node_attribute(gt_node, NodeAttr.TP_DIV, True) - g_pred.set_node_attribute(pred_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeAttr.TP_DIV, True) + g_pred.set_flag_on_node(pred_node, NodeAttr.TP_DIV, True) # If daughters are at all mismatched, division is false negative else: - g_gt.set_node_attribute(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeAttr.FN_DIV, True) # Remove res division to record that we have classified it if pred_node in div_pred: div_pred.remove(pred_node) # Any remaining pred divisions are false positives - g_pred.set_node_attribute(div_pred, NodeAttr.FP_DIV, True) + for fp_div in div_pred: + g_pred.set_flag_on_node(fp_div, NodeAttr.FP_DIV, True) # Set division annotation flag g_gt.division_annotations = True @@ -228,12 +229,12 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): if correct: # Remove error annotations from pred graph - g_pred.set_node_attribute(fp_node, NodeAttr.FP_DIV, False) - g_gt.set_node_attribute(fn_node, NodeAttr.FN_DIV, False) + g_pred.set_flag_on_node(fp_node, NodeAttr.FP_DIV, False) + g_gt.set_flag_on_node(fn_node, NodeAttr.FN_DIV, False) # Add the tp divisions annotations - g_gt.set_node_attribute(fn_node, NodeAttr.TP_DIV, True) - g_pred.set_node_attribute(fp_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(fn_node, NodeAttr.TP_DIV, True) + g_pred.set_flag_on_node(fp_node, NodeAttr.TP_DIV, True) return new_matched diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index a5d9a1c1..400314ab 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -194,7 +194,7 @@ def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): assert track2.graph.edges == nx_comp2.edges -def test_get_and_set_node_attributes(simple_graph): +def test_get_and_set_flag_on_node(simple_graph): assert simple_graph.nodes()["1_0"] == {"id": "1_0", "t": 0, "y": 1, "x": 1} assert simple_graph.nodes()["1_1"] == { "id": "1_1", @@ -204,7 +204,7 @@ def test_get_and_set_node_attributes(simple_graph): "is_tp_division": True, } - simple_graph.set_node_attribute("1_0", NodeAttr.FALSE_POS, value=False) + simple_graph.set_flag_on_node("1_0", NodeAttr.FALSE_POS, value=False) assert simple_graph.nodes()["1_0"] == { "id": "1_0", "t": 0, @@ -212,18 +212,36 @@ def test_get_and_set_node_attributes(simple_graph): "x": 1, NodeAttr.FALSE_POS: False, } + + simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=False) + for node in simple_graph.nodes: + assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == False + + simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=True) + for node in simple_graph.nodes: + assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == True + with pytest.raises(ValueError): - simple_graph.set_node_attribute("1_0", "x", 2) + simple_graph.set_flag_on_node("1_0", "x", 2) -def test_get_and_set_edge_attributes(simple_graph): +def test_get_and_set_flag_on_edge(simple_graph): print(simple_graph.edges()) assert EdgeAttr.TRUE_POS not in simple_graph.edges()[("1_1", "1_3")] - simple_graph.set_edge_attribute(("1_1", "1_3"), EdgeAttr.TRUE_POS, value=False) + simple_graph.set_flag_on_edge(("1_1", "1_3"), EdgeAttr.TRUE_POS, value=False) assert simple_graph.edges()[("1_1", "1_3")][EdgeAttr.TRUE_POS] is False + + simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=False) + for edge in simple_graph.edges: + assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == False + + simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=True) + for edge in simple_graph.edges: + assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == True + with pytest.raises(ValueError): - simple_graph.set_edge_attribute(("1_1", "1_3"), "x", 2) + simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) def test_get_tracklets(simple_graph): From 08866965e4172052094ee8e5baf9d462e4ca147e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Nov 2023 19:41:18 +0000 Subject: [PATCH 11/25] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_tracking_graph.py | 32 +++++++++++++++-------------- src/traccuracy/track_errors/_ctc.py | 4 +--- tests/test_tracking_graph.py | 6 +++--- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 633bd60a..9fa2772b 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,9 +3,9 @@ import copy import enum import logging +from typing import Hashable import networkx as nx -from typing import Hashable logger = logging.getLogger(__name__) @@ -370,8 +370,8 @@ def get_subgraph(self, nodes): return new_trackgraph - def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool=True): - """Set an attribute flag for a single node. + def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool = True): + """Set an attribute flag for a single node. If the id is not found in the graph, a KeyError will be raised. If the flag already exists, the existing value will be overwritten. @@ -382,7 +382,7 @@ def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool=True): are included in the NodeAttr enum values. value (bool, optional): Attributes are flags and can only be set to True or False. Defaults to True. - + Raises: KeyError if the provided id is not in the graph. ValueError if the provided flag is not a NodeAttr @@ -398,9 +398,9 @@ def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool=True): self.nodes_by_flag[flag].add(_id) else: self.nodes_by_flag[flag].discard(_id) - - def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool=True): - """Set an attribute flag for all nodes in the graph. + + def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): + """Set an attribute flag for all nodes in the graph. If the flag already exists, the existing values will be overwritten. Args: @@ -409,7 +409,7 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool=True): are included in the NodeAttr enum values. value (bool, optional): Flags can only be set to True or False. Defaults to True. - + Raises: ValueError if the provided flag is not a NodeAttr. """ @@ -425,8 +425,10 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool=True): else: self.nodes_by_flag[flag] = set() - def set_flag_on_edge(self, _id: tuple(Hashable), flag: EdgeAttr, value: bool=True): - """Set an attribute flag for an edge. + def set_flag_on_edge( + self, _id: tuple(Hashable), flag: EdgeAttr, value: bool = True + ): + """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. Args: @@ -435,7 +437,7 @@ def set_flag_on_edge(self, _id: tuple(Hashable), flag: EdgeAttr, value: bool=Tru flag (traccuracy.EdgeAttr): The edge flag to set. Must be of type EdgeAttr - you may not pass strings, even if they are included in the EdgeAttr enum values. - value (bool): Flags can only be set to True or False. + value (bool): Flags can only be set to True or False. Defaults to True. Raises: @@ -452,9 +454,9 @@ def set_flag_on_edge(self, _id: tuple(Hashable), flag: EdgeAttr, value: bool=Tru self.edges_by_flag[flag].add(_id) else: self.edges_by_flag[flag].discard(_id) - - def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool=True): - """Set an attribute flag for all edges in the graph. + + def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool = True): + """Set an attribute flag for all edges in the graph. If the flag already exists, the existing values will be overwritten. Args: @@ -463,7 +465,7 @@ def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool=True): are included in the EdgeAttr enum values. value (bool, optional): Flags can only be set to True or False. Defaults to True. - + Raises: ValueError if the provided flag is not an EdgeAttr. """ diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 640e140a..e4cd4ba6 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -101,9 +101,7 @@ def get_edge_errors(matched_data: Matched): for merge in graph.get_merges(): for parent in graph.get_preds(merge): - graph.set_flag_on_edge( - (parent, merge), EdgeAttr.INTERTRACK_EDGE, True - ) + graph.set_flag_on_edge((parent, merge), EdgeAttr.INTERTRACK_EDGE, True) # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 400314ab..b61ae037 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -216,11 +216,11 @@ def test_get_and_set_flag_on_node(simple_graph): simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=False) for node in simple_graph.nodes: assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == False - + simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=True) for node in simple_graph.nodes: assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == True - + with pytest.raises(ValueError): simple_graph.set_flag_on_node("1_0", "x", 2) @@ -235,7 +235,7 @@ def test_get_and_set_flag_on_edge(simple_graph): simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=False) for edge in simple_graph.edges: assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == False - + simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=True) for edge in simple_graph.edges: assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == True From 1e9aa7e976a9648adb4c140abe69ace4da5b5537 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 14:59:21 -0500 Subject: [PATCH 12/25] Simplify IOU dictionary naming --- src/traccuracy/matchers/_iou.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index fa3e8da3..f2153055 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -105,14 +105,11 @@ def construct_time_to_seg_id_map(graph): matches = _match_nodes( gt.segmentation[i], pred.segmentation[i], threshold=threshold ) - gt_seg_to_node_map = gt_time_to_seg_id_map[t] - pred_seg_to_node_map = pred_time_to_seg_id_map[t] - # Construct node id tuple for each match for gt_seg_id, pred_seg_id in zip(*matches): # Find node id based on time and segmentation label - gt_node = gt_seg_to_node_map[gt_seg_id] - pred_node = pred_seg_to_node_map[pred_seg_id] + gt_node = gt_time_to_seg_id_map[t][gt_seg_id] + pred_node = pred_time_to_seg_id_map[t][pred_seg_id] mapper.append((gt_node, pred_node)) return mapper From e34b44b8a9bd18980f55b4442253862ce4913bc6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 15:15:37 -0500 Subject: [PATCH 13/25] Use dict syntax to get node and edge flags We decided that we did not want to assume that missing flags are false, and instead explicitly annotate all flags on all nodes/edges. The only additional functionality the getters provided was to check for missing values and return False. Since we do not want that functionality, we can revert to networkx style access of attributes. --- src/traccuracy/_tracking_graph.py | 49 --------------------------- src/traccuracy/track_errors/_ctc.py | 25 +++++++++----- tests/test_tracking_graph.py | 8 ++--- tests/track_errors/test_ctc_errors.py | 24 ++++++------- tests/track_errors/test_divisions.py | 20 +++++------ 5 files changed, 41 insertions(+), 85 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 9fa2772b..c6831300 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -481,55 +481,6 @@ def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool = True): else: self.edges_by_flag[flag] = set() - def get_node_attribute(self, _id, attr): - """Get the boolean value of a given attribute for a given node. - - Args: - _id (hashable): node id - attr (NodeAttr): Node attribute to fetch the value of - - Raises: - ValueError: if attr is not a NodeAttr - - Returns: - bool: The value of the attribute for that node. If the attribute - is not present on the graph, the value is presumed False. - """ - if not isinstance(attr, NodeAttr): - raise ValueError( - f"Provided attribute {attr} is not of type NodeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." - ) - - if attr not in self.graph.nodes[_id]: - return False - return self.graph.nodes[_id][attr] - - def get_edge_attribute(self, _id, attr): - """Get the boolean value of a given attribute for a given edge. - - Args: - _id (hashable): node id - attr (EdgeAttr): Edge attribute to fetch the value of - - Raises: - ValueError: if attr is not a EdgeAttr - - Returns: - bool: The value of the attribute for that edge. If the attribute - is not present on the graph, the value is presumed False. - """ - if not isinstance(attr, EdgeAttr): - raise ValueError( - f"Provided attribute {attr} is not of type EdgeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." - ) - - if attr not in self.graph.edges[_id]: - return False - return self.graph.edges[_id][attr] def get_tracklets(self, include_division_edges: bool = False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index e4cd4ba6..a7dedd0e 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -38,6 +38,9 @@ def get_vertex_errors(matched_data: Matched): if comp_graph.node_errors and gt_graph.node_errors: logger.info("Node errors already calculated. Skipping graph annotation") return + + comp_graph.set_flag_on_all_nodes(NodeAttr.TRUE_POS, False) + comp_graph.set_flag_on_all_nodes(NodeAttr.NON_SPLIT, False) # will flip this when we come across the vertex in the mapping comp_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, True) @@ -88,11 +91,18 @@ def get_edge_errors(matched_data: Matched): comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS) ).graph + comp_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, False) + comp_graph.set_flag_on_all_edges(EdgeAttr.WRONG_SEMANTIC, False) + gt_graph.set_flag_on_all_edges(EdgeAttr.FALSE_NEG, False) + gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: + # Set to False by default + graph.set_flag_on_all_edges(EdgeAttr.INTERTRACK_EDGE, False) + for parent in graph.get_divisions(): for daughter in graph.get_succs(parent): graph.set_flag_on_edge( @@ -115,12 +125,8 @@ def get_edge_errors(matched_data: Matched): comp_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_POS, True) else: # check if semantics are correct - is_parent_gt = gt_graph.get_edge_attribute( - expected_gt_edge, EdgeAttr.INTERTRACK_EDGE - ) - is_parent_comp = comp_graph.get_edge_attribute( - edge, EdgeAttr.INTERTRACK_EDGE - ) + is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE] + is_parent_comp = comp_graph.edges[edge][EdgeAttr.INTERTRACK_EDGE] if is_parent_gt != is_parent_comp: comp_graph.set_flag_on_edge(edge, EdgeAttr.WRONG_SEMANTIC, True) @@ -128,9 +134,10 @@ def get_edge_errors(matched_data: Matched): for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn - if gt_graph.get_node_attribute( - source, NodeAttr.FALSE_NEG - ) or gt_graph.get_node_attribute(target, NodeAttr.FALSE_NEG): + if ( + gt_graph.nodes[source][NodeAttr.FALSE_NEG] + or gt_graph.nodes[target][NodeAttr.FALSE_NEG] + ): gt_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_NEG, True) continue diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index b61ae037..b4ab2691 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -215,11 +215,11 @@ def test_get_and_set_flag_on_node(simple_graph): simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=False) for node in simple_graph.nodes: - assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == False + assert simple_graph.nodes[node][NodeAttr.FALSE_POS] == False simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=True) for node in simple_graph.nodes: - assert simple_graph.get_node_attribute(node, NodeAttr.FALSE_POS) == True + assert simple_graph.nodes[node][NodeAttr.FALSE_POS] == True with pytest.raises(ValueError): simple_graph.set_flag_on_node("1_0", "x", 2) @@ -234,11 +234,11 @@ def test_get_and_set_flag_on_edge(simple_graph): simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=False) for edge in simple_graph.edges: - assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == False + assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] == False simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=True) for edge in simple_graph.edges: - assert simple_graph.get_edge_attribute(edge, EdgeAttr.FALSE_POS) == True + assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] == True with pytest.raises(ValueError): simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index 3305ef3b..e8735c74 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -43,17 +43,17 @@ def test_get_vertex_errors(): assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 - assert matched_data.gt_graph.get_node_attribute(15, NodeAttr.FALSE_NEG) - assert not matched_data.gt_graph.get_node_attribute(17, NodeAttr.FALSE_NEG) + assert matched_data.gt_graph.nodes[15][NodeAttr.FALSE_NEG] + assert not matched_data.gt_graph.nodes[17][NodeAttr.FALSE_NEG] - assert matched_data.pred_graph.get_node_attribute(3, NodeAttr.NON_SPLIT) - assert not matched_data.pred_graph.get_node_attribute(7, NodeAttr.NON_SPLIT) + assert matched_data.pred_graph.nodes[3][NodeAttr.NON_SPLIT] + assert not matched_data.pred_graph.nodes[7][NodeAttr.NON_SPLIT] - assert matched_data.pred_graph.get_node_attribute(7, NodeAttr.TRUE_POS) - assert not matched_data.pred_graph.get_node_attribute(3, NodeAttr.TRUE_POS) + assert matched_data.pred_graph.nodes[7][NodeAttr.TRUE_POS] + assert not matched_data.pred_graph.nodes[3][NodeAttr.TRUE_POS] - assert matched_data.pred_graph.get_node_attribute(10, NodeAttr.FALSE_POS) - assert not matched_data.pred_graph.get_node_attribute(7, NodeAttr.FALSE_POS) + assert matched_data.pred_graph.nodes[10][NodeAttr.FALSE_POS] + assert not matched_data.pred_graph.nodes[7][NodeAttr.FALSE_POS] def test_assign_edge_errors(): @@ -93,8 +93,8 @@ def test_assign_edge_errors(): get_edge_errors(matched_data) - assert matched_data.pred_graph.get_edge_attribute((7, 8), EdgeAttr.FALSE_POS) - assert matched_data.gt_graph.get_edge_attribute((17, 18), EdgeAttr.FALSE_NEG) + assert matched_data.pred_graph.edges[(7, 8)][EdgeAttr.FALSE_POS] + assert matched_data.gt_graph.edges[(17, 18)][EdgeAttr.FALSE_NEG] def test_assign_edge_errors_semantics(): @@ -133,6 +133,4 @@ def test_assign_edge_errors_semantics(): get_edge_errors(matched_data) - assert matched_data.pred_graph.get_edge_attribute( - ("1_2", "1_3"), EdgeAttr.WRONG_SEMANTIC - ) + assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index 6c669f6d..bbcb1955 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -170,8 +170,8 @@ def test_no_change(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.get_node_attribute("1_3", NodeAttr.FP_DIV) is True - assert ng_gt.get_node_attribute("1_1", NodeAttr.FN_DIV) is True + assert ng_pred.nodes["1_3"][NodeAttr.FP_DIV] is True + assert ng_gt.nodes["1_1"][NodeAttr.FN_DIV] is True assert len(ng_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 def test_fn_early(self): @@ -187,10 +187,10 @@ def test_fn_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.get_node_attribute("1_3", NodeAttr.FP_DIV) is False - assert ng_gt.get_node_attribute("1_1", NodeAttr.FN_DIV) is False - assert ng_pred.get_node_attribute("1_3", NodeAttr.TP_DIV) is True - assert ng_gt.get_node_attribute("1_1", NodeAttr.TP_DIV) is True + assert ng_pred.nodes["1_3"][NodeAttr.FP_DIV] is False + assert ng_gt.nodes["1_1"][NodeAttr.FN_DIV] is False + assert ng_pred.nodes["1_3"][NodeAttr.TP_DIV] is True + assert ng_gt.nodes["1_1"][NodeAttr.TP_DIV] is True def test_fp_early(self): # Early division in pred @@ -205,10 +205,10 @@ def test_fp_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.get_node_attribute("1_1", NodeAttr.FP_DIV) is False - assert ng_gt.get_node_attribute("1_3", NodeAttr.FN_DIV) is False - assert ng_pred.get_node_attribute("1_1", NodeAttr.TP_DIV) is True - assert ng_gt.get_node_attribute("1_3", NodeAttr.TP_DIV) is True + assert ng_pred.nodes["1_1"][NodeAttr.FP_DIV] is False + assert ng_gt.nodes["1_3"][NodeAttr.FN_DIV] is False + assert ng_pred.nodes["1_1"][NodeAttr.TP_DIV] is True + assert ng_gt.nodes["1_3"][NodeAttr.TP_DIV] is True def test_evaluate_division_events(): From 74264397f9781769486fe93e4780fcd08cd165e3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Nov 2023 20:17:46 +0000 Subject: [PATCH 14/25] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_tracking_graph.py | 1 - src/traccuracy/track_errors/_ctc.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index c6831300..0341a5fc 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -481,7 +481,6 @@ def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool = True): else: self.edges_by_flag[flag] = set() - def get_tracklets(self, include_division_edges: bool = False): """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index a7dedd0e..71839998 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -38,7 +38,7 @@ def get_vertex_errors(matched_data: Matched): if comp_graph.node_errors and gt_graph.node_errors: logger.info("Node errors already calculated. Skipping graph annotation") return - + comp_graph.set_flag_on_all_nodes(NodeAttr.TRUE_POS, False) comp_graph.set_flag_on_all_nodes(NodeAttr.NON_SPLIT, False) @@ -100,7 +100,7 @@ def get_edge_errors(matched_data: Matched): # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: - # Set to False by default + # Set to False by default graph.set_flag_on_all_edges(EdgeAttr.INTERTRACK_EDGE, False) for parent in graph.get_divisions(): From 72d143ebd9a1d6d38887952d8519ac18fa818619 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 15:40:33 -0500 Subject: [PATCH 15/25] Fix ruff and mypy complaints --- src/traccuracy/_tracking_graph.py | 4 ++-- tests/test_tracking_graph.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 0341a5fc..14fe2ba7 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,7 +3,7 @@ import copy import enum import logging -from typing import Hashable +from typing import Hashable, Tuple import networkx as nx @@ -426,7 +426,7 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): self.nodes_by_flag[flag] = set() def set_flag_on_edge( - self, _id: tuple(Hashable), flag: EdgeAttr, value: bool = True + self, _id: Tuple(Hashable), flag: EdgeAttr, value: bool = True ): """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index b4ab2691..8f3a0907 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -215,11 +215,11 @@ def test_get_and_set_flag_on_node(simple_graph): simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=False) for node in simple_graph.nodes: - assert simple_graph.nodes[node][NodeAttr.FALSE_POS] == False + assert simple_graph.nodes[node][NodeAttr.FALSE_POS] is False simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=True) for node in simple_graph.nodes: - assert simple_graph.nodes[node][NodeAttr.FALSE_POS] == True + assert simple_graph.nodes[node][NodeAttr.FALSE_POS] is True with pytest.raises(ValueError): simple_graph.set_flag_on_node("1_0", "x", 2) @@ -234,11 +234,11 @@ def test_get_and_set_flag_on_edge(simple_graph): simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=False) for edge in simple_graph.edges: - assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] == False + assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] is False simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=True) for edge in simple_graph.edges: - assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] == True + assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] is True with pytest.raises(ValueError): simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) From e88f0ed657e61e0d71fdb366e5f3c6fe6a196138 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Nov 2023 20:40:43 +0000 Subject: [PATCH 16/25] style(pre-commit.ci): auto fixes [...] --- src/traccuracy/_tracking_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 14fe2ba7..0341a5fc 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,7 +3,7 @@ import copy import enum import logging -from typing import Hashable, Tuple +from typing import Hashable import networkx as nx @@ -426,7 +426,7 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): self.nodes_by_flag[flag] = set() def set_flag_on_edge( - self, _id: Tuple(Hashable), flag: EdgeAttr, value: bool = True + self, _id: tuple(Hashable), flag: EdgeAttr, value: bool = True ): """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. From 8e2c0e3627dd356f872458e665f2d9d0621d881d Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 15:43:41 -0500 Subject: [PATCH 17/25] Actually fix mypy typing issue --- src/traccuracy/_tracking_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 0341a5fc..aac71157 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -426,13 +426,13 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): self.nodes_by_flag[flag] = set() def set_flag_on_edge( - self, _id: tuple(Hashable), flag: EdgeAttr, value: bool = True + self, _id: tuple[Hashable], flag: EdgeAttr, value: bool = True ): """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. Args: - ids (tuple(Hashable)): The edge id or list of edge ids + ids (tuple[Hashable]): The edge id or list of edge ids to set the attribute for. Edge ids are a 2-tuple of node ids. flag (traccuracy.EdgeAttr): The edge flag to set. Must be of type EdgeAttr - you may not pass strings, even if they are From a5e8d62d1895a4459aaa2bd2946a6278af2e253c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 15:51:58 -0500 Subject: [PATCH 18/25] Actually actually fix mypy typing errors --- src/traccuracy/_tracking_graph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index aac71157..c1586d53 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -426,7 +426,7 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): self.nodes_by_flag[flag] = set() def set_flag_on_edge( - self, _id: tuple[Hashable], flag: EdgeAttr, value: bool = True + self, _id: tuple[Hashable, Hashable], flag: EdgeAttr, value: bool = True ): """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. From 6fc9015fd4aba675de0c79faea051b77ff21228c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 29 Nov 2023 17:44:00 -0500 Subject: [PATCH 19/25] Change from Node/EdgeAttr to Node/EdgeFlag --- src/traccuracy/__init__.py | 4 +- src/traccuracy/_tracking_graph.py | 129 +++++++++++------------ src/traccuracy/metrics/_ctc.py | 14 +-- src/traccuracy/metrics/_divisions.py | 8 +- src/traccuracy/track_errors/_ctc.py | 52 ++++----- src/traccuracy/track_errors/divisions.py | 24 ++--- tests/test_tracking_graph.py | 40 +++---- tests/track_errors/test_ctc_errors.py | 36 +++---- tests/track_errors/test_divisions.py | 58 +++++----- 9 files changed, 179 insertions(+), 186 deletions(-) diff --git a/src/traccuracy/__init__.py b/src/traccuracy/__init__.py index 6fbbeeb8..662c5f31 100644 --- a/src/traccuracy/__init__.py +++ b/src/traccuracy/__init__.py @@ -7,6 +7,6 @@ __version__ = "uninstalled" from ._run_metrics import run_metrics -from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +from ._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph -__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"] +__all__ = ["TrackingGraph", "run_metrics", "NodeFlag", "EdgeFlag"] diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index c1586d53..2065a191 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -11,14 +11,12 @@ @enum.unique -class NodeAttr(str, enum.Enum): - """An enum containing all valid attributes that can be used to - annotate the nodes of a TrackingGraph. If new metrics require new - annotations, they should be added here to ensure strings do not overlap and - are standardized. Note that the user specified frame and location +class NodeFlag(str, enum.Enum): + """An enum containing standard flags that are used to annotate the nodes + of a TrackingGraph. Note that the user specified frame and location attributes are also valid node attributes that will be stored on the graph and should not overlap with these values. Additionally, if a graph already - has annotations using these strings before becoming a TrackGraph, + has annotations using these strings before becoming a TrackingGraph, this will likely ruin metrics computation! """ @@ -55,12 +53,10 @@ def has_value(cls, value): @enum.unique -class EdgeAttr(str, enum.Enum): - """An enum containing all valid attributes that can be used to - annotate the edges of a TrackingGraph. If new metrics require new - annotations, they should be added here to ensure strings do not overlap and - are standardized. Additionally, if a graph already - has annotations using these strings before becoming a TrackGraph, +class EdgeFlag(str, enum.Enum): + """An enum containing standard flags that are used to + annotate the edges of a TrackingGraph. If a graph already has + annotations using these strings before becoming a TrackingGraph, this will likely ruin metrics computation! """ @@ -118,14 +114,14 @@ def __init__( forward in time. If the provided graph already has annotations that are strings - included in NodeAttrs or EdgeAttrs, this will likely ruin + included in NodeFlags or EdgeFlags, this will likely ruin metric computation! Args: graph (networkx.DiGraph): A directed graph representing a tracking solution where edges go forward in time. If the graph already - has annotations that are strings included in NodeAttrs or - EdgeAttrs, this will likely ruin metrics computation! + has annotations that are strings included in NodeFlags or + EdgeFlags, this will likely ruin metrics computation! segmentation (numpy-like array, optional): A numpy-like array of segmentations. The location of each node in tracking_graph is assumed to be inside the area of the corresponding segmentation. Defaults to None. @@ -142,20 +138,20 @@ def __init__( Defaults to ('x', 'y'). """ self.segmentation = segmentation - if NodeAttr.has_value(frame_key): + if NodeFlag.has_value(frame_key): raise ValueError( f"Specified frame key {frame_key} is reserved for graph " "annotation. Please change the frame key." ) self.frame_key = frame_key - if label_key is not None and NodeAttr.has_value(label_key): + if label_key is not None and NodeFlag.has_value(label_key): raise ValueError( f"Specified label key {label_key} is reserved for graph" "annotation. Please change the label key." ) self.label_key = label_key for loc_key in location_keys: - if NodeAttr.has_value(loc_key): + if NodeFlag.has_value(loc_key): raise ValueError( f"Specified location key {loc_key} is reserved for graph" "annotation. Please change the location key." @@ -166,8 +162,8 @@ def __init__( # construct dictionaries from attributes to nodes/edges for easy lookup self.nodes_by_frame = {} - self.nodes_by_flag = {flag: set() for flag in NodeAttr} - self.edges_by_flag = {flag: set() for flag in EdgeAttr} + self.nodes_by_flag = {flag: set() for flag in NodeFlag} + self.edges_by_flag = {flag: set() for flag in EdgeFlag} for node, attrs in self.graph.nodes.items(): # check that every node has the time frame and location specified assert ( @@ -185,13 +181,13 @@ def __init__( else: self.nodes_by_frame[frame].add(node) # store node id in nodes_by_flag mapping - for flag in NodeAttr: + for flag in NodeFlag: if flag in attrs and attrs[flag]: self.nodes_by_flag[flag].add(node) # store edge id in edges_by_flag for edge, attrs in self.graph.edges.items(): - for flag in EdgeAttr: + for flag in EdgeFlag: if flag in attrs and attrs[flag]: self.edges_by_flag[flag].add(edge) @@ -251,31 +247,31 @@ def get_location(self, node_id): return [self.graph.nodes[node_id][key] for key in self.location_keys] def get_nodes_with_flag(self, attr): - """Get all nodes with specified NodeAttr set to True. + """Get all nodes with specified NodeFlag set to True. Args: - attr (traccuracy.NodeAttr): the node attribute to query for + attr (traccuracy.NodeFlag): the node attribute to query for Returns: (List(hashable)): A list of node_ids which have the given attribute and the value is True. """ - if not isinstance(attr, NodeAttr): - raise ValueError(f"Function takes NodeAttr arguments, not {type(attr)}.") + if not isinstance(attr, NodeFlag): + raise ValueError(f"Function takes NodeFlag arguments, not {type(attr)}.") return list(self.nodes_by_flag[attr]) def get_edges_with_flag(self, attr): - """Get all edges with specified EdgeAttr set to True. + """Get all edges with specified EdgeFlag set to True. Args: - attr (traccuracy.EdgeAttr): the edge attribute to query for + attr (traccuracy.EdgeFlag): the edge attribute to query for Returns: (List(hashable)): A list of edge ids which have the given attribute and the value is True. """ - if not isinstance(attr, EdgeAttr): - raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.") + if not isinstance(attr, EdgeFlag): + raise ValueError(f"Function takes EdgeFlag arguments, not {type(attr)}.") return list(self.edges_by_flag[attr]) def get_divisions(self): @@ -356,12 +352,12 @@ def get_subgraph(self, nodes): else: del new_trackgraph.nodes_by_frame[frame] - for attr in NodeAttr: - new_trackgraph.nodes_by_flag[attr] = self.nodes_by_flag[attr].intersection( + for flag in NodeFlag: + new_trackgraph.nodes_by_flag[flag] = self.nodes_by_flag[flag].intersection( nodes ) - for attr in EdgeAttr: - new_trackgraph.edges_by_flag[attr] = self.edges_by_flag[attr].intersection( + for flag in EdgeFlag: + new_trackgraph.edges_by_flag[flag] = self.edges_by_flag[flag].intersection( nodes ) @@ -370,28 +366,27 @@ def get_subgraph(self, nodes): return new_trackgraph - def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool = True): + def set_flag_on_node(self, _id: Hashable, flag: NodeFlag, value: bool = True): """Set an attribute flag for a single node. If the id is not found in the graph, a KeyError will be raised. If the flag already exists, the existing value will be overwritten. Args: _id (Hashable): The node id on which to set the flag. - flag (traccuracy.NodeAttr): The node flag to set. Must be - of type NodeAttr - you may not not pass strings, even if they - are included in the NodeAttr enum values. - value (bool, optional): Attributes are flags and can only be set to + flag (traccuracy.NodeFlag): The node flag to set. Must be + of type NodeFlag - you may not not pass strings, even if they + are included in the NodeFlag enum values. + value (bool, optional): Flags can only be set to True or False. Defaults to True. Raises: KeyError if the provided id is not in the graph. - ValueError if the provided flag is not a NodeAttr + ValueError if the provided flag is not a NodeFlag """ - if not isinstance(flag, NodeAttr): + if not isinstance(flag, NodeFlag): raise ValueError( - f"Provided flag {flag} is not of type NodeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided flag {flag} is not of type NodeFlag. " + "Please use the enum instead of passing string values." ) self.graph.nodes[_id][flag] = value if value: @@ -399,25 +394,24 @@ def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool = True): else: self.nodes_by_flag[flag].discard(_id) - def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): + def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True): """Set an attribute flag for all nodes in the graph. If the flag already exists, the existing values will be overwritten. Args: - flag (traccuracy.NodeAttr): The node flag to set. Must be - of type NodeAttr - you may not not pass strings, even if they - are included in the NodeAttr enum values. + flag (traccuracy.NodeFlag): The node flag to set. Must be + of type NodeFlag - you may not not pass strings, even if they + are included in the NodeFlag enum values. value (bool, optional): Flags can only be set to True or False. Defaults to True. Raises: - ValueError if the provided flag is not a NodeAttr. + ValueError if the provided flag is not a NodeFlag. """ - if not isinstance(flag, NodeAttr): + if not isinstance(flag, NodeFlag): raise ValueError( - f"Provided flag {flag} is not of type NodeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided flag {flag} is not of type NodeFlag. " + "Please use the enum instead of passing string values." ) nx.set_node_attributes(self.graph, value, name=flag) if value: @@ -426,7 +420,7 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True): self.nodes_by_flag[flag] = set() def set_flag_on_edge( - self, _id: tuple[Hashable, Hashable], flag: EdgeAttr, value: bool = True + self, _id: tuple[Hashable, Hashable], flag: EdgeFlag, value: bool = True ): """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. @@ -434,20 +428,19 @@ def set_flag_on_edge( Args: ids (tuple[Hashable]): The edge id or list of edge ids to set the attribute for. Edge ids are a 2-tuple of node ids. - flag (traccuracy.EdgeAttr): The edge flag to set. Must be - of type EdgeAttr - you may not pass strings, even if they are - included in the EdgeAttr enum values. + flag (traccuracy.EdgeFlag): The edge flag to set. Must be + of type EdgeFlag - you may not pass strings, even if they are + included in the EdgeFlag enum values. value (bool): Flags can only be set to True or False. Defaults to True. Raises: KeyError if edge with _id not in graph. """ - if not isinstance(flag, EdgeAttr): + if not isinstance(flag, EdgeFlag): raise ValueError( - f"Provided attribute {flag} is not of type EdgeAttr. " - "Please use the enum instead of passing string values, " - "and add new attributes to the class to avoid key collision." + f"Provided attribute {flag} is not of type EdgeFlag. " + "Please use the enum instead of passing string values." ) self.graph.edges[_id][flag] = value if value: @@ -455,23 +448,23 @@ def set_flag_on_edge( else: self.edges_by_flag[flag].discard(_id) - def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool = True): + def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True): """Set an attribute flag for all edges in the graph. If the flag already exists, the existing values will be overwritten. Args: - flag (traccuracy.EdgeAttr): The edge flag to set. Must be - of type EdgeAttr - you may not not pass strings, even if they - are included in the EdgeAttr enum values. + flag (traccuracy.EdgeFlag): The edge flag to set. Must be + of type EdgeFlag - you may not not pass strings, even if they + are included in the EdgeFlag enum values. value (bool, optional): Flags can only be set to True or False. Defaults to True. Raises: - ValueError if the provided flag is not an EdgeAttr. + ValueError if the provided flag is not an EdgeFlag. """ - if not isinstance(flag, EdgeAttr): + if not isinstance(flag, EdgeFlag): raise ValueError( - f"Provided flag {flag} is not of type EdgeAttr. " + f"Provided flag {flag} is not of type EdgeFlag. " "Please use the enum instead of passing string values, " "and add new attributes to the class to avoid key collision." ) diff --git a/src/traccuracy/metrics/_ctc.py b/src/traccuracy/metrics/_ctc.py index 28b6ad11..d9257b82 100644 --- a/src/traccuracy/metrics/_ctc.py +++ b/src/traccuracy/metrics/_ctc.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING -from traccuracy._tracking_graph import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeFlag, NodeFlag from traccuracy.track_errors._ctc import evaluate_ctc_events from ._base import Metric @@ -38,14 +38,14 @@ def compute(self, data: Matched): evaluate_ctc_events(data) vertex_error_counts = { - "ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)), - "fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)), - "fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)), + "ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)), + "fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)), + "fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)), } edge_error_counts = { - "ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)), - "fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)), - "fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)), + "ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)), + "fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.FALSE_POS)), + "fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)), } error_sum = get_weighted_error_sum( vertex_error_counts, diff --git a/src/traccuracy/metrics/_divisions.py b/src/traccuracy/metrics/_divisions.py index 4f3f4b6f..a7af81b2 100644 --- a/src/traccuracy/metrics/_divisions.py +++ b/src/traccuracy/metrics/_divisions.py @@ -36,7 +36,7 @@ from typing import TYPE_CHECKING -from traccuracy._tracking_graph import NodeAttr +from traccuracy._tracking_graph import NodeFlag from traccuracy.track_errors.divisions import _evaluate_division_events from ._base import Metric @@ -90,9 +90,9 @@ def compute(self, data: Matched): } def _calculate_metrics(self, g_gt, g_pred): - tp_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) - fn_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV)) - fp_division_count = len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV)) + tp_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) + fn_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.FN_DIV)) + fp_division_count = len(g_pred.get_nodes_with_flag(NodeFlag.FP_DIV)) try: recall = tp_division_count / (tp_division_count + fn_division_count) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index 71839998..cf8a3c6c 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -6,7 +6,7 @@ from tqdm import tqdm -from traccuracy._tracking_graph import EdgeAttr, NodeAttr +from traccuracy._tracking_graph import EdgeFlag, NodeFlag if TYPE_CHECKING: from traccuracy.matchers import Matched @@ -39,12 +39,12 @@ def get_vertex_errors(matched_data: Matched): logger.info("Node errors already calculated. Skipping graph annotation") return - comp_graph.set_flag_on_all_nodes(NodeAttr.TRUE_POS, False) - comp_graph.set_flag_on_all_nodes(NodeAttr.NON_SPLIT, False) + comp_graph.set_flag_on_all_nodes(NodeFlag.TRUE_POS, False) + comp_graph.set_flag_on_all_nodes(NodeFlag.NON_SPLIT, False) # will flip this when we come across the vertex in the mapping - comp_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, True) - gt_graph.set_flag_on_all_nodes(NodeAttr.FALSE_NEG, True) + comp_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, True) + gt_graph.set_flag_on_all_nodes(NodeFlag.FALSE_NEG, True) # we need to know how many computed vertices are "non-split", so we make # a mapping of gt vertices to their matched comp vertices @@ -57,16 +57,16 @@ def get_vertex_errors(matched_data: Matched): gt_ids = dict_mapping[pred_id] if len(gt_ids) == 1: gid = gt_ids[0] - comp_graph.set_flag_on_node(pred_id, NodeAttr.TRUE_POS, True) - comp_graph.set_flag_on_node(pred_id, NodeAttr.FALSE_POS, False) - gt_graph.set_flag_on_node(gid, NodeAttr.FALSE_NEG, False) + comp_graph.set_flag_on_node(pred_id, NodeFlag.TRUE_POS, True) + comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False) + gt_graph.set_flag_on_node(gid, NodeFlag.FALSE_NEG, False) elif len(gt_ids) > 1: - comp_graph.set_flag_on_node(pred_id, NodeAttr.NON_SPLIT, True) - comp_graph.set_flag_on_node(pred_id, NodeAttr.FALSE_POS, False) + comp_graph.set_flag_on_node(pred_id, NodeFlag.NON_SPLIT, True) + comp_graph.set_flag_on_node(pred_id, NodeFlag.FALSE_POS, False) # number of split operations that would be required to correct the vertices ns_count += len(gt_ids) - 1 for gt_id in gt_ids: - gt_graph.set_flag_on_node(gt_id, NodeAttr.FALSE_NEG, False) + gt_graph.set_flag_on_node(gt_id, NodeFlag.FALSE_NEG, False) # Record presence of annotations on the TrackingGraph comp_graph.node_errors = True @@ -88,12 +88,12 @@ def get_edge_errors(matched_data: Matched): get_vertex_errors(matched_data) induced_graph = comp_graph.get_subgraph( - comp_graph.get_nodes_with_flag(NodeAttr.TRUE_POS) + comp_graph.get_nodes_with_flag(NodeFlag.TRUE_POS) ).graph - comp_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, False) - comp_graph.set_flag_on_all_edges(EdgeAttr.WRONG_SEMANTIC, False) - gt_graph.set_flag_on_all_edges(EdgeAttr.FALSE_NEG, False) + comp_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, False) + comp_graph.set_flag_on_all_edges(EdgeFlag.WRONG_SEMANTIC, False) + gt_graph.set_flag_on_all_edges(EdgeFlag.FALSE_NEG, False) gt_comp_mapping = {gt: comp for gt, comp in node_mapping if comp in induced_graph} comp_gt_mapping = {comp: gt for gt, comp in node_mapping if comp in induced_graph} @@ -101,17 +101,17 @@ def get_edge_errors(matched_data: Matched): # intertrack edges = connection between parent and daughter for graph in [comp_graph, gt_graph]: # Set to False by default - graph.set_flag_on_all_edges(EdgeAttr.INTERTRACK_EDGE, False) + graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False) for parent in graph.get_divisions(): for daughter in graph.get_succs(parent): graph.set_flag_on_edge( - (parent, daughter), EdgeAttr.INTERTRACK_EDGE, True + (parent, daughter), EdgeFlag.INTERTRACK_EDGE, True ) for merge in graph.get_merges(): for parent in graph.get_preds(merge): - graph.set_flag_on_edge((parent, merge), EdgeAttr.INTERTRACK_EDGE, True) + graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True) # fp edges - edges in induced_graph that aren't in gt_graph for edge in tqdm(induced_graph.edges, "Evaluating FP edges"): @@ -122,23 +122,23 @@ def get_edge_errors(matched_data: Matched): expected_gt_edge = (source_gt_id, target_gt_id) if expected_gt_edge not in gt_graph.edges: - comp_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_POS, True) + comp_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_POS, True) else: # check if semantics are correct - is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeAttr.INTERTRACK_EDGE] - is_parent_comp = comp_graph.edges[edge][EdgeAttr.INTERTRACK_EDGE] + is_parent_gt = gt_graph.edges[expected_gt_edge][EdgeFlag.INTERTRACK_EDGE] + is_parent_comp = comp_graph.edges[edge][EdgeFlag.INTERTRACK_EDGE] if is_parent_gt != is_parent_comp: - comp_graph.set_flag_on_edge(edge, EdgeAttr.WRONG_SEMANTIC, True) + comp_graph.set_flag_on_edge(edge, EdgeFlag.WRONG_SEMANTIC, True) # fn edges - edges in gt_graph that aren't in induced graph for edge in tqdm(gt_graph.edges, "Evaluating FN edges"): source, target = edge[0], edge[1] # this edge is adjacent to an edge we didn't detect, so it definitely is an fn if ( - gt_graph.nodes[source][NodeAttr.FALSE_NEG] - or gt_graph.nodes[target][NodeAttr.FALSE_NEG] + gt_graph.nodes[source][NodeFlag.FALSE_NEG] + or gt_graph.nodes[target][NodeFlag.FALSE_NEG] ): - gt_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) continue source_comp_id = gt_comp_mapping[source] @@ -146,7 +146,7 @@ def get_edge_errors(matched_data: Matched): expected_comp_edge = (source_comp_id, target_comp_id) if expected_comp_edge not in induced_graph.edges: - gt_graph.set_flag_on_edge(edge, EdgeAttr.FALSE_NEG, True) + gt_graph.set_flag_on_edge(edge, EdgeFlag.FALSE_NEG, True) gt_graph.edge_errors = True comp_graph.edge_errors = True diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index 43626f29..cb2b5e50 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -6,7 +6,7 @@ from collections import Counter from typing import TYPE_CHECKING -from traccuracy._tracking_graph import NodeAttr +from traccuracy._tracking_graph import NodeFlag from traccuracy._utils import find_gt_node_matches, find_pred_node_matches if TYPE_CHECKING: @@ -62,7 +62,7 @@ def _find_pred_node_matches(pred_node): pred_node = _find_gt_node_matches(gt_node) # No matching node so division missed if pred_node is None: - g_gt.set_flag_on_node(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True) # Check if the division has the correct daughters else: succ_gt = g_gt.get_succs(gt_node) @@ -73,11 +73,11 @@ def _find_pred_node_matches(pred_node): # If daughters are same, division is correct if Counter(succ_gt) == Counter(succ_pred): - g_gt.set_flag_on_node(gt_node, NodeAttr.TP_DIV, True) - g_pred.set_flag_on_node(pred_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.TP_DIV, True) + g_pred.set_flag_on_node(pred_node, NodeFlag.TP_DIV, True) # If daughters are at all mismatched, division is false negative else: - g_gt.set_flag_on_node(gt_node, NodeAttr.FN_DIV, True) + g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True) # Remove res division to record that we have classified it if pred_node in div_pred: @@ -85,7 +85,7 @@ def _find_pred_node_matches(pred_node): # Any remaining pred divisions are false positives for fp_div in div_pred: - g_pred.set_flag_on_node(fp_div, NodeAttr.FP_DIV, True) + g_pred.set_flag_on_node(fp_div, NodeFlag.FP_DIV, True) # Set division annotation flag g_gt.division_annotations = True @@ -171,8 +171,8 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): ): raise ValueError("Mapping must be one-to-one") - fp_divs = g_pred.get_nodes_with_flag(NodeAttr.FP_DIV) - fn_divs = g_gt.get_nodes_with_flag(NodeAttr.FN_DIV) + fp_divs = g_pred.get_nodes_with_flag(NodeFlag.FP_DIV) + fn_divs = g_gt.get_nodes_with_flag(NodeFlag.FN_DIV) # Compare all pairs of fp and fn for fp_node, fn_node in itertools.product(fp_divs, fn_divs): @@ -229,12 +229,12 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): if correct: # Remove error annotations from pred graph - g_pred.set_flag_on_node(fp_node, NodeAttr.FP_DIV, False) - g_gt.set_flag_on_node(fn_node, NodeAttr.FN_DIV, False) + g_pred.set_flag_on_node(fp_node, NodeFlag.FP_DIV, False) + g_gt.set_flag_on_node(fn_node, NodeFlag.FN_DIV, False) # Add the tp divisions annotations - g_gt.set_flag_on_node(fn_node, NodeAttr.TP_DIV, True) - g_pred.set_flag_on_node(fp_node, NodeAttr.TP_DIV, True) + g_gt.set_flag_on_node(fn_node, NodeFlag.TP_DIV, True) + g_pred.set_flag_on_node(fp_node, NodeFlag.TP_DIV, True) return new_matched diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 8f3a0907..09100561 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -2,7 +2,7 @@ import networkx as nx import pytest -from traccuracy import EdgeAttr, NodeAttr, TrackingGraph +from traccuracy import EdgeFlag, NodeFlag, TrackingGraph @pytest.fixture @@ -127,11 +127,11 @@ def test_constructor(nx_comp1): with pytest.raises(AssertionError): TrackingGraph(nx_comp1, frame_key="f") with pytest.raises(ValueError): - TrackingGraph(nx_comp1, frame_key=NodeAttr.FALSE_NEG) + TrackingGraph(nx_comp1, frame_key=NodeFlag.FALSE_NEG) with pytest.raises(AssertionError): TrackingGraph(nx_comp1, location_keys=["x", "y", "z"]) with pytest.raises(ValueError): - TrackingGraph(nx_comp1, location_keys=["x", NodeAttr.FALSE_NEG]) + TrackingGraph(nx_comp1, location_keys=["x", NodeFlag.FALSE_NEG]) def test_get_cells_by_frame(simple_graph): @@ -141,15 +141,15 @@ def test_get_cells_by_frame(simple_graph): def test_get_nodes_with_flag(simple_graph): - assert simple_graph.get_nodes_with_flag(NodeAttr.TP_DIV) == ["1_1"] - assert simple_graph.get_nodes_with_flag(NodeAttr.FP_DIV) == [] + assert simple_graph.get_nodes_with_flag(NodeFlag.TP_DIV) == ["1_1"] + assert simple_graph.get_nodes_with_flag(NodeFlag.FP_DIV) == [] with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp_division") def test_get_edges_with_flag(simple_graph): - assert simple_graph.get_edges_with_flag(EdgeAttr.TRUE_POS) == [("1_0", "1_1")] - assert simple_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG) == [] + assert simple_graph.get_edges_with_flag(EdgeFlag.TRUE_POS) == [("1_0", "1_1")] + assert simple_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG) == [] with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp") @@ -204,22 +204,22 @@ def test_get_and_set_flag_on_node(simple_graph): "is_tp_division": True, } - simple_graph.set_flag_on_node("1_0", NodeAttr.FALSE_POS, value=False) + simple_graph.set_flag_on_node("1_0", NodeFlag.FALSE_POS, value=False) assert simple_graph.nodes()["1_0"] == { "id": "1_0", "t": 0, "y": 1, "x": 1, - NodeAttr.FALSE_POS: False, + NodeFlag.FALSE_POS: False, } - simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=False) + simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=False) for node in simple_graph.nodes: - assert simple_graph.nodes[node][NodeAttr.FALSE_POS] is False + assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is False - simple_graph.set_flag_on_all_nodes(NodeAttr.FALSE_POS, value=True) + simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=True) for node in simple_graph.nodes: - assert simple_graph.nodes[node][NodeAttr.FALSE_POS] is True + assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is True with pytest.raises(ValueError): simple_graph.set_flag_on_node("1_0", "x", 2) @@ -227,18 +227,18 @@ def test_get_and_set_flag_on_node(simple_graph): def test_get_and_set_flag_on_edge(simple_graph): print(simple_graph.edges()) - assert EdgeAttr.TRUE_POS not in simple_graph.edges()[("1_1", "1_3")] + assert EdgeFlag.TRUE_POS not in simple_graph.edges()[("1_1", "1_3")] - simple_graph.set_flag_on_edge(("1_1", "1_3"), EdgeAttr.TRUE_POS, value=False) - assert simple_graph.edges()[("1_1", "1_3")][EdgeAttr.TRUE_POS] is False + simple_graph.set_flag_on_edge(("1_1", "1_3"), EdgeFlag.TRUE_POS, value=False) + assert simple_graph.edges()[("1_1", "1_3")][EdgeFlag.TRUE_POS] is False - simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=False) + simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=False) for edge in simple_graph.edges: - assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] is False + assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is False - simple_graph.set_flag_on_all_edges(EdgeAttr.FALSE_POS, value=True) + simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=True) for edge in simple_graph.edges: - assert simple_graph.edges[edge][EdgeAttr.FALSE_POS] is True + assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is True with pytest.raises(ValueError): simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) diff --git a/tests/track_errors/test_ctc_errors.py b/tests/track_errors/test_ctc_errors.py index e8735c74..a66517d0 100644 --- a/tests/track_errors/test_ctc_errors.py +++ b/tests/track_errors/test_ctc_errors.py @@ -1,6 +1,6 @@ import networkx as nx import numpy as np -from traccuracy._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph +from traccuracy._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph from traccuracy.matchers import Matched from traccuracy.track_errors._ctc import get_edge_errors, get_vertex_errors @@ -38,22 +38,22 @@ def test_get_vertex_errors(): get_vertex_errors(matched_data) - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)) == 1 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.TRUE_POS)) == 3 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)) == 2 - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)) == 1 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.TRUE_POS)) == 3 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)) == 2 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)) == 3 - assert matched_data.gt_graph.nodes[15][NodeAttr.FALSE_NEG] - assert not matched_data.gt_graph.nodes[17][NodeAttr.FALSE_NEG] + assert matched_data.gt_graph.nodes[15][NodeFlag.FALSE_NEG] + assert not matched_data.gt_graph.nodes[17][NodeFlag.FALSE_NEG] - assert matched_data.pred_graph.nodes[3][NodeAttr.NON_SPLIT] - assert not matched_data.pred_graph.nodes[7][NodeAttr.NON_SPLIT] + assert matched_data.pred_graph.nodes[3][NodeFlag.NON_SPLIT] + assert not matched_data.pred_graph.nodes[7][NodeFlag.NON_SPLIT] - assert matched_data.pred_graph.nodes[7][NodeAttr.TRUE_POS] - assert not matched_data.pred_graph.nodes[3][NodeAttr.TRUE_POS] + assert matched_data.pred_graph.nodes[7][NodeFlag.TRUE_POS] + assert not matched_data.pred_graph.nodes[3][NodeFlag.TRUE_POS] - assert matched_data.pred_graph.nodes[10][NodeAttr.FALSE_POS] - assert not matched_data.pred_graph.nodes[7][NodeAttr.FALSE_POS] + assert matched_data.pred_graph.nodes[10][NodeFlag.FALSE_POS] + assert not matched_data.pred_graph.nodes[7][NodeFlag.FALSE_POS] def test_assign_edge_errors(): @@ -72,7 +72,7 @@ def test_assign_edge_errors(): comp_g = nx.DiGraph() comp_g.add_nodes_from(comp_ids) comp_g.add_edges_from(comp_edges) - nx.set_node_attributes(comp_g, True, NodeAttr.TRUE_POS) + nx.set_node_attributes(comp_g, True, NodeFlag.TRUE_POS) nx.set_node_attributes( comp_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in comp_ids}, @@ -83,7 +83,7 @@ def test_assign_edge_errors(): gt_g = nx.DiGraph() gt_g.add_nodes_from(gt_ids) gt_g.add_edges_from(gt_edges) - nx.set_node_attributes(gt_g, False, NodeAttr.FALSE_NEG) + nx.set_node_attributes(gt_g, False, NodeFlag.FALSE_NEG) nx.set_node_attributes( gt_g, {idx: {"t": 0, "segmentation_id": 1, "y": 0, "x": 0} for idx in gt_ids} ) @@ -93,8 +93,8 @@ def test_assign_edge_errors(): get_edge_errors(matched_data) - assert matched_data.pred_graph.edges[(7, 8)][EdgeAttr.FALSE_POS] - assert matched_data.gt_graph.edges[(17, 18)][EdgeAttr.FALSE_NEG] + assert matched_data.pred_graph.edges[(7, 8)][EdgeFlag.FALSE_POS] + assert matched_data.gt_graph.edges[(17, 18)][EdgeFlag.FALSE_NEG] def test_assign_edge_errors_semantics(): @@ -133,4 +133,4 @@ def test_assign_edge_errors_semantics(): get_edge_errors(matched_data) - assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeAttr.WRONG_SEMANTIC] + assert matched_data.pred_graph.edges[("1_2", "1_3")][EdgeFlag.WRONG_SEMANTIC] diff --git a/tests/track_errors/test_divisions.py b/tests/track_errors/test_divisions.py index bbcb1955..609c8ed5 100644 --- a/tests/track_errors/test_divisions.py +++ b/tests/track_errors/test_divisions.py @@ -1,7 +1,7 @@ import networkx as nx import numpy as np import pytest -from traccuracy import NodeAttr, TrackingGraph +from traccuracy import NodeFlag, TrackingGraph from traccuracy.matchers import Matched from traccuracy.track_errors.divisions import ( _classify_divisions, @@ -51,10 +51,10 @@ def test_classify_divisions_tp(g): # Test true positive _classify_divisions(matched_data) - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes["2_2"] - assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 + assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] + assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] # Check division flag assert matched_data.gt_graph.division_annotations @@ -80,10 +80,10 @@ def test_classify_divisions_fp(g): _classify_divisions(matched_data) - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.FN_DIV)) == 0 - assert NodeAttr.FP_DIV in matched_data.pred_graph.nodes["1_2"] - assert NodeAttr.TP_DIV in matched_data.gt_graph.nodes["2_2"] - assert NodeAttr.TP_DIV in matched_data.pred_graph.nodes["2_2"] + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.FN_DIV)) == 0 + assert NodeFlag.FP_DIV in matched_data.pred_graph.nodes["1_2"] + assert NodeFlag.TP_DIV in matched_data.gt_graph.nodes["2_2"] + assert NodeFlag.TP_DIV in matched_data.pred_graph.nodes["2_2"] def test_classify_divisions_fn(g): @@ -100,9 +100,9 @@ def test_classify_divisions_fn(g): _classify_divisions(matched_data) - assert len(matched_data.pred_graph.get_nodes_with_flag(NodeAttr.FP_DIV)) == 0 - assert len(matched_data.gt_graph.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 - assert NodeAttr.FN_DIV in matched_data.gt_graph.nodes["2_2"] + assert len(matched_data.pred_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == 0 + assert len(matched_data.gt_graph.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 + assert NodeFlag.FN_DIV in matched_data.gt_graph.nodes["2_2"] @pytest.fixture @@ -160,8 +160,8 @@ class Test_correct_shifted_divisions: def test_no_change(self): # Early division in gt g_pred, g_gt, mapper = get_division_graphs() - g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True - g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True + g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -170,15 +170,15 @@ def test_no_change(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes["1_3"][NodeAttr.FP_DIV] is True - assert ng_gt.nodes["1_1"][NodeAttr.FN_DIV] is True - assert len(ng_gt.get_nodes_with_flag(NodeAttr.TP_DIV)) == 0 + assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is True + assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is True + assert len(ng_gt.get_nodes_with_flag(NodeFlag.TP_DIV)) == 0 def test_fn_early(self): # Early division in gt g_pred, g_gt, mapper = get_division_graphs() - g_gt.nodes["1_1"][NodeAttr.FN_DIV] = True - g_pred.nodes["1_3"][NodeAttr.FP_DIV] = True + g_gt.nodes["1_1"][NodeFlag.FN_DIV] = True + g_pred.nodes["1_3"][NodeFlag.FP_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -187,16 +187,16 @@ def test_fn_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes["1_3"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes["1_1"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes["1_3"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes["1_1"][NodeAttr.TP_DIV] is True + assert ng_pred.nodes["1_3"][NodeFlag.FP_DIV] is False + assert ng_gt.nodes["1_1"][NodeFlag.FN_DIV] is False + assert ng_pred.nodes["1_3"][NodeFlag.TP_DIV] is True + assert ng_gt.nodes["1_1"][NodeFlag.TP_DIV] is True def test_fp_early(self): # Early division in pred g_gt, g_pred, mapper = get_division_graphs() - g_pred.nodes["1_1"][NodeAttr.FP_DIV] = True - g_gt.nodes["1_3"][NodeAttr.FN_DIV] = True + g_pred.nodes["1_1"][NodeFlag.FP_DIV] = True + g_gt.nodes["1_3"][NodeFlag.FN_DIV] = True matched_data = Matched(TrackingGraph(g_gt), TrackingGraph(g_pred), mapper) @@ -205,10 +205,10 @@ def test_fp_early(self): ng_pred = new_matched.pred_graph ng_gt = new_matched.gt_graph - assert ng_pred.nodes["1_1"][NodeAttr.FP_DIV] is False - assert ng_gt.nodes["1_3"][NodeAttr.FN_DIV] is False - assert ng_pred.nodes["1_1"][NodeAttr.TP_DIV] is True - assert ng_gt.nodes["1_3"][NodeAttr.TP_DIV] is True + assert ng_pred.nodes["1_1"][NodeFlag.FP_DIV] is False + assert ng_gt.nodes["1_3"][NodeFlag.FN_DIV] is False + assert ng_pred.nodes["1_1"][NodeFlag.TP_DIV] is True + assert ng_gt.nodes["1_3"][NodeFlag.TP_DIV] is True def test_evaluate_division_events(): From 99933edf49344ca0f88d2e8988aabbc51bbefeec Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 30 Nov 2023 15:28:45 -0500 Subject: [PATCH 20/25] Add typing annotations to TrackingGraph --- src/traccuracy/_tracking_graph.py | 112 +++++++++++++++++------------- src/traccuracy/matchers/_ctc.py | 5 +- 2 files changed, 66 insertions(+), 51 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index 2065a191..c66a2125 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,10 +3,14 @@ import copy import enum import logging -from typing import Hashable +from typing import TYPE_CHECKING, Hashable import networkx as nx +if TYPE_CHECKING: + import numpy as np + from networkx.classes.reportviews import NodeView, OutEdgeView + logger = logging.getLogger(__name__) @@ -104,11 +108,11 @@ class TrackingGraph: def __init__( self, - graph, - segmentation=None, - frame_key="t", - label_key="segmentation_id", - location_keys=("x", "y"), + graph: nx.DiGraph, + segmentation: np.ndarray | None = None, + frame_key: str = "t", + label_key: str = "segmentation_id", + location_keys: tuple[str, ...] = ("x", "y"), ): """A directed graph representing a tracking solution where edges go forward in time. @@ -161,9 +165,13 @@ def __init__( self.graph = graph # construct dictionaries from attributes to nodes/edges for easy lookup - self.nodes_by_frame = {} - self.nodes_by_flag = {flag: set() for flag in NodeFlag} - self.edges_by_flag = {flag: set() for flag in EdgeFlag} + self.nodes_by_frame: dict[int, set[Hashable]] = {} + self.nodes_by_flag: dict[NodeFlag, set[Hashable]] = { + flag: set() for flag in NodeFlag + } + self.edges_by_flag: dict[EdgeFlag, set[tuple[Hashable, Hashable]]] = { + flag: set() for flag in EdgeFlag + } for node, attrs in self.graph.nodes.items(): # check that every node has the time frame and location specified assert ( @@ -181,15 +189,15 @@ def __init__( else: self.nodes_by_frame[frame].add(node) # store node id in nodes_by_flag mapping - for flag in NodeFlag: - if flag in attrs and attrs[flag]: - self.nodes_by_flag[flag].add(node) + for node_flag in NodeFlag: + if node_flag in attrs and attrs[node_flag]: + self.nodes_by_flag[node_flag].add(node) # store edge id in edges_by_flag for edge, attrs in self.graph.edges.items(): - for flag in EdgeFlag: - if flag in attrs and attrs[flag]: - self.edges_by_flag[flag].add(edge) + for edge_flag in EdgeFlag: + if edge_flag in attrs and attrs[edge_flag]: + self.edges_by_flag[edge_flag].add(edge) # Store first and last frames for reference self.start_frame = min(self.nodes_by_frame.keys()) @@ -201,7 +209,7 @@ def __init__( self.edge_errors = False @property - def nodes(self): + def nodes(self) -> NodeView: """Get all the nodes in the graph, along with their attributes. Returns: @@ -210,7 +218,7 @@ def nodes(self): return self.graph.nodes @property - def edges(self): + def edges(self) -> OutEdgeView: """Get all the edges in the graph, along with their attributes. Returns: @@ -219,7 +227,7 @@ def edges(self): """ return self.graph.edges - def get_nodes_in_frame(self, frame): + def get_nodes_in_frame(self, frame: int) -> list[Hashable]: """Get the node ids of all nodes in the given frame. Args: @@ -235,7 +243,7 @@ def get_nodes_in_frame(self, frame): else: return [] - def get_location(self, node_id): + def get_location(self, node_id: Hashable) -> list[float]: """Get the spatial location of the node with node_id using self.location_keys. Args: @@ -246,35 +254,35 @@ def get_location(self, node_id): """ return [self.graph.nodes[node_id][key] for key in self.location_keys] - def get_nodes_with_flag(self, attr): + def get_nodes_with_flag(self, flag: NodeFlag) -> list[Hashable]: """Get all nodes with specified NodeFlag set to True. Args: - attr (traccuracy.NodeFlag): the node attribute to query for + flag (traccuracy.NodeFlag): the node flag to query for Returns: - (List(hashable)): A list of node_ids which have the given attribute + (List(hashable)): A list of node_ids which have the given flag and the value is True. """ - if not isinstance(attr, NodeFlag): - raise ValueError(f"Function takes NodeFlag arguments, not {type(attr)}.") - return list(self.nodes_by_flag[attr]) + if not isinstance(flag, NodeFlag): + raise ValueError(f"Function takes NodeFlag arguments, not {type(flag)}.") + return list(self.nodes_by_flag[flag]) - def get_edges_with_flag(self, attr): + def get_edges_with_flag(self, flag: EdgeFlag) -> list[tuple[Hashable, Hashable]]: """Get all edges with specified EdgeFlag set to True. Args: - attr (traccuracy.EdgeFlag): the edge attribute to query for + flag (traccuracy.EdgeFlag): the edge flag to query for Returns: - (List(hashable)): A list of edge ids which have the given attribute + (List(hashable)): A list of edge ids which have the given flag and the value is True. """ - if not isinstance(attr, EdgeFlag): - raise ValueError(f"Function takes EdgeFlag arguments, not {type(attr)}.") - return list(self.edges_by_flag[attr]) + if not isinstance(flag, EdgeFlag): + raise ValueError(f"Function takes EdgeFlag arguments, not {type(flag)}.") + return list(self.edges_by_flag[flag]) - def get_divisions(self): + def get_divisions(self) -> list[Hashable]: """Get all nodes that have at least two edges pointing to the next time frame Returns: @@ -282,7 +290,7 @@ def get_divisions(self): """ return [node for node, degree in self.graph.out_degree() if degree >= 2] - def get_merges(self): + def get_merges(self) -> list[Hashable]: """Get all nodes that have at least two incoming edges from the previous time frame Returns: @@ -290,7 +298,7 @@ def get_merges(self): """ return [node for node, degree in self.graph.in_degree() if degree >= 2] - def get_preds(self, node): + def get_preds(self, node: Hashable) -> list[Hashable]: """Get all predecessors of the given node. A predecessor node is any node from a previous time point that has an edge to @@ -306,7 +314,7 @@ def get_preds(self, node): """ return [pred for pred, _ in self.graph.in_edges(node)] - def get_succs(self, node): + def get_succs(self, node: Hashable) -> list[Hashable]: """Get all successor nodes of the given node. A successor node is any node from a later time point that has an edge @@ -322,7 +330,7 @@ def get_succs(self, node): """ return [succ for _, succ in self.graph.out_edges(node)] - def get_connected_components(self): + def get_connected_components(self) -> list[TrackingGraph]: """Get a list of TrackingGraphs, each corresponding to one track (i.e., a connected component in the track graph). @@ -335,7 +343,7 @@ def get_connected_components(self): return [self.get_subgraph(g) for g in nx.weakly_connected_components(graph)] - def get_subgraph(self, nodes): + def get_subgraph(self, nodes: list[Hashable]) -> TrackingGraph: """Returns a new TrackingGraph with the subgraph defined by the list of nodes Args: @@ -352,21 +360,23 @@ def get_subgraph(self, nodes): else: del new_trackgraph.nodes_by_frame[frame] - for flag in NodeFlag: - new_trackgraph.nodes_by_flag[flag] = self.nodes_by_flag[flag].intersection( - nodes - ) - for flag in EdgeFlag: - new_trackgraph.edges_by_flag[flag] = self.edges_by_flag[flag].intersection( - nodes - ) + for node_flag in NodeFlag: + new_trackgraph.nodes_by_flag[node_flag] = self.nodes_by_flag[ + node_flag + ].intersection(nodes) + for edge_flag in EdgeFlag: + new_trackgraph.edges_by_flag[edge_flag] = self.edges_by_flag[ + edge_flag + ].intersection(nodes) new_trackgraph.start_frame = min(new_trackgraph.nodes_by_frame.keys()) new_trackgraph.end_frame = max(new_trackgraph.nodes_by_frame.keys()) + 1 return new_trackgraph - def set_flag_on_node(self, _id: Hashable, flag: NodeFlag, value: bool = True): + def set_flag_on_node( + self, _id: Hashable, flag: NodeFlag, value: bool = True + ) -> None: """Set an attribute flag for a single node. If the id is not found in the graph, a KeyError will be raised. If the flag already exists, the existing value will be overwritten. @@ -394,7 +404,7 @@ def set_flag_on_node(self, _id: Hashable, flag: NodeFlag, value: bool = True): else: self.nodes_by_flag[flag].discard(_id) - def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True): + def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True) -> None: """Set an attribute flag for all nodes in the graph. If the flag already exists, the existing values will be overwritten. @@ -421,7 +431,7 @@ def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True): def set_flag_on_edge( self, _id: tuple[Hashable, Hashable], flag: EdgeFlag, value: bool = True - ): + ) -> None: """Set an attribute flag for an edge. If the flag already exists, the existing value will be overwritten. @@ -448,7 +458,7 @@ def set_flag_on_edge( else: self.edges_by_flag[flag].discard(_id) - def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True): + def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True) -> None: """Set an attribute flag for all edges in the graph. If the flag already exists, the existing values will be overwritten. @@ -474,7 +484,9 @@ def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True): else: self.edges_by_flag[flag] = set() - def get_tracklets(self, include_division_edges: bool = False): + def get_tracklets( + self, include_division_edges: bool = False + ) -> list[TrackingGraph]: """Gets a list of new TrackingGraph objects containing all tracklets of the current graph. Tracklet is defined as all connected components between divisions (daughter to next diff --git a/src/traccuracy/matchers/_ctc.py b/src/traccuracy/matchers/_ctc.py index 23f02be8..43c33f8b 100644 --- a/src/traccuracy/matchers/_ctc.py +++ b/src/traccuracy/matchers/_ctc.py @@ -37,7 +37,7 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): traccuracy.matchers.Matched: Matched data object containing the CTC mapping Raises: - ValueError: GT and pred segmentations must be the same shape + ValueError: if GT and pred segmentations are None or are not the same shape """ gt = gt_graph pred = pred_graph @@ -46,6 +46,9 @@ def _compute_mapping(self, gt_graph: TrackingGraph, pred_graph: TrackingGraph): G_gt, mask_gt = gt, gt.segmentation G_pred, mask_pred = pred, pred.segmentation + if mask_gt is None or mask_pred is None: + raise ValueError("Segmentation is None, cannot perform matching") + if mask_gt.shape != mask_pred.shape: raise ValueError("Segmentation shapes must match between gt and pred") From b68e3a863b881d97f820320c88035b75c8f9ce1c Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 5 Dec 2023 16:07:53 -0500 Subject: [PATCH 21/25] Return set from TrackingGraph node/edge_by_flag I tried to annotate the return type as a generic Iterable, to match the networkx conventions, but we do call `len` on it, so I stuck to the specific set type annotation. --- src/traccuracy/_tracking_graph.py | 26 +++++++++++++------------- tests/test_tracking_graph.py | 16 ++++++++++------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index c66a2125..96af931a 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -3,7 +3,7 @@ import copy import enum import logging -from typing import TYPE_CHECKING, Hashable +from typing import TYPE_CHECKING, Hashable, Iterable import networkx as nx @@ -227,21 +227,21 @@ def edges(self) -> OutEdgeView: """ return self.graph.edges - def get_nodes_in_frame(self, frame: int) -> list[Hashable]: + def get_nodes_in_frame(self, frame: int) -> set[Hashable]: """Get the node ids of all nodes in the given frame. Args: frame (int): The frame to return all node ids for. If the provided frame is outside of the range - (self.start_frame, self.end_frame), returns an empty list. + (self.start_frame, self.end_frame), returns an empty iterable. Returns: - list of node_ids: A list of node ids for all nodes in frame. + Iterable[Hashable]: An iterable of node ids for all nodes in frame. """ if frame in self.nodes_by_frame.keys(): - return list(self.nodes_by_frame[frame]) + return self.nodes_by_frame[frame] else: - return [] + return set() def get_location(self, node_id: Hashable) -> list[float]: """Get the spatial location of the node with node_id using self.location_keys. @@ -254,33 +254,33 @@ def get_location(self, node_id: Hashable) -> list[float]: """ return [self.graph.nodes[node_id][key] for key in self.location_keys] - def get_nodes_with_flag(self, flag: NodeFlag) -> list[Hashable]: + def get_nodes_with_flag(self, flag: NodeFlag) -> set[Hashable]: """Get all nodes with specified NodeFlag set to True. Args: flag (traccuracy.NodeFlag): the node flag to query for Returns: - (List(hashable)): A list of node_ids which have the given flag + (List(hashable)): An iterable of node_ids which have the given flag and the value is True. """ if not isinstance(flag, NodeFlag): raise ValueError(f"Function takes NodeFlag arguments, not {type(flag)}.") - return list(self.nodes_by_flag[flag]) + return self.nodes_by_flag[flag] - def get_edges_with_flag(self, flag: EdgeFlag) -> list[tuple[Hashable, Hashable]]: + def get_edges_with_flag(self, flag: EdgeFlag) -> set[tuple[Hashable, Hashable]]: """Get all edges with specified EdgeFlag set to True. Args: flag (traccuracy.EdgeFlag): the edge flag to query for Returns: - (List(hashable)): A list of edge ids which have the given flag + (List(hashable)): An iterable of edge ids which have the given flag and the value is True. """ if not isinstance(flag, EdgeFlag): raise ValueError(f"Function takes EdgeFlag arguments, not {type(flag)}.") - return list(self.edges_by_flag[flag]) + return self.edges_by_flag[flag] def get_divisions(self) -> list[Hashable]: """Get all nodes that have at least two edges pointing to the next time frame @@ -343,7 +343,7 @@ def get_connected_components(self) -> list[TrackingGraph]: return [self.get_subgraph(g) for g in nx.weakly_connected_components(graph)] - def get_subgraph(self, nodes: list[Hashable]) -> TrackingGraph: + def get_subgraph(self, nodes: Iterable[Hashable]) -> TrackingGraph: """Returns a new TrackingGraph with the subgraph defined by the list of nodes Args: diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 09100561..028919ec 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -135,21 +135,25 @@ def test_constructor(nx_comp1): def test_get_cells_by_frame(simple_graph): - assert simple_graph.get_nodes_in_frame(0) == ["1_0"] + assert Counter(simple_graph.get_nodes_in_frame(0)) == Counter({"1_0"}) assert Counter(simple_graph.get_nodes_in_frame(2)) == Counter(["1_2", "1_3"]) - assert simple_graph.get_nodes_in_frame(5) == [] + assert Counter(simple_graph.get_nodes_in_frame(5)) == Counter([]) def test_get_nodes_with_flag(simple_graph): - assert simple_graph.get_nodes_with_flag(NodeFlag.TP_DIV) == ["1_1"] - assert simple_graph.get_nodes_with_flag(NodeFlag.FP_DIV) == [] + assert Counter(simple_graph.get_nodes_with_flag(NodeFlag.TP_DIV)) == Counter( + ["1_1"] + ) + assert Counter(simple_graph.get_nodes_with_flag(NodeFlag.FP_DIV)) == Counter([]) with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp_division") def test_get_edges_with_flag(simple_graph): - assert simple_graph.get_edges_with_flag(EdgeFlag.TRUE_POS) == [("1_0", "1_1")] - assert simple_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG) == [] + assert Counter(simple_graph.get_edges_with_flag(EdgeFlag.TRUE_POS)) == Counter( + [("1_0", "1_1")] + ) + assert Counter(simple_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)) == Counter([]) with pytest.raises(ValueError): assert simple_graph.get_nodes_with_flag("is_tp") From f8e49735b76be0975a74478ed6b0f4139fa4004f Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 8 Jan 2024 13:38:10 -0500 Subject: [PATCH 22/25] Separate out and test helper function in iou matcher --- src/traccuracy/matchers/_iou.py | 54 ++++++++++++++++++++------------- tests/matchers/test_iou.py | 25 ++++++++++++++- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/src/traccuracy/matchers/_iou.py b/src/traccuracy/matchers/_iou.py index f2153055..5bcd7c39 100644 --- a/src/traccuracy/matchers/_iou.py +++ b/src/traccuracy/matchers/_iou.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Hashable + import numpy as np from tqdm import tqdm @@ -46,6 +48,35 @@ def _match_nodes(gt, res, threshold=1): return gtcells, rescells +def _construct_time_to_seg_id_map( + graph: TrackingGraph, +) -> dict[int, dict[Hashable, Hashable]]: + """For each time frame in the graph, create a mapping from segmentation ids + (the ids in the segmentation array, stored in graph.label_key) to the + node ids (the ids of the TrackingGraph nodes). + + Args: + graph(TrackingGraph): a tracking graph with a label_key on each node + + Returns: + dict[int, dict[Hashable, Hashable]]: a dictionary from {time: {segmentation_id: node_id}} + + Raises: + AssertionError: If two nodes in a time frame have the same segmentation_id + """ + time_to_seg_id_map: dict[int, dict[Hashable, Hashable]] = {} + for node_id, data in graph.nodes(data=True): + time = data[graph.frame_key] + seg_id = data[graph.label_key] + seg_id_to_node_id_map = time_to_seg_id_map.get(time, {}) + assert ( + seg_id not in seg_id_to_node_id_map + ), f"Segmentation ID {seg_id} occurred twice in frame {time}." + seg_id_to_node_id_map[seg_id] = node_id + time_to_seg_id_map[time] = seg_id_to_node_id_map + return time_to_seg_id_map + + def match_iou(gt, pred, threshold=0.6): """Identifies pairs of cells between gt and pred that have iou > threshold @@ -79,27 +110,8 @@ def match_iou(gt, pred, threshold=0.6): frame_range = range(gt.start_frame, gt.end_frame) total = len(list(frame_range)) - def construct_time_to_seg_id_map(graph): - """ - Args: - graph(TrackingGraph) - - Returns a dictionary {time: {segmentation_id: node_id}} - """ - time_to_seg_id_map = {} - for node_id, data in graph.nodes(data=True): - time = data[graph.frame_key] - seg_id = data[graph.label_key] - seg_id_to_node_id_map = time_to_seg_id_map.get(time, {}) - assert ( - seg_id not in seg_id_to_node_id_map - ), f"Segmentation ID {seg_id} occurred twice in frame {time}." - seg_id_to_node_id_map[seg_id] = node_id - time_to_seg_id_map[time] = seg_id_to_node_id_map - return time_to_seg_id_map - - gt_time_to_seg_id_map = construct_time_to_seg_id_map(gt) - pred_time_to_seg_id_map = construct_time_to_seg_id_map(pred) + gt_time_to_seg_id_map = _construct_time_to_seg_id_map(gt) + pred_time_to_seg_id_map = _construct_time_to_seg_id_map(pred) for i, t in tqdm(enumerate(frame_range), desc="Matching frames", total=total): matches = _match_nodes( diff --git a/tests/matchers/test_iou.py b/tests/matchers/test_iou.py index edf942b7..d1133a0d 100644 --- a/tests/matchers/test_iou.py +++ b/tests/matchers/test_iou.py @@ -2,7 +2,12 @@ import numpy as np import pytest from traccuracy._tracking_graph import TrackingGraph -from traccuracy.matchers._iou import IOUMatcher, _match_nodes, match_iou +from traccuracy.matchers._iou import ( + IOUMatcher, + _construct_time_to_seg_id_map, + _match_nodes, + match_iou, +) from tests.test_utils import get_annotated_image, get_movie_with_graph @@ -21,6 +26,24 @@ def test__match_nodes(): gtcells, rescells = _match_nodes(y1, y2) +def test__construct_time_to_seg_id_map(): + # Test 2d data + n_frames = 3 + n_labels = 3 + track_graph = get_movie_with_graph(ndims=3, n_frames=n_frames, n_labels=n_labels) + time_to_seg_id_map = _construct_time_to_seg_id_map(track_graph) + for t in range(n_frames): + for i in range(1, n_labels): + assert time_to_seg_id_map[t][i] == f"{i}_{t}" + + # Test 3d data + track_graph = get_movie_with_graph(ndims=4, n_frames=n_frames, n_labels=n_labels) + time_to_seg_id_map = _construct_time_to_seg_id_map(track_graph) + for t in range(n_frames): + for i in range(1, n_labels): + assert time_to_seg_id_map[t][i] == f"{i}_{t}" + + def test_match_iou(): # Bad input with pytest.raises(ValueError): From 87dda146a6d8beb5340a8d510ff49bd6c534c260 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 8 Jan 2024 13:51:42 -0500 Subject: [PATCH 23/25] Test nodes/edges_by_flag dict when updating flags --- tests/test_tracking_graph.py | 46 +++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 028919ec..1a863d1f 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -208,6 +208,16 @@ def test_get_and_set_flag_on_node(simple_graph): "is_tp_division": True, } + simple_graph.set_flag_on_node("1_0", NodeFlag.FALSE_POS, value=True) + assert simple_graph.nodes()["1_0"] == { + "id": "1_0", + "t": 0, + "y": 1, + "x": 1, + NodeFlag.FALSE_POS: True, + } + assert "1_0" in simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] + simple_graph.set_flag_on_node("1_0", NodeFlag.FALSE_POS, value=False) assert simple_graph.nodes()["1_0"] == { "id": "1_0", @@ -216,33 +226,47 @@ def test_get_and_set_flag_on_node(simple_graph): "x": 1, NodeFlag.FALSE_POS: False, } - - simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=False) - for node in simple_graph.nodes: - assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is False + assert "1_0" not in simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=True) for node in simple_graph.nodes: assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is True + assert Counter(simple_graph.nodes_by_flag[NodeFlag.FALSE_POS]) == Counter( + list(simple_graph.nodes()) + ) + + simple_graph.set_flag_on_all_nodes(NodeFlag.FALSE_POS, value=False) + for node in simple_graph.nodes: + assert simple_graph.nodes[node][NodeFlag.FALSE_POS] is False + assert not simple_graph.nodes_by_flag[NodeFlag.FALSE_POS] with pytest.raises(ValueError): simple_graph.set_flag_on_node("1_0", "x", 2) def test_get_and_set_flag_on_edge(simple_graph): - print(simple_graph.edges()) - assert EdgeFlag.TRUE_POS not in simple_graph.edges()[("1_1", "1_3")] + edge_id = ("1_1", "1_3") + assert EdgeFlag.TRUE_POS not in simple_graph.edges()[edge_id] - simple_graph.set_flag_on_edge(("1_1", "1_3"), EdgeFlag.TRUE_POS, value=False) - assert simple_graph.edges()[("1_1", "1_3")][EdgeFlag.TRUE_POS] is False + simple_graph.set_flag_on_edge(edge_id, EdgeFlag.TRUE_POS, value=True) + assert simple_graph.edges()[edge_id][EdgeFlag.TRUE_POS] is True + assert edge_id in simple_graph.edges_by_flag[EdgeFlag.TRUE_POS] - simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=False) - for edge in simple_graph.edges: - assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is False + simple_graph.set_flag_on_edge(edge_id, EdgeFlag.TRUE_POS, value=False) + assert simple_graph.edges()[edge_id][EdgeFlag.TRUE_POS] is False + assert edge_id not in simple_graph.edges_by_flag[EdgeFlag.TRUE_POS] simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=True) for edge in simple_graph.edges: assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is True + assert Counter(simple_graph.edges_by_flag[EdgeFlag.FALSE_POS]) == Counter( + list(simple_graph.edges) + ) + + simple_graph.set_flag_on_all_edges(EdgeFlag.FALSE_POS, value=False) + for edge in simple_graph.edges: + assert simple_graph.edges[edge][EdgeFlag.FALSE_POS] is False + assert not simple_graph.edges_by_flag[EdgeFlag.FALSE_POS] with pytest.raises(ValueError): simple_graph.set_flag_on_edge(("1_1", "1_3"), "x", 2) From 528323bff7f68879628285992863fdf241209ef0 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Mon, 8 Jan 2024 13:58:55 -0500 Subject: [PATCH 24/25] Remove get from test names for flag setting --- tests/test_tracking_graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 1a863d1f..05964054 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -198,7 +198,7 @@ def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): assert track2.graph.edges == nx_comp2.edges -def test_get_and_set_flag_on_node(simple_graph): +def test_set_flag_on_node(simple_graph): assert simple_graph.nodes()["1_0"] == {"id": "1_0", "t": 0, "y": 1, "x": 1} assert simple_graph.nodes()["1_1"] == { "id": "1_1", @@ -244,7 +244,7 @@ def test_get_and_set_flag_on_node(simple_graph): simple_graph.set_flag_on_node("1_0", "x", 2) -def test_get_and_set_flag_on_edge(simple_graph): +def test_set_flag_on_edge(simple_graph): edge_id = ("1_1", "1_3") assert EdgeFlag.TRUE_POS not in simple_graph.edges()[edge_id] From f036f2835b764e0da49ec2f358c260644373b347 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Wed, 10 Jan 2024 14:14:40 -0500 Subject: [PATCH 25/25] Remove get_preds and get_succs --- src/traccuracy/_tracking_graph.py | 34 +----------------------- src/traccuracy/track_errors/_ctc.py | 4 +-- src/traccuracy/track_errors/divisions.py | 16 +++++------ tests/test_tracking_graph.py | 17 ------------ 4 files changed, 11 insertions(+), 60 deletions(-) diff --git a/src/traccuracy/_tracking_graph.py b/src/traccuracy/_tracking_graph.py index a8d78019..f6dd21fc 100644 --- a/src/traccuracy/_tracking_graph.py +++ b/src/traccuracy/_tracking_graph.py @@ -302,38 +302,6 @@ def get_merges(self) -> list[Hashable]: """ return [node for node, degree in self.graph.in_degree() if degree >= 2] - def get_preds(self, node: Hashable) -> list[Hashable]: - """Get all predecessors of the given node. - - A predecessor node is any node from a previous time point that has an edge to - the given node. In a case where merges are not allowed, each node will have a - maximum of one predecessor. - - Args: - node (hashable): A node id - - Returns: - list of hashable: A list of node ids containing all nodes that - have an edge to the given node. - """ - return [pred for pred, _ in self.graph.in_edges(node)] - - def get_succs(self, node: Hashable) -> list[Hashable]: - """Get all successor nodes of the given node. - - A successor node is any node from a later time point that has an edge - from the given node. In a case where divisions are not allowed, - a node will have a maximum of one successor. - - Args: - node (hashable): A node id - - Returns: - list of hashable: A list of node ids containing all nodes that have - an edge from the given node. - """ - return [succ for _, succ in self.graph.out_edges(node)] - def get_connected_components(self) -> list[TrackingGraph]: """Get a list of TrackingGraphs, each corresponding to one track (i.e., a connected component in the track graph). @@ -506,7 +474,7 @@ def get_tracklets( # Remove all intertrack edges from a copy of the original graph removed_edges = [] for parent in self.get_divisions(): - for daughter in self.get_succs(parent): + for daughter in self.graph.successors(parent): graph_copy.remove_edge(parent, daughter) removed_edges.append((parent, daughter)) diff --git a/src/traccuracy/track_errors/_ctc.py b/src/traccuracy/track_errors/_ctc.py index cf8a3c6c..01a7c7bb 100644 --- a/src/traccuracy/track_errors/_ctc.py +++ b/src/traccuracy/track_errors/_ctc.py @@ -104,13 +104,13 @@ def get_edge_errors(matched_data: Matched): graph.set_flag_on_all_edges(EdgeFlag.INTERTRACK_EDGE, False) for parent in graph.get_divisions(): - for daughter in graph.get_succs(parent): + for daughter in graph.graph.successors(parent): graph.set_flag_on_edge( (parent, daughter), EdgeFlag.INTERTRACK_EDGE, True ) for merge in graph.get_merges(): - for parent in graph.get_preds(merge): + for parent in graph.graph.predecessors(merge): graph.set_flag_on_edge((parent, merge), EdgeFlag.INTERTRACK_EDGE, True) # fp edges - edges in induced_graph that aren't in gt_graph diff --git a/src/traccuracy/track_errors/divisions.py b/src/traccuracy/track_errors/divisions.py index cb2b5e50..a0c41b0e 100644 --- a/src/traccuracy/track_errors/divisions.py +++ b/src/traccuracy/track_errors/divisions.py @@ -65,10 +65,10 @@ def _find_pred_node_matches(pred_node): g_gt.set_flag_on_node(gt_node, NodeFlag.FN_DIV, True) # Check if the division has the correct daughters else: - succ_gt = g_gt.get_succs(gt_node) + succ_gt = g_gt.graph.successors(gt_node) # Map pred succ nodes onto gt, unmapped nodes will return as None succ_pred = [ - _find_pred_node_matches(n) for n in g_pred.get_succs(pred_node) + _find_pred_node_matches(n) for n in g_pred.graph.successors(pred_node) ] # If daughters are same, division is correct @@ -107,7 +107,7 @@ def _get_pred_by_t(g, node, delta_frames): hashable: Node key of predecessor in target frame """ for _ in range(delta_frames): - nodes = g.get_preds(node) + nodes = list(g.graph.predecessors(node)) # Exit if there are no predecessors if len(nodes) == 0: return None @@ -133,7 +133,7 @@ def _get_succ_by_t(g, node, delta_frames): hashable: Node id of successor """ for _ in range(delta_frames): - nodes = g.get_succs(node) + nodes = list(g.graph.successors(node)) # Exit if there are no successors another division if len(nodes) == 0 or len(nodes) >= 2: return None @@ -196,9 +196,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): # Check if daughters match fp_succ = [ _get_succ_by_t(g_pred, node, t_fn - t_fp) - for node in g_pred.get_succs(fp_node) + for node in g_pred.graph.successors(fp_node) ] - fn_succ = g_gt.get_succs(fn_node) + fn_succ = g_gt.graph.successors(fn_node) if Counter(fp_succ) != Counter(fn_succ): # Daughters don't match so division cannot match continue @@ -217,9 +217,9 @@ def _correct_shifted_divisions(matched_data: Matched, n_frames=1): # Check if daughters match fn_succ = [ _get_succ_by_t(g_gt, node, t_fp - t_fn) - for node in g_gt.get_succs(fn_node) + for node in g_gt.graph.successors(fn_node) ] - fp_succ = g_pred.get_succs(fp_node) + fp_succ = g_pred.graph.successors(fp_node) if Counter(fp_succ) != Counter(fn_succ): # Daughters don't match so division cannot match continue diff --git a/tests/test_tracking_graph.py b/tests/test_tracking_graph.py index 05964054..045584a4 100644 --- a/tests/test_tracking_graph.py +++ b/tests/test_tracking_graph.py @@ -166,23 +166,6 @@ def test_get_merges(merge_graph): assert merge_graph.get_merges() == ["3_2"] -def test_get_preds(simple_graph, merge_graph): - # Division graph - assert simple_graph.get_preds("1_0") == [] - assert simple_graph.get_preds("1_1") == ["1_0"] - assert simple_graph.get_preds("1_2") == ["1_1"] - - # Merge graph - assert merge_graph.get_preds("3_3") == ["3_2"] - assert merge_graph.get_preds("3_2") == ["3_1", "3_5"] - - -def test_get_succs(simple_graph): - assert simple_graph.get_succs("1_0") == ["1_1"] - assert Counter(simple_graph.get_succs("1_1")) == Counter(["1_2", "1_3"]) - assert simple_graph.get_succs("1_2") == [] - - def test_get_connected_components(complex_graph, nx_comp1, nx_comp2): tracks = complex_graph.get_connected_components() assert len(tracks) == 2