Skip to content

Commit

Permalink
Update docstrings for tracks controller
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Jan 15, 2025
1 parent 1a5431c commit 3e38638
Showing 1 changed file with 57 additions and 54 deletions.
111 changes: 57 additions & 54 deletions src/motile_tracker/data_model/tracks_controller.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from collections.abc import Iterable
from typing import Any

import numpy as np
from motile_toolbox.candidate_graph import NodeAttr
from napari.utils.notifications import show_warning
from napari.utils.notifications import show_info, show_warning
from qtpy.QtWidgets import QMessageBox

from .action_history import ActionHistory
Expand Down Expand Up @@ -40,8 +39,9 @@ def add_nodes(
"""Calls the _add_nodes function to add nodes. Calls the refresh signal when finished.
Args:
nodes (np.ndarray[int]):an array of node ids
attributes (dict[str, np.ndarray]): dictionary containing at least time and position attributes
attributes (Attrs): dictionary containing at least time and position attributes
pixels (list[SegMask] | None, optional): The pixels associated with each node,
if a segmentation is present. Defaults to None.
"""
action, nodes = self._add_nodes(attributes, pixels)
self.action_history.add_new_action(action)
Expand All @@ -64,14 +64,12 @@ def _get_pred_and_succ(
track id, and the first node after time with the given track id,
or Nones if there are no such nodes.
"""
if track_id not in list(self.tracks.node_id_to_track_id.values()):
if (
track_id not in self.tracks.track_id_to_node
or len(self.tracks.track_id_to_node[track_id]) == 0
):
return None, None

candidates = [
node
for node, tid in self.tracks.node_id_to_track_id.items()
if tid == track_id and self.tracks.get_time(node) != time
]
candidates = self.tracks.track_id_to_node[track_id]
candidates.sort(key=lambda n: self.tracks.get_time(n))

pred = None
Expand Down Expand Up @@ -109,7 +107,7 @@ def _add_nodes(
predecessor and successor with the same track_id, if any)
Args:
attributes (Attributes): dictionary containing at least time and track id,
attributes (Attrs): dictionary containing at least time and track id,
and either node_id (if pixels are provided) or position (if not)
pixels (list[SegMask] | None): A list of pixels associated with the node,
or None if there is no segmentation. These pixels will be updated
Expand Down Expand Up @@ -195,19 +193,19 @@ def _add_nodes(

return ActionGroup(self.tracks, actions), nodes

def delete_nodes(self, nodes: Iterable[Any]) -> None:
def delete_nodes(self, nodes: Iterable[Node]) -> None:
"""Calls the _delete_nodes function and then emits the refresh signal
Args:
nodes (np.ndarray): array of node_ids to be deleted
nodes (Iterable[Node]): array of node_ids to be deleted
"""

action = self._delete_nodes(nodes)
self.action_history.add_new_action(action)
self.tracks.refresh.emit()

def _delete_nodes(
self, nodes: np.ndarray[Any], pixels: list[SegMask] | None = None
self, nodes: Iterable[Node], pixels: list[SegMask] | None = None
) -> TracksAction:
"""Delete the nodes provided by the array from the graph but maintain successor
track_ids. Reconnect to the nearest predecessor and/or nearest successor
Expand Down Expand Up @@ -269,46 +267,29 @@ def _delete_nodes(

return ActionGroup(self.tracks, actions=actions)

def update_node_segs(
self, nodes: Iterable[Node], attributes: dict[str, np.ndarray]
) -> None:
"""Calls the _update_node_segs function to update the node attributtes in given array.
Then calls the refresh signal.
Args:
nodes (np.ndarray[int]):an array of node ids
attributes (dict[str, np.ndarray]): dictionary containing the attributes to be updated
"""
action = self._update_node_segs(nodes, attributes)
self.action_history.add_new_action(action)
self.tracks.refresh.emit()

def update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs):
action = self._update_node_attrs(nodes, attributes)
self.action_history.add_new_action(action)
self.tracks.refresh.emit()

def _update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs):
return UpdateNodeAttrs(self.tracks, nodes, attributes)

def _update_node_segs(
self,
nodes: np.ndarray[Any],
nodes: Iterable[Node],
pixels: list[SegMask],
added=False,
) -> TracksAction:
"""Update the segmentation and segmentation-managed attributes for
a set of nodes.
Args:
nodes (np.ndarray[int]):an array of node ids
attributes (dict[str, np.ndarray]): dictionary containing the attributes to be updated
nodes (Iterable[Node]): The nodes to update
pixels (list[SegMask]): The pixels for each node that were edited
added (bool, optional): If the pixels were added to the nodes (True)
or deleted (False). Defaults to False. Cannot mix adding and removing
pixels in one call.
Returns:
TracksAction: _description_
"""
return UpdateNodeSegs(self.tracks, nodes, pixels, added=added)

def add_edges(self, edges: np.ndarray[int]) -> None:
"""Add edges and attributes to the graph. Also update the track ids and
"""Add edges to the graph. Also update the track ids and
corresponding segmentations if applicable
Args:
Expand All @@ -332,19 +313,43 @@ def add_edges(self, edges: np.ndarray[int]) -> None:
self.action_history.add_new_action(action)
self.tracks.refresh.emit()

def update_node_attrs(self, nodes: Iterable[Node], attributes: Attrs):
"""Update the user provided node attributes (not the managed attributes).
Also adds the action to the history and emits the refresh signal.
Args:
nodes (Iterable[Node]): The nodes to update the attributes for
attributes (Attrs): A mapping from user-provided attributes to values for
each node.
"""
action = self._update_node_attrs(nodes, attributes)
self.action_history.add_new_action(action)
self.tracks.refresh.emit()

def _update_node_attrs(
self, nodes: Iterable[Node], attributes: Attrs
) -> TracksAction:
"""Update the user provided node attributes (not the managed attributes).
Args:
nodes (Iterable[Node]): The nodes to update the attributes for
attributes (Attrs): A mapping from user-provided attributes to values for
each node.
Returns: A TracksAction object that performed the update
"""
return UpdateNodeAttrs(self.tracks, nodes, attributes)

def _add_edges(self, edges: np.ndarray[int]) -> TracksAction:
"""Add edges and attributes to the graph. Also update the track ids and
corresponding segmentations of the target node tracks and potentially sibling
tracks.
"""Add edges and attributes to the graph. Also update the track ids of the target
node tracks and potentially sibling tracks.
Args:
edges (np.array[int]): An Nx2 array of N edges, each with source and target
node ids
attributes (dict[str, np.ndarray]): dictionary mapping attribute names to
an array of values, where the index in the array matches the edge index
Returns:
True if the edges were successfully added, False if any edge was invalid.
A TracksAction containing all edits performed in this call
"""
actions = []
for edge in edges:
Expand Down Expand Up @@ -463,7 +468,7 @@ def delete_edges(self, edges: np.ndarray):
"""Delete edges from the graph.
Args:
edges (np.ndarray): _description_
edges (np.ndarray): The Nx2 array of edges to be deleted
"""

for edge in edges:
Expand Down Expand Up @@ -497,7 +502,7 @@ def update_segmentations(
to_remove: list[Node], # (node_ids, pixels)
to_update_smaller: list[tuple], # (node_id, pixels)
to_update_bigger: list[tuple], # (node_id, pixels)
to_add: list[tuple], # (seg_id, track_id, pixels)
to_add: list[tuple], # (node_id, track_id, pixels)
current_timepoint: int,
) -> None:
"""Handle a change in the segmentation mask, checking for node addition, deletion, and attribute updates.
Expand Down Expand Up @@ -528,14 +533,12 @@ def update_segmentations(
if len(to_add) > 0:
nodes = [node for node, _, _ in to_add]
pixels = [pix for _, _, pix in to_add]
seg_ids = [val for val, _, _ in to_add]
track_ids = [
val if val is not None else self.tracks.get_next_track_id()
for _, val, _ in to_add
]
times = [pix[0][0] for pix in pixels]
attributes = {
NodeAttr.SEG_ID.value: seg_ids,
NodeAttr.TRACK_ID.value: track_ids,
self.tracks.time_attr: times,
"node_id": nodes,
Expand Down Expand Up @@ -563,14 +566,14 @@ def undo(self) -> None:
if self.action_history.undo():
self.tracks.refresh.emit()
else:
show_warning("No more actions to undo")
show_info("No more actions to undo")

def redo(self) -> None:
"""Obtain the action to redo from the history"""
if self.action_history.redo():
self.tracks.refresh.emit()
else:
show_warning("No more actions to redo")
show_info("No more actions to redo")

def _get_new_node_ids(self, n: int) -> list[Node]:
"""Get a list of new node ids for creating new nodes.
Expand Down

0 comments on commit 3e38638

Please sign in to comment.