Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement singleton state and use it in the 2d annotator #240

Merged
merged 3 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple

from micro_sam.instance_segmentation import AMGBase
from micro_sam.util import ImageEmbeddings
from segment_anything import SamPredictor


class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


@dataclass
class AnnotatorState(metaclass=Singleton):

# predictor, image_embeddings and image_shape:
# This needs to be initialized for the interactive segmentation fucntionality.
image_embeddings: Optional[ImageEmbeddings] = None
predictor: Optional[SamPredictor] = None
image_shape: Optional[Tuple[int, int]] = None

# amg: needs to be initialized for the automatic segmentation functionality.
# amg_state: for storing the instance segmentation state for the 3d segmentation tool.
amg: Optional[AMGBase] = None
amg_state: Optional[Dict] = None

# current_track_id, lineage:
# State for the tracking annotator to keep track of lineage information.
current_track_id: Optional[int] = None
lineage: Optional[Dict] = None

def initialized_for_interactive_segmentation(self):
have_image_embeddings = self.image_embeddings is not None
have_predictor = self.predictor is not None
have_image_shape = self.image_shape is not None
init_sum = sum((have_image_embeddings, have_predictor, have_image_shape))
if init_sum == 3:
return True
elif init_sum == 0:
return False
else:
raise RuntimeError(
f"Invalid AnnotatorState: {init_sum} / 3 parts of the state "
"needed for interactive segmentation are initialized."
)

def initialized_for_tracking(self):
have_current_track_id = self.current_track_id is not None
have_lineage = self.lineage is not None
init_sum = sum((have_current_track_id, have_lineage))
if init_sum == 2:
return True
elif init_sum == 0:
return False
else:
raise RuntimeError(
f"Invalid AnnotatorState: {init_sum} / 2 parts of the state "
"needed for tracking are initialized."
)
48 changes: 27 additions & 21 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..visualization import project_embeddings_for_visualization
from . import util as vutil
from .gui_utils import show_wrong_file_warning
from ._state import AnnotatorState


@magicgui(call_button="Segment Object [S]")
Expand All @@ -23,14 +24,16 @@ def _segment_widget(v: Viewer, box_extension: float = 0.1) -> None:
boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape)
points, labels = vutil.point_layer_to_prompts(v.layers["point_prompts"], with_stop_annotation=False)

if IMAGE_EMBEDDINGS["original_size"] is None: # tiled prediction
predictor = AnnotatorState().predictor
image_embeddings = AnnotatorState().image_embeddings
if image_embeddings["original_size"] is None: # tiled prediction
seg = vutil.prompt_segmentation(
PREDICTOR, points, labels, boxes, masks, shape, image_embeddings=IMAGE_EMBEDDINGS,
predictor, points, labels, boxes, masks, shape, image_embeddings=image_embeddings,
multiple_box_prompts=True, box_extension=box_extension,
)
else: # normal prediction and we have set the precomputed embeddings already
seg = vutil.prompt_segmentation(
PREDICTOR, points, labels, boxes, masks, shape, multiple_box_prompts=True, box_extension=box_extension,
predictor, points, labels, boxes, masks, shape, multiple_box_prompts=True, box_extension=box_extension,
)

# no prompts were given or prompts were invalid, skip segmentation
Expand Down Expand Up @@ -62,17 +65,21 @@ def _autosegment_widget(
min_object_size: int = 100,
with_background: bool = True,
) -> None:
global AMG
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
if AMG is None:
AMG = instance_segmentation.get_amg(PREDICTOR, is_tiled)
state = AnnotatorState()

if not AMG.is_initialized:
AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
is_tiled = state.image_embeddings["input_size"] is None
if state.amg is None:
state.amg = instance_segmentation.get_amg(state.predictor, is_tiled)

seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
shape = state.image_shape
if not state.amg.is_initialized:
# we don't need to pass the actual image data here, since the embeddings are passed
# (the image data is only used by the amg to compute image embeddings, so not needed here)
dummy_image = np.zeros(shape, dtype="uint8")
state.amg.initialize(dummy_image, image_embeddings=state.image_embeddings, verbose=True)

seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)

shape = v.layers["raw"].data.shape[:2]
seg = instance_segmentation.mask_data_to_segmentation(
seg, shape, with_background=with_background, min_object_size=min_object_size
)
Expand Down Expand Up @@ -112,7 +119,7 @@ def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings):

# show the PCA of the image embeddings
if show_embeddings:
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
embedding_vis, scale = project_embeddings_for_visualization(AnnotatorState().image_embeddings)
v.add_image(embedding_vis, name="embeddings", scale=scale)

labels = ["positive", "negative"]
Expand Down Expand Up @@ -224,26 +231,25 @@ def annotator_2d(
Returns:
The napari viewer, only returned if `return_viewer=True`.
"""
# for access to the predictor and the image embeddings in the widgets
global PREDICTOR, IMAGE_EMBEDDINGS, AMG
AMG = None
state = AnnotatorState()

if predictor is None:
PREDICTOR = util.get_sam_model(model_type=model_type)
state.predictor = util.get_sam_model(model_type=model_type)
else:
PREDICTOR = predictor
state.predictor = predictor
state.image_shape = _get_shape(raw)

IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
state.image_embeddings = util.precompute_image_embeddings(
state.predictor, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
wrong_file_callback=show_wrong_file_warning
)
if precompute_amg_state and (embedding_path is not None):
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)
state.amg = cache_amg_state(state.predictor, raw, state.image_embeddings, embedding_path)

# we set the pre-computed image embeddings if we don't use tiling
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically)
if tile_shape is None:
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
util.set_precomputed(state.predictor, state.image_embeddings)

# viewer is freshly initialized
if v is None:
Expand Down
69 changes: 39 additions & 30 deletions micro_sam/sam_annotator/annotator_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ..visualization import project_embeddings_for_visualization
from . import util as vutil
from .gui_utils import show_wrong_file_warning
from ._state import AnnotatorState


#
Expand All @@ -39,9 +40,10 @@ def _segment_slice_wigdet(v: Viewer, box_extension: float = 0.1) -> None:
boxes, masks = vutil.shape_layer_to_prompts(v.layers["prompts"], shape, i=z)
points, labels = point_prompts

state = AnnotatorState()
seg = vutil.prompt_segmentation(
PREDICTOR, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=IMAGE_EMBEDDINGS, i=z, box_extension=box_extension,
state.predictor, points, labels, boxes, masks, shape, multiple_box_prompts=False,
image_embeddings=state.image_embeddings, i=z, box_extension=box_extension,
)

# no prompts were given or prompts were invalid, skip segmentation
Expand All @@ -54,19 +56,20 @@ def _segment_slice_wigdet(v: Viewer, box_extension: float = 0.1) -> None:


def _segment_volume_for_current_object(v, projection, iou_threshold, box_extension):
shape = v.layers["raw"].data.shape
state = AnnotatorState()
shape = state.image_shape

with progress(total=shape[0]) as progress_bar:

# step 1: segment all slices with prompts
seg, slices, stop_lower, stop_upper = vutil.segment_slices_with_prompts(
PREDICTOR, v.layers["point_prompts"], v.layers["prompts"], IMAGE_EMBEDDINGS, shape,
state.predictor, v.layers["point_prompts"], v.layers["prompts"], state.image_embeddings, shape,
progress_bar=progress_bar,
)

# step 2: segment the rest of the volume based on smart prompting
seg = segment_mask_in_volume(
seg, PREDICTOR, IMAGE_EMBEDDINGS, slices,
seg, state.predictor, state.image_embeddings, slices,
stop_lower, stop_upper,
iou_threshold=iou_threshold, projection=projection,
progress_bar=progress_bar, box_extension=box_extension,
Expand All @@ -89,12 +92,13 @@ def _segment_volume_for_auto_segmentation(
seg[:start_slice] = 0
seg[(start_slice+1):]

state = AnnotatorState()
for object_id in progress(object_ids):
object_seg = seg == object_id
segmented_slices = np.array([start_slice])
object_seg = segment_mask_in_volume(
segmentation=object_seg, predictor=PREDICTOR,
image_embeddings=IMAGE_EMBEDDINGS, segmented_slices=segmented_slices,
segmentation=object_seg, predictor=state.predictor,
image_embeddings=state.image_embeddings, segmented_slices=segmented_slices,
stop_lower=False, stop_upper=False, iou_threshold=iou_threshold,
projection=projection, box_extension=box_extension,
)
Expand Down Expand Up @@ -146,30 +150,36 @@ def _autosegment_widget(
min_object_size: int = 100,
with_background: bool = True,
) -> None:
global AMG, AMG_STATE
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
if AMG is None:
AMG = instance_segmentation.get_amg(PREDICTOR, is_tiled)
state = AnnotatorState()

is_tiled = state.image_embeddings["input_size"] is None
if state.amg is None:
state.amg = instance_segmentation.get_amg(state.predictor, is_tiled)

i = int(v.cursor.position[0])
if i in AMG_STATE:
state = AMG_STATE[i]
AMG.set_state(state)
shape = state.image_shape[-2:]

if i in state.amg_state:
amg_state_i = state.amg_state[i]
state.amg.set_state(amg_state_i)

else:
image_data = v.layers["raw"].data[i]
AMG.initialize(image_data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True, i=i)
state = AMG.get_state()
# we don't need to pass the actual image data here, since the embeddings are passed
# (the image data is only used by the amg to compute image embeddings, so not needed here)
dummy_image = np.zeros(shape, dtype="uint8")

state.amg.initialize(dummy_image, image_embeddings=state.image_embeddings, verbose=True, i=i)
amg_state_i = state.amg.get_state()

cache_folder = AMG_STATE["cache_folder"]
state.amg_state[i] = amg_state_i
cache_folder = state.amg_state["cache_folder"]
if cache_folder is not None:
cache_path = os.path.join(cache_folder, f"state-{i}.pkl")
with open(cache_path, "wb") as f:
pickle.dump(state, f)
pickle.dump(amg_state_i, f)

seg = AMG.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
seg = state.amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)

shape = v.layers["raw"].data.shape[-2:]
seg = instance_segmentation.mask_data_to_segmentation(
seg, shape, with_background=with_background, min_object_size=min_object_size
)
Expand Down Expand Up @@ -230,20 +240,19 @@ def annotator_3d(
Returns:
The napari viewer, only returned if `return_viewer=True`.
"""
# for access to the predictor and the image embeddings in the widgets
global PREDICTOR, IMAGE_EMBEDDINGS, AMG, AMG_STATE
AMG = None
state = AnnotatorState()

if predictor is None:
PREDICTOR = util.get_sam_model(model_type=model_type)
state.predictor = util.get_sam_model(model_type=model_type)
else:
PREDICTOR = predictor
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
state.predictor = predictor
state.image_embeddings = util.precompute_image_embeddings(
state.predictor, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
wrong_file_callback=show_wrong_file_warning,
)
state.image_shape = raw.shape

AMG_STATE = _load_amg_state(embedding_path)
state.amg_state = _load_amg_state(embedding_path)

#
# initialize the viewer and add layers
Expand All @@ -263,7 +272,7 @@ def annotator_3d(

# show the PCA of the image embeddings
if show_embeddings:
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS)
embedding_vis, scale = project_embeddings_for_visualization(state.image_embeddings)
v.add_image(embedding_vis, name="embeddings", scale=scale)

labels = ["positive", "negative"]
Expand Down
Loading