From b66b2b3e62112264aee2eea90f6ce29c7d1875d9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 23 Oct 2023 23:50:23 +0200 Subject: [PATCH 1/3] Implement singleton state and use it in the 2d annotator --- micro_sam/sam_annotator/_state.py | 23 +++++++++++++ micro_sam/sam_annotator/annotator_2d.py | 44 ++++++++++++++----------- 2 files changed, 47 insertions(+), 20 deletions(-) create mode 100644 micro_sam/sam_annotator/_state.py diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py new file mode 100644 index 00000000..477fe3b7 --- /dev/null +++ b/micro_sam/sam_annotator/_state.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Optional + +from micro_sam.instance_segmentation import AMGBase +from micro_sam.util import ImageEmbeddings +from segment_anything import SamPredictor + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +# TODO the shape should also go in here +@dataclass +class AnnotatorState(metaclass=Singleton): + image_embeddings: Optional[ImageEmbeddings] = None + predictor: Optional[SamPredictor] = None + amg: Optional[AMGBase] = None diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 47d84bcb..d186d18c 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -13,6 +13,7 @@ from ..visualization import project_embeddings_for_visualization from . import util as vutil from .gui_utils import show_wrong_file_warning +from ._state import AnnotatorState @magicgui(call_button="Segment Object [S]") @@ -23,14 +24,17 @@ def _segment_widget(v: Viewer, box_extension: float = 0.1) -> None: boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape) points, labels = vutil.point_layer_to_prompts(v.layers["point_prompts"], with_stop_annotation=False) - if IMAGE_EMBEDDINGS["original_size"] is None: # tiled prediction + # TODO we should check that the image embeddings are initialized + predictor = AnnotatorState().predictor + image_embeddings = AnnotatorState().image_embeddings + if image_embeddings["original_size"] is None: # tiled prediction seg = vutil.prompt_segmentation( - PREDICTOR, points, labels, boxes, masks, shape, image_embeddings=IMAGE_EMBEDDINGS, + predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings, multiple_box_prompts=True, box_extension=box_extension, ) else: # normal prediction and we have set the precomputed embeddings already seg = vutil.prompt_segmentation( - PREDICTOR, points, labels, boxes, masks, shape, multiple_box_prompts=True, box_extension=box_extension, + predictor, points, labels, boxes, masks, shape, multiple_box_prompts=True, box_extension=box_extension, ) # no prompts were given or prompts were invalid, skip segmentation @@ -62,15 +66,17 @@ def _autosegment_widget( min_object_size: int = 100, with_background: bool = True, ) -> None: - global AMG - is_tiled = IMAGE_EMBEDDINGS["input_size"] is None - if AMG is None: - AMG = instance_segmentation.get_amg(PREDICTOR, is_tiled) + state = AnnotatorState() - if not AMG.is_initialized: - AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True) + # TODO check that the image embeddings are initialized + is_tiled = state.image_embeddings["input_size"] is None + if state.amg is None: + state.amg = instance_segmentation.get_amg(state.predictor, is_tiled) - seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) + if not state.amg.is_initialized: + state.amg.initialize(v.layers["raw"].data, image_embeddings=state.image_embeddings, verbose=True) + + seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) shape = v.layers["raw"].data.shape[:2] seg = instance_segmentation.mask_data_to_segmentation( @@ -112,7 +118,7 @@ def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings): # show the PCA of the image embeddings if show_embeddings: - embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) + embedding_vis, scale = project_embeddings_for_visualization(AnnotatorState().image_embeddings) v.add_image(embedding_vis, name="embeddings", scale=scale) labels = ["positive", "negative"] @@ -224,26 +230,24 @@ def annotator_2d( Returns: The napari viewer, only returned if `return_viewer=True`. """ - # for access to the predictor and the image embeddings in the widgets - global PREDICTOR, IMAGE_EMBEDDINGS, AMG - AMG = None + state = AnnotatorState() if predictor is None: - PREDICTOR = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_type=model_type) else: - PREDICTOR = predictor + state.predictor = predictor - IMAGE_EMBEDDINGS = util.precompute_image_embeddings( - PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, + state.image_embeddings = util.precompute_image_embeddings( + state.predictor, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, wrong_file_callback=show_wrong_file_warning ) if precompute_amg_state and (embedding_path is not None): - AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path) + state.amg = cache_amg_state(state.predictor, raw, state.image_embeddings, embedding_path) # we set the pre-computed image embeddings if we don't use tiling # (if we use tiling we cannot directly set it because the tile will be chosen dynamically) if tile_shape is None: - util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS) + util.set_precomputed(state.predictor, state.image_embeddings) # viewer is freshly initialized if v is None: From e60ede2408b151674d22f8fbde759ec3ea981683 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 24 Oct 2023 21:57:01 +0200 Subject: [PATCH 2/3] Use singleton state in all annotator tools --- micro_sam/sam_annotator/_state.py | 45 ++++++++- micro_sam/sam_annotator/annotator_2d.py | 10 +- micro_sam/sam_annotator/annotator_3d.py | 69 +++++++------ micro_sam/sam_annotator/annotator_tracking.py | 99 ++++++++++--------- 4 files changed, 142 insertions(+), 81 deletions(-) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 477fe3b7..e80cb337 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import Dict, Optional, Tuple from micro_sam.instance_segmentation import AMGBase from micro_sam.util import ImageEmbeddings @@ -15,9 +15,50 @@ def __call__(cls, *args, **kwargs): return cls._instances[cls] -# TODO the shape should also go in here @dataclass class AnnotatorState(metaclass=Singleton): + + # predictor, image_embeddings and image_shape: + # This needs to be initialized for the interactive segmentation fucntionality. image_embeddings: Optional[ImageEmbeddings] = None predictor: Optional[SamPredictor] = None + image_shape: Optional[Tuple[int, int]] = None + + # amg: needs to be initialized for the automatic segmentation functionality. + # amg_state: for storing the instance segmentation state for the 3d segmentation tool. amg: Optional[AMGBase] = None + amg_state: Optional[Dict] = None + + # current_track_id, lineage: + # State for the tracking annotator to keep track of lineage information. + current_track_id: Optional[int] = None + lineage: Optional[Dict] = None + + def initialized_for_interactive_segmentation(self): + have_image_embeddings = self.image_embeddings is not None + have_predictor = self.predictor is not None + have_image_shape = self.image_shape is not None + init_sum = sum((have_image_embeddings, have_predictor, have_image_shape)) + if init_sum == 3: + return True + elif init_sum == 0: + return False + else: + raise RuntimeError( + f"Invalid AnnotatorState: {init_sum} / 3 parts of the state " + "needed for interactive segmentation are initialized." + ) + + def initialized_for_tracking(self): + have_current_track_id = self.current_track_id is not None + have_lineage = self.lineage is not None + init_sum = sum((have_current_track_id, have_lineage)) + if init_sum == 2: + return True + elif init_sum == 0: + return False + else: + raise RuntimeError( + f"Invalid AnnotatorState: {init_sum} / 2 parts of the state " + "needed for tracking are initialized." + ) diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index d186d18c..9c8e002e 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -24,7 +24,6 @@ def _segment_widget(v: Viewer, box_extension: float = 0.1) -> None: boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape) points, labels = vutil.point_layer_to_prompts(v.layers["point_prompts"], with_stop_annotation=False) - # TODO we should check that the image embeddings are initialized predictor = AnnotatorState().predictor image_embeddings = AnnotatorState().image_embeddings if image_embeddings["original_size"] is None: # tiled prediction @@ -68,17 +67,19 @@ def _autosegment_widget( ) -> None: state = AnnotatorState() - # TODO check that the image embeddings are initialized is_tiled = state.image_embeddings["input_size"] is None if state.amg is None: state.amg = instance_segmentation.get_amg(state.predictor, is_tiled) + shape = state.image_shape if not state.amg.is_initialized: - state.amg.initialize(v.layers["raw"].data, image_embeddings=state.image_embeddings, verbose=True) + # we don't need to pass the actual image data here, since the embeddings are passed + # (the image data is only used by the amg to compute image embeddings, so not needed here) + dummy_image = np.zeros(shape, dtype="uint8") + state.amg.initialize(dummy_image, image_embeddings=state.image_embeddings, verbose=True) seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - shape = v.layers["raw"].data.shape[:2] seg = instance_segmentation.mask_data_to_segmentation( seg, shape, with_background=with_background, min_object_size=min_object_size ) @@ -236,6 +237,7 @@ def annotator_2d( state.predictor = util.get_sam_model(model_type=model_type) else: state.predictor = predictor + state.image_shape = _get_shape(raw) state.image_embeddings = util.precompute_image_embeddings( state.predictor, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index b49258ce..91264417 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -18,6 +18,7 @@ from ..visualization import project_embeddings_for_visualization from . import util as vutil from .gui_utils import show_wrong_file_warning +from ._state import AnnotatorState # @@ -39,9 +40,10 @@ def _segment_slice_wigdet(v: Viewer, box_extension: float = 0.1) -> None: boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape, i=z) points, labels = point_prompts + state = AnnotatorState() seg = vutil.prompt_segmentation( - PREDICTOR, points, labels, boxes, masks, shape, multiple_box_prompts=False, - image_embeddings=IMAGE_EMBEDDINGS, i=z, box_extension=box_extension, + state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, + image_embeddings=state.image_embeddings, i=z, box_extension=box_extension, ) # no prompts were given or prompts were invalid, skip segmentation @@ -54,19 +56,20 @@ def _segment_slice_wigdet(v: Viewer, box_extension: float = 0.1) -> None: def _segment_volume_for_current_object(v, projection, iou_threshold, box_extension): - shape = v.layers["raw"].data.shape + state = AnnotatorState() + shape = state.image_shape with progress(total=shape[0]) as progress_bar: # step 1: segment all slices with prompts seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts( - PREDICTOR, v.layers["point_prompts"], v.layers["prompts"], IMAGE_EMBEDDINGS, shape, + state.predictor, v.layers["point_prompts"], v.layers["prompts"], state.image_embeddings, shape, progress_bar=progress_bar, ) # step 2: segment the rest of the volume based on smart prompting seg = segment_mask_in_volume( - seg, PREDICTOR, IMAGE_EMBEDDINGS, slices, + seg, state.predictor, state.image_embeddings, slices, stop_lower, stop_upper, iou_threshold=iou_threshold, projection=projection, progress_bar=progress_bar, box_extension=box_extension, @@ -89,12 +92,13 @@ def _segment_volume_for_auto_segmentation( seg[:start_slice] = 0 seg[(start_slice+1):] + state = AnnotatorState() for object_id in progress(object_ids): object_seg = seg == object_id segmented_slices = np.array([start_slice]) object_seg = segment_mask_in_volume( - segmentation=object_seg, predictor=PREDICTOR, - image_embeddings=IMAGE_EMBEDDINGS, segmented_slices=segmented_slices, + segmentation=object_seg, predictor=state.predictor, + image_embeddings=state.image_embeddings, segmented_slices=segmented_slices, stop_lower=False, stop_upper=False, iou_threshold=iou_threshold, projection=projection, box_extension=box_extension, ) @@ -146,30 +150,36 @@ def _autosegment_widget( min_object_size: int = 100, with_background: bool = True, ) -> None: - global AMG, AMG_STATE - is_tiled = IMAGE_EMBEDDINGS["input_size"] is None - if AMG is None: - AMG = instance_segmentation.get_amg(PREDICTOR, is_tiled) + state = AnnotatorState() + + is_tiled = state.image_embeddings["input_size"] is None + if state.amg is None: + state.amg = instance_segmentation.get_amg(state.predictor, is_tiled) i = int(v.cursor.position[0]) - if i in AMG_STATE: - state = AMG_STATE[i] - AMG.set_state(state) + shape = state.image_shape[-2:] + + if i in state.amg_state: + amg_state_i = state.amg_state[i] + state.amg.set_state(amg_state_i) else: - image_data = v.layers["raw"].data[i] - AMG.initialize(image_data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True, i=i) - state = AMG.get_state() + # we don't need to pass the actual image data here, since the embeddings are passed + # (the image data is only used by the amg to compute image embeddings, so not needed here) + dummy_image = np.zeros(shape, dtype="uint8") + + state.amg.initialize(dummy_image, image_embeddings=state.image_embeddings, verbose=True, i=i) + amg_state_i = state.amg.get_state() - cache_folder = AMG_STATE["cache_folder"] + state.amg_state[i] = amg_state_i + cache_folder = state.amg_state["cache_folder"] if cache_folder is not None: cache_path = os.path.join(cache_folder, f"state-{i}.pkl") with open(cache_path, "wb") as f: - pickle.dump(state, f) + pickle.dump(amg_state_i, f) - seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) + seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) - shape = v.layers["raw"].data.shape[-2:] seg = instance_segmentation.mask_data_to_segmentation( seg, shape, with_background=with_background, min_object_size=min_object_size ) @@ -230,20 +240,19 @@ def annotator_3d( Returns: The napari viewer, only returned if `return_viewer=True`. """ - # for access to the predictor and the image embeddings in the widgets - global PREDICTOR, IMAGE_EMBEDDINGS, AMG, AMG_STATE - AMG = None + state = AnnotatorState() if predictor is None: - PREDICTOR = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_type=model_type) else: - PREDICTOR = predictor - IMAGE_EMBEDDINGS = util.precompute_image_embeddings( - PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo, + state.predictor = predictor + state.image_embeddings = util.precompute_image_embeddings( + state.predictor, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo, wrong_file_callback=show_wrong_file_warning, ) + state.image_shape = raw.shape - AMG_STATE = _load_amg_state(embedding_path) + state.amg_state = _load_amg_state(embedding_path) # # initialize the viewer and add layers @@ -263,7 +272,7 @@ def annotator_3d( # show the PCA of the image embeddings if show_embeddings: - embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) + embedding_vis, scale = project_embeddings_for_visualization(state.image_embeddings) v.add_image(embedding_vis, name="embeddings", scale=scale) labels = ["positive", "negative"] diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index c026aaeb..f0683696 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -21,6 +21,7 @@ from ..visualization import project_embeddings_for_visualization from . import util as vutil from .gui_utils import show_wrong_file_warning +from ._state import AnnotatorState # Cyan (track) and Magenta (division) STATE_COLOR_CYCLE = ["#00FFFF", "#FF00FF", ] @@ -160,18 +161,20 @@ def _update_motion_model(seg, t, t0, motion_model): def _update_lineage(): - global LINEAGE, TRACKING_WIDGET - mother = CURRENT_TRACK_ID - assert mother in LINEAGE - assert len(LINEAGE[mother]) == 0 + global TRACKING_WIDGET + state = AnnotatorState() + + mother = state.current_track_id + assert mother in state.lineage + assert len(state.lineage[mother]) == 0 - daughter1, daughter2 = CURRENT_TRACK_ID + 1, CURRENT_TRACK_ID + 2 - LINEAGE[mother] = [daughter1, daughter2] - LINEAGE[daughter1] = [] - LINEAGE[daughter2] = [] + daughter1, daughter2 = state.current_track_id + 1, state.current_track_id + 2 + state.lineage[mother] = [daughter1, daughter2] + state.lineage[daughter1] = [] + state.lineage[daughter2] = [] # update the choices in the track_id menu - track_ids = list(map(str, LINEAGE.keys())) + track_ids = list(map(str, state.lineage.keys())) TRACKING_WIDGET[1].choices = track_ids # not sure if this does the right thing. @@ -187,21 +190,22 @@ def _update_lineage(): @magicgui(call_button="Segment Frame [S]") def _segment_frame_wigdet(v: Viewer) -> None: + state = AnnotatorState() shape = v.layers["current_track"].data.shape[1:] position = v.cursor.position t = int(position[0]) - point_prompts = vutil.point_layer_to_prompts(v.layers["point_prompts"], i=t, track_id=CURRENT_TRACK_ID) + point_prompts = vutil.point_layer_to_prompts(v.layers["point_prompts"], i=t, track_id=state.current_track_id) # this is a stop prompt, we do nothing if not point_prompts: return - boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape, i=t, track_id=CURRENT_TRACK_ID) + boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape, i=t, track_id=state.current_track_id) points, labels = point_prompts seg = vutil.prompt_segmentation( - PREDICTOR, points, labels, boxes, masks, shape, multiple_box_prompts=False, - image_embeddings=IMAGE_EMBEDDINGS, i=t + state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False, + image_embeddings=state.image_embeddings, i=t ) # no prompts were given or prompts were invalid, skip segmentation @@ -210,11 +214,11 @@ def _segment_frame_wigdet(v: Viewer) -> None: return # clear the old segmentation for this track_id - old_mask = v.layers["current_track"].data[t] == CURRENT_TRACK_ID + old_mask = v.layers["current_track"].data[t] == state.current_track_id v.layers["current_track"].data[t][old_mask] = 0 # set the new segmentation new_mask = seg.squeeze() == 1 - v.layers["current_track"].data[t][new_mask] = CURRENT_TRACK_ID + v.layers["current_track"].data[t][new_mask] = state.current_track_id v.layers["current_track"].refresh() @@ -223,7 +227,8 @@ def _track_object_widget( v: Viewer, iou_threshold: float = 0.5, projection: str = "default", motion_smoothing: float = 0.5, box_extension: float = 0.1, ) -> None: - shape = v.layers["raw"].data.shape + state = AnnotatorState() + shape = state.image_shape # we use the bounding box projection method as default which generally seems to work better for larger changes # between frames (which is pretty tyipical for tracking compared to 3d segmentation) @@ -232,14 +237,14 @@ def _track_object_widget( with progress(total=shape[0]) as progress_bar: # step 1: segment all slices with prompts seg, slices, _, stop_upper = vutil.segment_slices_with_prompts( - PREDICTOR, v.layers["point_prompts"], v.layers["prompts"], IMAGE_EMBEDDINGS, shape, - progress_bar=progress_bar, track_id=CURRENT_TRACK_ID + state.predictor, v.layers["point_prompts"], v.layers["prompts"], state.image_embeddings, shape, + progress_bar=progress_bar, track_id=state.current_track_id ) # step 2: track the object starting from the lowest annotated slice seg, has_division = _track_from_prompts( v.layers["point_prompts"], v.layers["prompts"], seg, - PREDICTOR, slices, IMAGE_EMBEDDINGS, stop_upper, + state.predictor, slices, state.image_embeddings, stop_upper, threshold=iou_threshold, projection=projection_, progress_bar=progress_bar, motion_smoothing=motion_smoothing, box_extension=box_extension, @@ -247,18 +252,20 @@ def _track_object_widget( # if a division has occurred and it's the first time it occurred for this track # we need to create the two daughter tracks and update the lineage - if has_division and (len(LINEAGE[CURRENT_TRACK_ID]) == 0): + if has_division and (len(state.lineage[state.current_track_id]) == 0): _update_lineage() # clear the old track mask - v.layers["current_track"].data[v.layers["current_track"].data == CURRENT_TRACK_ID] = 0 + v.layers["current_track"].data[v.layers["current_track"].data == state.current_track_id] = 0 # set the new track mask - v.layers["current_track"].data[seg == 1] = CURRENT_TRACK_ID + v.layers["current_track"].data[seg == 1] = state.current_track_id v.layers["current_track"].refresh() def create_tracking_menu(points_layer, box_layer, states, track_ids): """@private""" + state = AnnotatorState() + state_menu = ComboBox(label="track_state", choices=states) track_id_menu = ComboBox(label="track_id", choices=list(map(str, track_ids))) tracking_widget = Container(widgets=[state_menu, track_id_menu]) @@ -269,11 +276,10 @@ def update_state(event): state_menu.value = new_state def update_track_id(event): - global CURRENT_TRACK_ID new_id = str(points_layer.current_properties["track_id"][0]) if new_id != track_id_menu.value: track_id_menu.value = new_id - CURRENT_TRACK_ID = int(new_id) + state.current_track_id = int(new_id) # def update_state_boxes(event): # new_state = str(box_layer.current_properties["state"][0]) @@ -281,11 +287,10 @@ def update_track_id(event): # state_menu.value = new_state def update_track_id_boxes(event): - global CURRENT_TRACK_ID new_id = str(box_layer.current_properties["track_id"][0]) if new_id != track_id_menu.value: track_id_menu.value = new_id - CURRENT_TRACK_ID = int(new_id) + state.current_track_id = int(new_id) points_layer.events.current_properties.connect(update_state) points_layer.events.current_properties.connect(update_track_id) @@ -299,11 +304,10 @@ def state_changed(new_state): points_layer.refresh_colors() def track_id_changed(new_track_id): - global CURRENT_TRACK_ID current_properties = points_layer.current_properties current_properties["track_id"] = np.array([new_track_id]) points_layer.current_properties = current_properties - CURRENT_TRACK_ID = int(new_track_id) + state.current_track_id = int(new_track_id) # def state_changed_boxes(new_state): # current_properties = box_layer.current_properties @@ -312,11 +316,10 @@ def track_id_changed(new_track_id): # box_layer.refresh_colors() def track_id_changed_boxes(new_track_id): - global CURRENT_TRACK_ID current_properties = box_layer.current_properties current_properties["track_id"] = np.array([new_track_id]) box_layer.current_properties = current_properties - CURRENT_TRACK_ID = int(new_track_id) + state.current_track_id = int(new_track_id) state_menu.changed.connect(state_changed) track_id_menu.changed.connect(track_id_changed) @@ -328,20 +331,22 @@ def track_id_changed_boxes(new_track_id): def _reset_tracking_state(): - global CURRENT_TRACK_ID, LINEAGE, TRACKING_WIDGET + global TRACKING_WIDGET + state = AnnotatorState() # reset the lineage and track id - CURRENT_TRACK_ID = 1 - LINEAGE = {1: []} + state.current_track_id = 1 + state.lineage = {1: []} # reset the choices in the track_id menu - track_ids = list(map(str, LINEAGE.keys())) + track_ids = list(map(str, state.lineage.keys())) TRACKING_WIDGET[1].choices = track_ids @magicgui(call_button="Commit [C]", layer={"choices": ["current_track"]}) def _commit_tracking_widget(v: Viewer, layer: str = "current_track") -> None: global COMMITTED_LINEAGES + state = AnnotatorState() seg = v.layers[layer].data @@ -356,7 +361,7 @@ def _commit_tracking_widget(v: Viewer, layer: str = "current_track") -> None: v.layers[layer].refresh() updated_lineage = { - parent + id_offset: [child + id_offset for child in children] for parent, children in LINEAGE.items() + parent + id_offset: [child + id_offset for child in children] for parent, children in state.lineage.items() } COMMITTED_LINEAGES.append(updated_lineage) @@ -411,21 +416,25 @@ def annotator_tracking( Returns: The napari viewer, only returned if `return_viewer=True`. """ - # global state - global PREDICTOR, IMAGE_EMBEDDINGS, CURRENT_TRACK_ID, LINEAGE + # NOTE: the tracking widget is left as global state for now. + # The fact that it is in the state is quite a hack. When building a plugin for the + # tracking annotator this needs to be redesigned! global TRACKING_WIDGET + state = AnnotatorState() + if predictor is None: - PREDICTOR = util.get_sam_model(model_type=model_type) + state.predictor = util.get_sam_model(model_type=model_type) else: - PREDICTOR = predictor - IMAGE_EMBEDDINGS = util.precompute_image_embeddings( - PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo, + state.predictor = predictor + state.image_embeddings = util.precompute_image_embeddings( + state.predictor, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo, wrong_file_callback=show_wrong_file_warning, ) + state.image_shape = raw.shape - CURRENT_TRACK_ID = 1 - LINEAGE = {1: []} + state.current_track_id = 1 + state.lineage = {1: []} # # initialize the viewer and add layers @@ -444,7 +453,7 @@ def annotator_tracking( # show the PCA of the image embeddings if show_embeddings: - embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) + embedding_vis, scale = project_embeddings_for_visualization(state.image_embeddings) v.add_image(embedding_vis, name="embeddings", scale=scale) # @@ -499,7 +508,7 @@ def annotator_tracking( prompt_widget = vutil.create_prompt_menu(prompts, labels) v.window.add_dock_widget(prompt_widget) - TRACKING_WIDGET = create_tracking_menu(prompts, box_prompts, state_labels, list(LINEAGE.keys())) + TRACKING_WIDGET = create_tracking_menu(prompts, box_prompts, state_labels, list(state.lineage.keys())) v.window.add_dock_widget(TRACKING_WIDGET) v.window.add_dock_widget(_segment_frame_wigdet) From 11f170d54ad05041c7eca5bf5ebe2a701bc87c4b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 30 Oct 2023 09:58:55 +0100 Subject: [PATCH 3/3] Add unittests for the state --- micro_sam/sam_annotator/_state.py | 5 +++++ test/test_sam_annotator/test_state.py | 32 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 test/test_sam_annotator/test_state.py diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index e80cb337..2642f2d1 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -1,3 +1,8 @@ +"""Implements a singleton class for the state of the annotation tools. +The singleton is implemented following the metaclass design described here: +https://itnext.io/deciding-the-best-singleton-approach-in-python-65c61e90cdc4 +""" + from dataclasses import dataclass from typing import Dict, Optional, Tuple diff --git a/test/test_sam_annotator/test_state.py b/test/test_sam_annotator/test_state.py new file mode 100644 index 00000000..694049da --- /dev/null +++ b/test/test_sam_annotator/test_state.py @@ -0,0 +1,32 @@ +import unittest +import micro_sam.util as util + +from skimage.data import binary_blobs + + +class TestState(unittest.TestCase): + model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" + + def test_state_for_interactive_segmentation(self): + from micro_sam.sam_annotator._state import AnnotatorState + image = binary_blobs(512) + predictor = util.get_sam_model(model_type=self.model_type) + image_embeddings = util.precompute_image_embeddings(predictor, image) + + state = AnnotatorState() + state.image_embeddings = image_embeddings + state.predictor = predictor + state.image_shape = image.shape + self.assertTrue(state.initialized_for_interactive_segmentation()) + + def test_state_for_tracking(self): + from micro_sam.sam_annotator._state import AnnotatorState + + state = AnnotatorState() + state.current_track_id = 1 + state.lineage = {1: {}} + self.assertTrue(state.initialized_for_tracking()) + + +if __name__ == "__main__": + unittest.main()