Skip to content

Commit

Permalink
Start implementation for automated tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Aug 27, 2024
1 parent 36e191c commit 3fe600c
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 26 deletions.
9 changes: 7 additions & 2 deletions examples/annotator_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,22 @@ def track_ctc_data(use_finetuned_model):
if use_finetuned_model:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr")
model_type = "vit_b_lm"
precompute_amg_state = True
else:
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr")
model_type = "vit_h"
precompute_amg_state = False

# start the annotator with cached embeddings
annotator_tracking(timeseries, embedding_path=embedding_path, model_type=model_type)
annotator_tracking(
timeseries, embedding_path=embedding_path, model_type=model_type,
precompute_amg_state=precompute_amg_state,
)


def main():
# Whether to use the fine-tuned SAM model.
use_finetuned_model = False
use_finetuned_model = True
track_ctc_data(use_finetuned_model)


Expand Down
188 changes: 168 additions & 20 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,34 @@ def merge_instance_segmentation_3d(
return segmentation


# TODO: Enable tiling
def _segment_slices(
data, predictor, segmentor, embedding_path, verbose, with_background=True, **kwargs
):
assert data.ndim == 3

image_embeddings = util.precompute_image_embeddings(predictor, data, save_path=embedding_path, ndim=3)

offset = 0
segmentation = np.zeros(data.shape, dtype="uint32")

min_object_size = kwargs.pop("min_object_size", 0)
for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i)
seg = segmentor.generate(**kwargs)
if len(seg) == 0:
continue
else:
seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size)
max_z = seg.max()
if max_z == 0:
continue
seg[seg != 0] += offset
offset = max_z + offset
segmentation[i] = seg

return segmentation


def automatic_3d_segmentation(
volume: np.ndarray,
predictor: SamPredictor,
Expand Down Expand Up @@ -386,28 +413,149 @@ def automatic_3d_segmentation(
Returns:
The segmentation.
"""
offset = 0
segmentation = np.zeros(volume.shape, dtype="uint32")

min_object_size = kwargs.pop("min_object_size", 0)
image_embeddings = util.precompute_image_embeddings(predictor, volume, save_path=embedding_path, ndim=3)

for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i)
seg = segmentor.generate(**kwargs)
if len(seg) == 0:
continue
else:
seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size)
max_z = seg.max()
if max_z == 0:
continue
seg[seg != 0] += offset
offset = max_z + offset
segmentation[i] = seg
segmentation = _segment_slices(
volume, predictor, segmentor, embedding_path, verbose, with_background=with_background, **kwargs
)

segmentation = merge_instance_segmentation_3d(
segmentation, beta=0.5, with_background=with_background, gap_closing=gap_closing, min_z_extent=min_z_extent
)

return segmentation


def _filter_tracks(tracking_result, min_track_length):
props = regionprops(tracking_result)
discard_ids = []
for prop in props:
label_id = prop.label
z_start, z_stop = prop.bbox[0], prop.bbox[3]
if z_stop - z_start < min_track_length:
discard_ids.append(label_id)
tracking_result[np.isin(tracking_result, discard_ids)] = 0
tracking_result, _, _ = relabel_sequential(tracking_result)
return tracking_result


def _parse_result(slice_segmentation, solver, graph):
import networkx as nx
import motile
from nifty.tools import takeDict

lineage_graph = nx.DiGraph()

node_indicators = solver.get_variables(motile.variables.NodeSelected)
edge_indicators = solver.get_variables(motile.variables.EdgeSelected)

# build new graphs that contain the selected nodes and tracking / lineage results
for node, index in node_indicators.items():
if solver.solution[index] > 0.5:
lineage_graph.add_node(node, **graph.nodes[node])

for edge, index in edge_indicators.items():
if solver.solution[index] > 0.5:
lineage_graph.add_edge(*edge, **graph.edges[edge])

# Use connected components to find the lineages in the result graph.
components = nx.weakly_connected_components(lineage_graph)

# Compute the track assignments and the lineages
# (according to the representation expected by micro_sam)
track_assignment = {}
lineages = {}

# Initialize the track id with 1.
track_id = 1

# Iterate over all lineages in the graph.
for lineage_nodes in components:

# Extract the sub-graph for thhis lineage, make sure it's a tree and
# then find its root node.
node_list = sorted(list(lineage_nodes))
sub_graph = lineage_graph.subgraph(node_list)
assert nx.is_tree(sub_graph)
root = [node for node in sub_graph.nodes() if sub_graph.in_degree(node) == 0]
assert len(root) == 1
root = root[0]

# Perform depth first search over the graph to map all nodes
# to their track id and build the lineage information for this sub-graph.
for u, v in nx.dfs_edges(sub_graph, root):
# Assign u to the current track_id if it has not been assigned yet.
if u not in track_assignment:
track_assignment[u] = track_id

degree_u = sub_graph.out_degree(u)
assert degree_u in (1, 2) # The only allowed degrees for u.
if degree_u == 2: # Division -> increase track id and record lineage.
track_id += 1
mother_track = track_assignment[u]
if mother_track in lineages:
lineages[mother_track].append(track_id)
else:
lineages[mother_track] = [track_id]

degree_v = sub_graph.out_degree(v)
assert degree_v in (0, 1, 2) # The only allowed degrees for v.
if degree_v == 0: # The track stops here. Assign v to the track id and increase it.
track_assignment[v] = track_id
track_id += 1

# Recolor the segmentation according to the track assignment.

# Map non-selected nodes and backround to zero
seg_ids = np.unique(slice_segmentation)
not_selected = list(set(seg_ids) - set(track_assignment.keys()))
track_assignment.update({not_select: 0 for not_select in not_selected})
track_assignment[0] = 0

segmentation = takeDict(track_assignment, slice_segmentation)
return segmentation, lineages


def track_across_frames(
slice_segmentation: np.ndarray,
gap_closing: Optional[int] = None,
min_time_extent: Optional[int] = None,
verbose: bool = True,
pbar_init: Optional[callable] = None,
pbar_update: Optional[callable] = None,
):
"""TODO
"""
# from elf.tracking.tracking_utils import preprocess_closing
from elf.tracking.motile_tracking import _track_with_motile_impl

_, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update)

if gap_closing is not None and gap_closing > 0:
slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update)

solver, graph = _track_with_motile_impl(slice_segmentation, max_children=2)

segmentation, lineage = _parse_result(slice_segmentation, solver, graph)

return segmentation, lineage


def automatic_tracking(
timeseries: np.ndarray,
predictor: SamPredictor,
segmentor: AMGBase,
embedding_path: Optional[Union[str, os.PathLike]] = None,
gap_closing: Optional[int] = None,
min_time_extent: Optional[int] = None,
verbose: bool = True,
**kwargs,
):
"""TODO
"""

segmentation = _segment_slices(timeseries, predictor, segmentor, embedding_path, verbose, **kwargs)
segmentation, lineage = track_across_frames(
segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent,
verbose=verbose,
)

return segmentation, lineage
4 changes: 4 additions & 0 deletions micro_sam/sam_annotator/_tooltips.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"pred_iou_thresh": "Enter the threshold for filtering objects based on the predicted IOU.",
"stability_score_thresh": "Enter the threshold for filtering objects based on the stability score.",
},
"autotrack": {
"run_button": "Run automatic tracking.",
"run_tracking": "Choose if to run tracking for the whole timeseries or if to segment only the current timeframe."
},
"prompt_menu": {
"labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].",
},
Expand Down
89 changes: 87 additions & 2 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from . import util as vutil
from ._tooltips import get_tooltip
from .. import instance_segmentation, util
from ..multi_dimensional_segmentation import segment_mask_in_volume, merge_instance_segmentation_3d, PROJECTION_MODES
from ..multi_dimensional_segmentation import (
segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES
)


#
Expand Down Expand Up @@ -1413,7 +1415,7 @@ def _create_volumetric_switch(self):
return self._add_boolean_param(
"apply_to_volume", self.apply_to_volume, title="Apply to Volume",
tooltip=get_tooltip("autosegment", "apply_to_volume")
)
)

def _add_common_settings(self, settings):
# Create the UI element for min object size.
Expand Down Expand Up @@ -1646,3 +1648,86 @@ def __call__(self):
worker = self._run_segmentation_2d(kwargs)
_select_layer(self._viewer, "auto_segmentation")
return worker


class AutoTrackWidget(AutoSegmentWidget):
def _create_tracking_switch(self):
self.apply_to_volume = False
return self._add_boolean_param(
"apply_to_volume", self.apply_to_volume, title="Track Timeseries",
tooltip=get_tooltip("autotrack", "run_tracking")
)

def _create_widget(self):
# Add the switch for segmenting the slice vs. tracking the timeseries.
self.layout().addWidget(self._create_tracking_switch())

# Add the nested settings widget.
self.settings = self._create_settings()
self.layout().addWidget(self.settings)

# Add the run button.
self.run_button = QtWidgets.QPushButton("Automatic Tracking")
self.run_button.clicked.connect(self.__call__)
self.run_button.setToolTip(get_tooltip("autotrack", "run_button"))
self.layout().addWidget(self.run_button)

def _run_segmentation_3d(self, kwargs):
allow_segment_3d = self._allow_segment_3d()
if not allow_segment_3d:
val_results = {
"message_type": "error",
"message": "Tracking with AMG is only supported if you have a GPU."
}
return _generate_message(val_results["message_type"], val_results["message"])

pbar, pbar_signals = _create_pbar_for_threadworker()

@thread_worker
def seg_impl():
segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data)
offset = 0

def pbar_init(total, description):
pbar_signals.pbar_total.emit(total)
pbar_signals.pbar_description.emit(description)

pbar_init(segmentation.shape[0], "Run tracking")

# Further optimization: parallelize if state is precomputed for all slices
for i in range(segmentation.shape[0]):
seg = _instance_segmentation_impl(self.with_background, self.min_object_size, i=i, **kwargs)
seg_max = seg.max()
if seg_max == 0:
continue
seg[seg != 0] += offset
offset = seg_max + offset
segmentation[i] = seg
pbar_signals.pbar_update.emit(1)

pbar_signals.pbar_reset.emit()
segmentation, lineage = track_across_frames(
segmentation,
verbose=True, pbar_init=pbar_init,
pbar_update=lambda update: pbar_signals.pbar_update.emit(1),
)
pbar_signals.pbar_stop.emit()
return (segmentation, lineage)

# TODO update the tracking result
def update_segmentation(result):
segmentation, lineage = result
is_empty = segmentation.max() == 0
if is_empty:
self._empty_segmentation_warning()

state = AnnotatorState()
state.lineage = lineage

self._viewer.layers["auto_segmentation"].data = segmentation
self._viewer.layers["auto_segmentation"].refresh()

worker = seg_impl()
worker.returned.connect(update_segmentation)
worker.start()
return worker
Loading

0 comments on commit 3fe600c

Please sign in to comment.