Skip to content

Commit

Permalink
Update the image series annotator (#738)
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape authored Oct 16, 2024
1 parent cd4418a commit 896ea00
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 54 deletions.
2 changes: 1 addition & 1 deletion micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _precompute_state_for_file(
# Precompute the image embeddings.
output_path = Path(output_path).with_suffix(".zarr")
embeddings = util.precompute_image_embeddings(
predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=False
)

# Precompute the state for automatic instance segmnetaiton (AMG or AIS).
Expand Down
145 changes: 92 additions & 53 deletions micro_sam/sam_annotator/image_series_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,46 @@ def _get_input_shape(image, is_volumetric=False):
return image_shape


def _initialize_annotator(
viewer, image, image_embedding_path,
model_type, halo, tile_shape, predictor, decoder, is_volumetric,
precompute_amg_state, checkpoint_path, device,
embedding_path,
):
if viewer is None:
viewer = napari.Viewer()
viewer.add_image(image, name="image")

state = AnnotatorState()
state.initialize_predictor(
image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
predictor=predictor, decoder=decoder,
ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
checkpoint_path=checkpoint_path, device=device, skip_load=False,
)
state.image_shape = _get_input_shape(image, is_volumetric)

if is_volumetric:
if image.ndim not in [3, 4]:
raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
annotator = Annotator3d(viewer)
else:
if image.ndim not in (2, 3):
raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
annotator = Annotator2d(viewer)

annotator._update_image()

# Add the annotator widget to the viewer and sync widgets.
viewer.window.add_dock_widget(annotator)
_sync_embedding_widget(
state.widgets["embeddings"], model_type,
save_path=embedding_path, checkpoint_path=checkpoint_path,
device=device, tile_shape=tile_shape, halo=halo
)
return viewer, annotator


def image_series_annotator(
images: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
output_folder: str,
Expand All @@ -94,6 +134,7 @@ def image_series_annotator(
is_volumetric: bool = False,
device: Optional[Union[str, torch.device]] = None,
prefer_decoder: bool = True,
skip_segmented: bool = True,
) -> Optional["napari.viewer.Viewer"]:
"""Run the annotation tool for a series of images (supported for both 2d and 3d images).
Expand All @@ -116,13 +157,13 @@ def image_series_annotator(
is_volumetric: Whether to use the 3d annotator.
prefer_decoder: Whether to use decoder based instance segmentation if
the model used has an additional decoder for instance segmentation.
skip_segmented: Whether to skip images that were already segmented.
Returns:
The napari viewer, only returned if `return_viewer=True`.
"""

end_msg = "You have annotated the last image. Do you wish to close napari?"
os.makedirs(output_folder, exist_ok=True)
next_image_id = 0

# Precompute embeddings and amg state (if corresponding options set).
predictor, decoder, embedding_paths = _precompute(
Expand All @@ -132,57 +173,48 @@ def image_series_annotator(
ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder,
)

# Load the first image and intialize the viewer, annotator and state.
if isinstance(images[next_image_id], np.ndarray):
image = images[next_image_id]
have_inputs_as_arrays = True
else:
image = imageio.imread(images[next_image_id])
have_inputs_as_arrays = False

image_embedding_path = embedding_paths[next_image_id]

if viewer is None:
viewer = napari.Viewer()
viewer.add_image(image, name="image")

state = AnnotatorState()
state.initialize_predictor(
image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
predictor=predictor, decoder=decoder,
ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
checkpoint_path=checkpoint_path, device=device, skip_load=False,
)
state.image_shape = _get_input_shape(image, is_volumetric)

if is_volumetric:
if image.ndim not in [3, 4]:
raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}")
annotator = Annotator3d(viewer)
else:
if image.ndim not in (2, 3):
raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}")
annotator = Annotator2d(viewer)

annotator._update_image()

# Add the annotator widget to the viewer and sync widgets.
viewer.window.add_dock_widget(annotator)
_sync_embedding_widget(
state.widgets["embeddings"], model_type,
save_path=embedding_path, checkpoint_path=checkpoint_path,
device=device, tile_shape=tile_shape, halo=halo
)
next_image_id = 0
have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray)

def _save_segmentation(image_path, current_idx, segmentation):
def _get_save_path(image_path, current_idx):
if have_inputs_as_arrays:
fname = f"seg_{current_idx:05}.tif"
else:
fname = os.path.basename(image_path)
fname = os.path.splitext(fname)[0] + ".tif"
return os.path.join(output_folder, fname)

# Check which image to load next if we skip segmented images.
image_embedding_path = None
if skip_segmented:
while True:
if next_image_id == len(images):
print(end_msg)
return

out_path = os.path.join(output_folder, fname)
imageio.imwrite(out_path, segmentation)
save_path = _get_save_path(images[next_image_id], next_image_id)
if not os.path.exists(save_path):
print("The first image to annotate is image number", next_image_id)
image = images[next_image_id]
if not have_inputs_as_arrays:
image = imageio.imread(image)
image_embedding_path = embedding_paths[next_image_id]
break

next_image_id += 1

# Initialize the viewer and annotator for this image.
state = AnnotatorState()
viewer, annotator = _initialize_annotator(
viewer, image, image_embedding_path,
model_type, halo, tile_shape, predictor, decoder, is_volumetric,
precompute_amg_state, checkpoint_path, device,
embedding_path,
)

def _save_segmentation(image_path, current_idx, segmentation):
save_path = _get_save_path(image_path, next_image_id)
imageio.imwrite(save_path, segmentation, compression="zlib")

# Add functionality for going to the next image.
@magicgui(call_button="Next Image [N]")
Expand All @@ -203,14 +235,20 @@ def next_image(*args):
# Clear the segmentation already to avoid lagging removal.
viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data)

# Load the next image.
# Go to the next images, if skipping images that are already segmented check if we have to load it.
next_image_id += 1
if skip_segmented:
save_path = _get_save_path(images[next_image_id], next_image_id)
while os.path.exists(save_path):
next_image_id += 1
if next_image_id == len(images):
break
save_path = _get_save_path(images[next_image_id], next_image_id)

# Load the next image.
if next_image_id == len(images):
msg = "You have annotated the last image. Do you wish to close napari?"
print(msg)
abort = False
# inform the user via dialog
abort = widgets._generate_message("info", msg)
# Inform the user via dialog.
abort = widgets._generate_message("info", end_msg)
if not abort:
viewer.close()
return
Expand Down Expand Up @@ -459,6 +497,7 @@ def main():
)
parser.add_argument("--precompute_amg_state", action="store_true")
parser.add_argument("--prefer_decoder", action="store_false")
parser.add_argument("--skip_segmented", action="store_false")

args = parser.parse_args()

Expand All @@ -467,5 +506,5 @@ def main():
embedding_path=args.embedding_path, model_type=args.model_type,
tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state,
checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric,
prefer_decoder=args.prefer_decoder,
prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented
)
2 changes: 2 additions & 0 deletions micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None:
viewer.layers["point_prompts"].data = []
viewer.layers["point_prompts"].refresh()
if "prompts" in viewer.layers:
# Select all prompts and then remove them.
viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data)))
viewer.layers["prompts"].remove_selected()
viewer.layers["prompts"].refresh()
if not clear_segmentations:
Expand Down

0 comments on commit 896ea00

Please sign in to comment.