From 3fe600cfb16d5ea761456bb1373f06e1718bee48 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 27 Aug 2024 21:23:13 +0200 Subject: [PATCH] Start implementation for automated tracking --- examples/annotator_tracking.py | 9 +- micro_sam/multi_dimensional_segmentation.py | 188 ++++++++++++++++-- micro_sam/sam_annotator/_tooltips.py | 4 + micro_sam/sam_annotator/_widgets.py | 89 ++++++++- micro_sam/sam_annotator/annotator_tracking.py | 16 +- test/test_multi_dimensional_segmentation.py | 18 ++ 6 files changed, 298 insertions(+), 26 deletions(-) diff --git a/examples/annotator_tracking.py b/examples/annotator_tracking.py index 003da5ce..dc0c0530 100644 --- a/examples/annotator_tracking.py +++ b/examples/annotator_tracking.py @@ -22,17 +22,22 @@ def track_ctc_data(use_finetuned_model): if use_finetuned_model: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr") model_type = "vit_b_lm" + precompute_amg_state = True else: embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr") model_type = "vit_h" + precompute_amg_state = False # start the annotator with cached embeddings - annotator_tracking(timeseries, embedding_path=embedding_path, model_type=model_type) + annotator_tracking( + timeseries, embedding_path=embedding_path, model_type=model_type, + precompute_amg_state=precompute_amg_state, + ) def main(): # Whether to use the fine-tuned SAM model. - use_finetuned_model = False + use_finetuned_model = True track_ctc_data(use_finetuned_model) diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 2c65d4f1..723a536d 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -353,7 +353,34 @@ def merge_instance_segmentation_3d( return segmentation -# TODO: Enable tiling +def _segment_slices( + data, predictor, segmentor, embedding_path, verbose, with_background=True, **kwargs +): + assert data.ndim == 3 + + image_embeddings = util.precompute_image_embeddings(predictor, data, save_path=embedding_path, ndim=3) + + offset = 0 + segmentation = np.zeros(data.shape, dtype="uint32") + + min_object_size = kwargs.pop("min_object_size", 0) + for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): + segmentor.initialize(data[i], image_embeddings=image_embeddings, verbose=False, i=i) + seg = segmentor.generate(**kwargs) + if len(seg) == 0: + continue + else: + seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) + max_z = seg.max() + if max_z == 0: + continue + seg[seg != 0] += offset + offset = max_z + offset + segmentation[i] = seg + + return segmentation + + def automatic_3d_segmentation( volume: np.ndarray, predictor: SamPredictor, @@ -386,28 +413,149 @@ def automatic_3d_segmentation( Returns: The segmentation. """ - offset = 0 - segmentation = np.zeros(volume.shape, dtype="uint32") - - min_object_size = kwargs.pop("min_object_size", 0) - image_embeddings = util.precompute_image_embeddings(predictor, volume, save_path=embedding_path, ndim=3) - - for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): - segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) - seg = segmentor.generate(**kwargs) - if len(seg) == 0: - continue - else: - seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) - max_z = seg.max() - if max_z == 0: - continue - seg[seg != 0] += offset - offset = max_z + offset - segmentation[i] = seg + segmentation = _segment_slices( + volume, predictor, segmentor, embedding_path, verbose, with_background=with_background, **kwargs + ) segmentation = merge_instance_segmentation_3d( segmentation, beta=0.5, with_background=with_background, gap_closing=gap_closing, min_z_extent=min_z_extent ) return segmentation + + +def _filter_tracks(tracking_result, min_track_length): + props = regionprops(tracking_result) + discard_ids = [] + for prop in props: + label_id = prop.label + z_start, z_stop = prop.bbox[0], prop.bbox[3] + if z_stop - z_start < min_track_length: + discard_ids.append(label_id) + tracking_result[np.isin(tracking_result, discard_ids)] = 0 + tracking_result, _, _ = relabel_sequential(tracking_result) + return tracking_result + + +def _parse_result(slice_segmentation, solver, graph): + import networkx as nx + import motile + from nifty.tools import takeDict + + lineage_graph = nx.DiGraph() + + node_indicators = solver.get_variables(motile.variables.NodeSelected) + edge_indicators = solver.get_variables(motile.variables.EdgeSelected) + + # build new graphs that contain the selected nodes and tracking / lineage results + for node, index in node_indicators.items(): + if solver.solution[index] > 0.5: + lineage_graph.add_node(node, **graph.nodes[node]) + + for edge, index in edge_indicators.items(): + if solver.solution[index] > 0.5: + lineage_graph.add_edge(*edge, **graph.edges[edge]) + + # Use connected components to find the lineages in the result graph. + components = nx.weakly_connected_components(lineage_graph) + + # Compute the track assignments and the lineages + # (according to the representation expected by micro_sam) + track_assignment = {} + lineages = {} + + # Initialize the track id with 1. + track_id = 1 + + # Iterate over all lineages in the graph. + for lineage_nodes in components: + + # Extract the sub-graph for thhis lineage, make sure it's a tree and + # then find its root node. + node_list = sorted(list(lineage_nodes)) + sub_graph = lineage_graph.subgraph(node_list) + assert nx.is_tree(sub_graph) + root = [node for node in sub_graph.nodes() if sub_graph.in_degree(node) == 0] + assert len(root) == 1 + root = root[0] + + # Perform depth first search over the graph to map all nodes + # to their track id and build the lineage information for this sub-graph. + for u, v in nx.dfs_edges(sub_graph, root): + # Assign u to the current track_id if it has not been assigned yet. + if u not in track_assignment: + track_assignment[u] = track_id + + degree_u = sub_graph.out_degree(u) + assert degree_u in (1, 2) # The only allowed degrees for u. + if degree_u == 2: # Division -> increase track id and record lineage. + track_id += 1 + mother_track = track_assignment[u] + if mother_track in lineages: + lineages[mother_track].append(track_id) + else: + lineages[mother_track] = [track_id] + + degree_v = sub_graph.out_degree(v) + assert degree_v in (0, 1, 2) # The only allowed degrees for v. + if degree_v == 0: # The track stops here. Assign v to the track id and increase it. + track_assignment[v] = track_id + track_id += 1 + + # Recolor the segmentation according to the track assignment. + + # Map non-selected nodes and backround to zero + seg_ids = np.unique(slice_segmentation) + not_selected = list(set(seg_ids) - set(track_assignment.keys())) + track_assignment.update({not_select: 0 for not_select in not_selected}) + track_assignment[0] = 0 + + segmentation = takeDict(track_assignment, slice_segmentation) + return segmentation, lineages + + +def track_across_frames( + slice_segmentation: np.ndarray, + gap_closing: Optional[int] = None, + min_time_extent: Optional[int] = None, + verbose: bool = True, + pbar_init: Optional[callable] = None, + pbar_update: Optional[callable] = None, +): + """TODO + """ + # from elf.tracking.tracking_utils import preprocess_closing + from elf.tracking.motile_tracking import _track_with_motile_impl + + _, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init=pbar_init, pbar_update=pbar_update) + + if gap_closing is not None and gap_closing > 0: + slice_segmentation = _preprocess_closing(slice_segmentation, gap_closing, pbar_update) + + solver, graph = _track_with_motile_impl(slice_segmentation, max_children=2) + + segmentation, lineage = _parse_result(slice_segmentation, solver, graph) + + return segmentation, lineage + + +def automatic_tracking( + timeseries: np.ndarray, + predictor: SamPredictor, + segmentor: AMGBase, + embedding_path: Optional[Union[str, os.PathLike]] = None, + gap_closing: Optional[int] = None, + min_time_extent: Optional[int] = None, + verbose: bool = True, + **kwargs, +): + """TODO + """ + + segmentation = _segment_slices(timeseries, predictor, segmentor, embedding_path, verbose, **kwargs) + segmentation, lineage = track_across_frames( + segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent, + verbose=verbose, + ) + + return segmentation, lineage diff --git a/micro_sam/sam_annotator/_tooltips.py b/micro_sam/sam_annotator/_tooltips.py index 068dd44d..9e104491 100644 --- a/micro_sam/sam_annotator/_tooltips.py +++ b/micro_sam/sam_annotator/_tooltips.py @@ -35,6 +35,10 @@ "pred_iou_thresh": "Enter the threshold for filtering objects based on the predicted IOU.", "stability_score_thresh": "Enter the threshold for filtering objects based on the stability score.", }, + "autotrack": { + "run_button": "Run automatic tracking.", + "run_tracking": "Choose if to run tracking for the whole timeseries or if to segment only the current timeframe." + }, "prompt_menu": { "labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].", }, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 50a190e0..37ff931b 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -27,7 +27,9 @@ from . import util as vutil from ._tooltips import get_tooltip from .. import instance_segmentation, util -from ..multi_dimensional_segmentation import segment_mask_in_volume, merge_instance_segmentation_3d, PROJECTION_MODES +from ..multi_dimensional_segmentation import ( + segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES +) # @@ -1413,7 +1415,7 @@ def _create_volumetric_switch(self): return self._add_boolean_param( "apply_to_volume", self.apply_to_volume, title="Apply to Volume", tooltip=get_tooltip("autosegment", "apply_to_volume") - ) + ) def _add_common_settings(self, settings): # Create the UI element for min object size. @@ -1646,3 +1648,86 @@ def __call__(self): worker = self._run_segmentation_2d(kwargs) _select_layer(self._viewer, "auto_segmentation") return worker + + +class AutoTrackWidget(AutoSegmentWidget): + def _create_tracking_switch(self): + self.apply_to_volume = False + return self._add_boolean_param( + "apply_to_volume", self.apply_to_volume, title="Track Timeseries", + tooltip=get_tooltip("autotrack", "run_tracking") + ) + + def _create_widget(self): + # Add the switch for segmenting the slice vs. tracking the timeseries. + self.layout().addWidget(self._create_tracking_switch()) + + # Add the nested settings widget. + self.settings = self._create_settings() + self.layout().addWidget(self.settings) + + # Add the run button. + self.run_button = QtWidgets.QPushButton("Automatic Tracking") + self.run_button.clicked.connect(self.__call__) + self.run_button.setToolTip(get_tooltip("autotrack", "run_button")) + self.layout().addWidget(self.run_button) + + def _run_segmentation_3d(self, kwargs): + allow_segment_3d = self._allow_segment_3d() + if not allow_segment_3d: + val_results = { + "message_type": "error", + "message": "Tracking with AMG is only supported if you have a GPU." + } + return _generate_message(val_results["message_type"], val_results["message"]) + + pbar, pbar_signals = _create_pbar_for_threadworker() + + @thread_worker + def seg_impl(): + segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) + offset = 0 + + def pbar_init(total, description): + pbar_signals.pbar_total.emit(total) + pbar_signals.pbar_description.emit(description) + + pbar_init(segmentation.shape[0], "Run tracking") + + # Further optimization: parallelize if state is precomputed for all slices + for i in range(segmentation.shape[0]): + seg = _instance_segmentation_impl(self.with_background, self.min_object_size, i=i, **kwargs) + seg_max = seg.max() + if seg_max == 0: + continue + seg[seg != 0] += offset + offset = seg_max + offset + segmentation[i] = seg + pbar_signals.pbar_update.emit(1) + + pbar_signals.pbar_reset.emit() + segmentation, lineage = track_across_frames( + segmentation, + verbose=True, pbar_init=pbar_init, + pbar_update=lambda update: pbar_signals.pbar_update.emit(1), + ) + pbar_signals.pbar_stop.emit() + return (segmentation, lineage) + + # TODO update the tracking result + def update_segmentation(result): + segmentation, lineage = result + is_empty = segmentation.max() == 0 + if is_empty: + self._empty_segmentation_warning() + + state = AnnotatorState() + state.lineage = lineage + + self._viewer.layers["auto_segmentation"].data = segmentation + self._viewer.layers["auto_segmentation"].refresh() + + worker = seg_impl() + worker.returned.connect(update_segmentation) + worker.start() + return worker diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 183678d5..0f7bfc72 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -152,10 +152,12 @@ def _get_widgets(self): states=self._track_state_labels, track_ids=list(state.lineage.keys()), ) segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True) + autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) return { "tracking": self._tracking_widget, "segment": widgets.segment_frame(), "segment_nd": segment_nd, + "autosegment": autotrack, "commit": widgets.commit_track(), "clear": widgets.clear_track(), } @@ -163,6 +165,7 @@ def _get_widgets(self): def __init__(self, viewer: "napari.viewer.Viewer") -> None: # Initialize the state for tracking. self._init_track_state() + self._with_decoder = AnnotatorState().decoder is not None super().__init__(viewer=viewer, ndim=3) # Go to t=0. self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:]) @@ -176,6 +179,11 @@ def _init_track_state(self): def _update_image(self): super()._update_image() self._init_track_state() + state = AnnotatorState() + if self._with_decoder: + state.amg_state = vutil._load_is_state(state.embedding_path) + else: + state.amg_state = vutil._load_amg_state(state.embedding_path) def annotator_tracking( @@ -187,6 +195,7 @@ def annotator_tracking( halo: Optional[Tuple[int, int]] = None, return_viewer: bool = False, viewer: Optional["napari.viewer.Viewer"] = None, + precompute_amg_state: bool = False, checkpoint_path: Optional[str] = None, device: Optional[Union[str, torch.device]] = None, ) -> Optional["napari.viewer.Viewer"]: @@ -203,6 +212,9 @@ def annotator_tracking( return_viewer: Whether to return the napari viewer to further modify it before starting the tool. viewer: The viewer to which the SegmentAnything functionality should be added. This enables using a pre-initialized viewer. + precompute_amg_state: Whether to precompute the state for automatic mask generation. + This will take more time when precomputing embeddings, but will then make + automatic mask generation much faster. checkpoint_path: Path to a custom checkpoint from which to load the SAM model. device: The computational device to use for the SAM model. @@ -210,13 +222,13 @@ def annotator_tracking( The napari viewer, only returned if `return_viewer=True`. """ - # TODO update this to match the new annotator design # Initialize the predictor state. state = AnnotatorState() state.initialize_predictor( image, model_type=model_type, save_path=embedding_path, - halo=halo, tile_shape=tile_shape, prefer_decoder=False, + halo=halo, tile_shape=tile_shape, prefer_decoder=True, ndim=3, checkpoint_path=checkpoint_path, device=device, + precompute_amg_state=precompute_amg_state, ) state.image_shape = image.shape[:-1] if image.ndim == 4 else image.shape diff --git a/test/test_multi_dimensional_segmentation.py b/test/test_multi_dimensional_segmentation.py index 9fc24c0a..f7e04a50 100644 --- a/test/test_multi_dimensional_segmentation.py +++ b/test/test_multi_dimensional_segmentation.py @@ -60,6 +60,24 @@ def test_merge_instance_segmentation_3d_with_closing(self): for z in range(1, n_slices): self.assertTrue(np.array_equal(ids0, np.unique(merged_seg[z]))) + def test_track_across_frames(self): + from micro_sam.multi_dimensional_segmentation import track_across_frames + + n_slices = 5 + data = np.stack(n_slices * binary_blobs(512)) + seg = label(data) + + stacked_seg = [] + offset = 0 + for _ in range(n_slices): + stack_seg = seg.copy() + stack_seg[stack_seg != 0] += offset + offset = stack_seg.max() + stacked_seg.append(stack_seg) + stacked_seg = np.stack(stacked_seg) + + seg, lineage = track_across_frames(stacked_seg) + if __name__ == "__main__": unittest.main()