Skip to content

Commit

Permalink
Change from Node/EdgeAttr to Node/EdgeFlag
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Nov 29, 2023
1 parent a5e8d62 commit 6fc9015
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 186 deletions.
4 changes: 2 additions & 2 deletions src/traccuracy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
__version__ = "uninstalled"

from ._run_metrics import run_metrics
from ._tracking_graph import EdgeAttr, NodeAttr, TrackingGraph
from ._tracking_graph import EdgeFlag, NodeFlag, TrackingGraph

__all__ = ["TrackingGraph", "run_metrics", "NodeAttr", "EdgeAttr"]
__all__ = ["TrackingGraph", "run_metrics", "NodeFlag", "EdgeFlag"]
129 changes: 61 additions & 68 deletions src/traccuracy/_tracking_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@


@enum.unique
class NodeAttr(str, enum.Enum):
"""An enum containing all valid attributes that can be used to
annotate the nodes of a TrackingGraph. If new metrics require new
annotations, they should be added here to ensure strings do not overlap and
are standardized. Note that the user specified frame and location
class NodeFlag(str, enum.Enum):
"""An enum containing standard flags that are used to annotate the nodes
of a TrackingGraph. Note that the user specified frame and location
attributes are also valid node attributes that will be stored on the graph
and should not overlap with these values. Additionally, if a graph already
has annotations using these strings before becoming a TrackGraph,
has annotations using these strings before becoming a TrackingGraph,
this will likely ruin metrics computation!
"""

Expand Down Expand Up @@ -55,12 +53,10 @@ def has_value(cls, value):


@enum.unique
class EdgeAttr(str, enum.Enum):
"""An enum containing all valid attributes that can be used to
annotate the edges of a TrackingGraph. If new metrics require new
annotations, they should be added here to ensure strings do not overlap and
are standardized. Additionally, if a graph already
has annotations using these strings before becoming a TrackGraph,
class EdgeFlag(str, enum.Enum):
"""An enum containing standard flags that are used to
annotate the edges of a TrackingGraph. If a graph already has
annotations using these strings before becoming a TrackingGraph,
this will likely ruin metrics computation!
"""

Expand Down Expand Up @@ -118,14 +114,14 @@ def __init__(
forward in time.
If the provided graph already has annotations that are strings
included in NodeAttrs or EdgeAttrs, this will likely ruin
included in NodeFlags or EdgeFlags, this will likely ruin
metric computation!
Args:
graph (networkx.DiGraph): A directed graph representing a tracking
solution where edges go forward in time. If the graph already
has annotations that are strings included in NodeAttrs or
EdgeAttrs, this will likely ruin metrics computation!
has annotations that are strings included in NodeFlags or
EdgeFlags, this will likely ruin metrics computation!
segmentation (numpy-like array, optional): A numpy-like array of segmentations.
The location of each node in tracking_graph is assumed to be inside the
area of the corresponding segmentation. Defaults to None.
Expand All @@ -142,20 +138,20 @@ def __init__(
Defaults to ('x', 'y').
"""
self.segmentation = segmentation
if NodeAttr.has_value(frame_key):
if NodeFlag.has_value(frame_key):
raise ValueError(
f"Specified frame key {frame_key} is reserved for graph "
"annotation. Please change the frame key."
)
self.frame_key = frame_key
if label_key is not None and NodeAttr.has_value(label_key):
if label_key is not None and NodeFlag.has_value(label_key):
raise ValueError(
f"Specified label key {label_key} is reserved for graph"
"annotation. Please change the label key."
)
self.label_key = label_key
for loc_key in location_keys:
if NodeAttr.has_value(loc_key):
if NodeFlag.has_value(loc_key):
raise ValueError(
f"Specified location key {loc_key} is reserved for graph"
"annotation. Please change the location key."
Expand All @@ -166,8 +162,8 @@ def __init__(

# construct dictionaries from attributes to nodes/edges for easy lookup
self.nodes_by_frame = {}
self.nodes_by_flag = {flag: set() for flag in NodeAttr}
self.edges_by_flag = {flag: set() for flag in EdgeAttr}
self.nodes_by_flag = {flag: set() for flag in NodeFlag}
self.edges_by_flag = {flag: set() for flag in EdgeFlag}
for node, attrs in self.graph.nodes.items():
# check that every node has the time frame and location specified
assert (
Expand All @@ -185,13 +181,13 @@ def __init__(
else:
self.nodes_by_frame[frame].add(node)
# store node id in nodes_by_flag mapping
for flag in NodeAttr:
for flag in NodeFlag:
if flag in attrs and attrs[flag]:
self.nodes_by_flag[flag].add(node)

# store edge id in edges_by_flag
for edge, attrs in self.graph.edges.items():
for flag in EdgeAttr:
for flag in EdgeFlag:
if flag in attrs and attrs[flag]:
self.edges_by_flag[flag].add(edge)

Expand Down Expand Up @@ -251,31 +247,31 @@ def get_location(self, node_id):
return [self.graph.nodes[node_id][key] for key in self.location_keys]

def get_nodes_with_flag(self, attr):
"""Get all nodes with specified NodeAttr set to True.
"""Get all nodes with specified NodeFlag set to True.
Args:
attr (traccuracy.NodeAttr): the node attribute to query for
attr (traccuracy.NodeFlag): the node attribute to query for
Returns:
(List(hashable)): A list of node_ids which have the given attribute
and the value is True.
"""
if not isinstance(attr, NodeAttr):
raise ValueError(f"Function takes NodeAttr arguments, not {type(attr)}.")
if not isinstance(attr, NodeFlag):
raise ValueError(f"Function takes NodeFlag arguments, not {type(attr)}.")
return list(self.nodes_by_flag[attr])

def get_edges_with_flag(self, attr):
"""Get all edges with specified EdgeAttr set to True.
"""Get all edges with specified EdgeFlag set to True.
Args:
attr (traccuracy.EdgeAttr): the edge attribute to query for
attr (traccuracy.EdgeFlag): the edge attribute to query for
Returns:
(List(hashable)): A list of edge ids which have the given attribute
and the value is True.
"""
if not isinstance(attr, EdgeAttr):
raise ValueError(f"Function takes EdgeAttr arguments, not {type(attr)}.")
if not isinstance(attr, EdgeFlag):
raise ValueError(f"Function takes EdgeFlag arguments, not {type(attr)}.")
return list(self.edges_by_flag[attr])

def get_divisions(self):
Expand Down Expand Up @@ -356,12 +352,12 @@ def get_subgraph(self, nodes):
else:
del new_trackgraph.nodes_by_frame[frame]

for attr in NodeAttr:
new_trackgraph.nodes_by_flag[attr] = self.nodes_by_flag[attr].intersection(
for flag in NodeFlag:
new_trackgraph.nodes_by_flag[flag] = self.nodes_by_flag[flag].intersection(
nodes
)
for attr in EdgeAttr:
new_trackgraph.edges_by_flag[attr] = self.edges_by_flag[attr].intersection(
for flag in EdgeFlag:
new_trackgraph.edges_by_flag[flag] = self.edges_by_flag[flag].intersection(
nodes
)

Expand All @@ -370,54 +366,52 @@ def get_subgraph(self, nodes):

return new_trackgraph

def set_flag_on_node(self, _id: Hashable, flag: NodeAttr, value: bool = True):
def set_flag_on_node(self, _id: Hashable, flag: NodeFlag, value: bool = True):
"""Set an attribute flag for a single node.
If the id is not found in the graph, a KeyError will be raised.
If the flag already exists, the existing value will be overwritten.
Args:
_id (Hashable): The node id on which to set the flag.
flag (traccuracy.NodeAttr): The node flag to set. Must be
of type NodeAttr - you may not not pass strings, even if they
are included in the NodeAttr enum values.
value (bool, optional): Attributes are flags and can only be set to
flag (traccuracy.NodeFlag): The node flag to set. Must be
of type NodeFlag - you may not not pass strings, even if they
are included in the NodeFlag enum values.
value (bool, optional): Flags can only be set to
True or False. Defaults to True.
Raises:
KeyError if the provided id is not in the graph.
ValueError if the provided flag is not a NodeAttr
ValueError if the provided flag is not a NodeFlag
"""
if not isinstance(flag, NodeAttr):
if not isinstance(flag, NodeFlag):
raise ValueError(
f"Provided flag {flag} is not of type NodeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
f"Provided flag {flag} is not of type NodeFlag. "
"Please use the enum instead of passing string values."
)
self.graph.nodes[_id][flag] = value
if value:
self.nodes_by_flag[flag].add(_id)
else:
self.nodes_by_flag[flag].discard(_id)

def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True):
def set_flag_on_all_nodes(self, flag: NodeFlag, value: bool = True):
"""Set an attribute flag for all nodes in the graph.
If the flag already exists, the existing values will be overwritten.
Args:
flag (traccuracy.NodeAttr): The node flag to set. Must be
of type NodeAttr - you may not not pass strings, even if they
are included in the NodeAttr enum values.
flag (traccuracy.NodeFlag): The node flag to set. Must be
of type NodeFlag - you may not not pass strings, even if they
are included in the NodeFlag enum values.
value (bool, optional): Flags can only be set to True or False.
Defaults to True.
Raises:
ValueError if the provided flag is not a NodeAttr.
ValueError if the provided flag is not a NodeFlag.
"""
if not isinstance(flag, NodeAttr):
if not isinstance(flag, NodeFlag):
raise ValueError(
f"Provided flag {flag} is not of type NodeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
f"Provided flag {flag} is not of type NodeFlag. "
"Please use the enum instead of passing string values."
)
nx.set_node_attributes(self.graph, value, name=flag)
if value:
Expand All @@ -426,52 +420,51 @@ def set_flag_on_all_nodes(self, flag: NodeAttr, value: bool = True):
self.nodes_by_flag[flag] = set()

def set_flag_on_edge(
self, _id: tuple[Hashable, Hashable], flag: EdgeAttr, value: bool = True
self, _id: tuple[Hashable, Hashable], flag: EdgeFlag, value: bool = True
):
"""Set an attribute flag for an edge.
If the flag already exists, the existing value will be overwritten.
Args:
ids (tuple[Hashable]): The edge id or list of edge ids
to set the attribute for. Edge ids are a 2-tuple of node ids.
flag (traccuracy.EdgeAttr): The edge flag to set. Must be
of type EdgeAttr - you may not pass strings, even if they are
included in the EdgeAttr enum values.
flag (traccuracy.EdgeFlag): The edge flag to set. Must be
of type EdgeFlag - you may not pass strings, even if they are
included in the EdgeFlag enum values.
value (bool): Flags can only be set to True or False.
Defaults to True.
Raises:
KeyError if edge with _id not in graph.
"""
if not isinstance(flag, EdgeAttr):
if not isinstance(flag, EdgeFlag):
raise ValueError(
f"Provided attribute {flag} is not of type EdgeAttr. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
f"Provided attribute {flag} is not of type EdgeFlag. "
"Please use the enum instead of passing string values."
)
self.graph.edges[_id][flag] = value
if value:
self.edges_by_flag[flag].add(_id)
else:
self.edges_by_flag[flag].discard(_id)

def set_flag_on_all_edges(self, flag: EdgeAttr, value: bool = True):
def set_flag_on_all_edges(self, flag: EdgeFlag, value: bool = True):
"""Set an attribute flag for all edges in the graph.
If the flag already exists, the existing values will be overwritten.
Args:
flag (traccuracy.EdgeAttr): The edge flag to set. Must be
of type EdgeAttr - you may not not pass strings, even if they
are included in the EdgeAttr enum values.
flag (traccuracy.EdgeFlag): The edge flag to set. Must be
of type EdgeFlag - you may not not pass strings, even if they
are included in the EdgeFlag enum values.
value (bool, optional): Flags can only be set to True or False.
Defaults to True.
Raises:
ValueError if the provided flag is not an EdgeAttr.
ValueError if the provided flag is not an EdgeFlag.
"""
if not isinstance(flag, EdgeAttr):
if not isinstance(flag, EdgeFlag):
raise ValueError(
f"Provided flag {flag} is not of type EdgeAttr. "
f"Provided flag {flag} is not of type EdgeFlag. "
"Please use the enum instead of passing string values, "
"and add new attributes to the class to avoid key collision."
)
Expand Down
14 changes: 7 additions & 7 deletions src/traccuracy/metrics/_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import EdgeAttr, NodeAttr
from traccuracy._tracking_graph import EdgeFlag, NodeFlag
from traccuracy.track_errors._ctc import evaluate_ctc_events

from ._base import Metric
Expand Down Expand Up @@ -38,14 +38,14 @@ def compute(self, data: Matched):
evaluate_ctc_events(data)

vertex_error_counts = {
"ns": len(data.pred_graph.get_nodes_with_flag(NodeAttr.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeAttr.FALSE_NEG)),
"ns": len(data.pred_graph.get_nodes_with_flag(NodeFlag.NON_SPLIT)),
"fp": len(data.pred_graph.get_nodes_with_flag(NodeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_nodes_with_flag(NodeFlag.FALSE_NEG)),
}
edge_error_counts = {
"ws": len(data.pred_graph.get_edges_with_flag(EdgeAttr.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeAttr.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeAttr.FALSE_NEG)),
"ws": len(data.pred_graph.get_edges_with_flag(EdgeFlag.WRONG_SEMANTIC)),
"fp": len(data.pred_graph.get_edges_with_flag(EdgeFlag.FALSE_POS)),
"fn": len(data.gt_graph.get_edges_with_flag(EdgeFlag.FALSE_NEG)),
}
error_sum = get_weighted_error_sum(
vertex_error_counts,
Expand Down
8 changes: 4 additions & 4 deletions src/traccuracy/metrics/_divisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from typing import TYPE_CHECKING

from traccuracy._tracking_graph import NodeAttr
from traccuracy._tracking_graph import NodeFlag
from traccuracy.track_errors.divisions import _evaluate_division_events

from ._base import Metric
Expand Down Expand Up @@ -90,9 +90,9 @@ def compute(self, data: Matched):
}

def _calculate_metrics(self, g_gt, g_pred):
tp_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.TP_DIV))
fn_division_count = len(g_gt.get_nodes_with_flag(NodeAttr.FN_DIV))
fp_division_count = len(g_pred.get_nodes_with_flag(NodeAttr.FP_DIV))
tp_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.TP_DIV))
fn_division_count = len(g_gt.get_nodes_with_flag(NodeFlag.FN_DIV))
fp_division_count = len(g_pred.get_nodes_with_flag(NodeFlag.FP_DIV))

try:
recall = tp_division_count / (tp_division_count + fn_division_count)
Expand Down
Loading

1 comment on commit 6fc9015

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Mean (s) BASE ef50a44 Mean (s) HEAD 6fc9015 Percent Change
test_load_gt_data 1.27863 1.38075 7.99
test_load_pred_data 1.16853 1.23938 6.06
test_ctc_matched 2.28306 2.42893 6.39
test_ctc_metrics 0.54236 0.506 -6.7
test_ctc_div_metrics 0.27516 0.28441 3.36
test_iou_matched 9.48363 9.10111 -4.03
test_iou_div_metrics 0.29817 0.27492 -7.8

Please sign in to comment.