diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 520f2baec..736182603 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -164,7 +164,7 @@ def _precompute_state_for_file( # Precompute the image embeddings. output_path = Path(output_path).with_suffix(".zarr") embeddings = util.precompute_image_embeddings( - predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, + predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=False ) # Precompute the state for automatic instance segmnetaiton (AMG or AIS). diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index aec6ec028..cd5a15e2c 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -80,6 +80,46 @@ def _get_input_shape(image, is_volumetric=False): return image_shape +def _initialize_annotator( + viewer, image, image_embedding_path, + model_type, halo, tile_shape, predictor, decoder, is_volumetric, + precompute_amg_state, checkpoint_path, device, + embedding_path, +): + if viewer is None: + viewer = napari.Viewer() + viewer.add_image(image, name="image") + + state = AnnotatorState() + state.initialize_predictor( + image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, + predictor=predictor, decoder=decoder, + ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, + checkpoint_path=checkpoint_path, device=device, skip_load=False, + ) + state.image_shape = _get_input_shape(image, is_volumetric) + + if is_volumetric: + if image.ndim not in [3, 4]: + raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") + annotator = Annotator3d(viewer) + else: + if image.ndim not in (2, 3): + raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") + annotator = Annotator2d(viewer) + + annotator._update_image() + + # Add the annotator widget to the viewer and sync widgets. + viewer.window.add_dock_widget(annotator) + _sync_embedding_widget( + state.widgets["embeddings"], model_type, + save_path=embedding_path, checkpoint_path=checkpoint_path, + device=device, tile_shape=tile_shape, halo=halo + ) + return viewer, annotator + + def image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], output_folder: str, @@ -94,6 +134,7 @@ def image_series_annotator( is_volumetric: bool = False, device: Optional[Union[str, torch.device]] = None, prefer_decoder: bool = True, + skip_segmented: bool = True, ) -> Optional["napari.viewer.Viewer"]: """Run the annotation tool for a series of images (supported for both 2d and 3d images). @@ -116,13 +157,13 @@ def image_series_annotator( is_volumetric: Whether to use the 3d annotator. prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. + skip_segmented: Whether to skip images that were already segmented. Returns: The napari viewer, only returned if `return_viewer=True`. """ - + end_msg = "You have annotated the last image. Do you wish to close napari?" os.makedirs(output_folder, exist_ok=True) - next_image_id = 0 # Precompute embeddings and amg state (if corresponding options set). predictor, decoder, embedding_paths = _precompute( @@ -132,57 +173,48 @@ def image_series_annotator( ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, ) - # Load the first image and intialize the viewer, annotator and state. - if isinstance(images[next_image_id], np.ndarray): - image = images[next_image_id] - have_inputs_as_arrays = True - else: - image = imageio.imread(images[next_image_id]) - have_inputs_as_arrays = False - - image_embedding_path = embedding_paths[next_image_id] - - if viewer is None: - viewer = napari.Viewer() - viewer.add_image(image, name="image") - - state = AnnotatorState() - state.initialize_predictor( - image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, - predictor=predictor, decoder=decoder, - ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, - checkpoint_path=checkpoint_path, device=device, skip_load=False, - ) - state.image_shape = _get_input_shape(image, is_volumetric) - - if is_volumetric: - if image.ndim not in [3, 4]: - raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") - annotator = Annotator3d(viewer) - else: - if image.ndim not in (2, 3): - raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") - annotator = Annotator2d(viewer) - - annotator._update_image() - - # Add the annotator widget to the viewer and sync widgets. - viewer.window.add_dock_widget(annotator) - _sync_embedding_widget( - state.widgets["embeddings"], model_type, - save_path=embedding_path, checkpoint_path=checkpoint_path, - device=device, tile_shape=tile_shape, halo=halo - ) + next_image_id = 0 + have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) - def _save_segmentation(image_path, current_idx, segmentation): + def _get_save_path(image_path, current_idx): if have_inputs_as_arrays: fname = f"seg_{current_idx:05}.tif" else: fname = os.path.basename(image_path) fname = os.path.splitext(fname)[0] + ".tif" + return os.path.join(output_folder, fname) + + # Check which image to load next if we skip segmented images. + image_embedding_path = None + if skip_segmented: + while True: + if next_image_id == len(images): + print(end_msg) + return - out_path = os.path.join(output_folder, fname) - imageio.imwrite(out_path, segmentation) + save_path = _get_save_path(images[next_image_id], next_image_id) + if not os.path.exists(save_path): + print("The first image to annotate is image number", next_image_id) + image = images[next_image_id] + if not have_inputs_as_arrays: + image = imageio.imread(image) + image_embedding_path = embedding_paths[next_image_id] + break + + next_image_id += 1 + + # Initialize the viewer and annotator for this image. + state = AnnotatorState() + viewer, annotator = _initialize_annotator( + viewer, image, image_embedding_path, + model_type, halo, tile_shape, predictor, decoder, is_volumetric, + precompute_amg_state, checkpoint_path, device, + embedding_path, + ) + + def _save_segmentation(image_path, current_idx, segmentation): + save_path = _get_save_path(image_path, next_image_id) + imageio.imwrite(save_path, segmentation, compression="zlib") # Add functionality for going to the next image. @magicgui(call_button="Next Image [N]") @@ -203,14 +235,20 @@ def next_image(*args): # Clear the segmentation already to avoid lagging removal. viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) - # Load the next image. + # Go to the next images, if skipping images that are already segmented check if we have to load it. next_image_id += 1 + if skip_segmented: + save_path = _get_save_path(images[next_image_id], next_image_id) + while os.path.exists(save_path): + next_image_id += 1 + if next_image_id == len(images): + break + save_path = _get_save_path(images[next_image_id], next_image_id) + + # Load the next image. if next_image_id == len(images): - msg = "You have annotated the last image. Do you wish to close napari?" - print(msg) - abort = False - # inform the user via dialog - abort = widgets._generate_message("info", msg) + # Inform the user via dialog. + abort = widgets._generate_message("info", end_msg) if not abort: viewer.close() return @@ -459,6 +497,7 @@ def main(): ) parser.add_argument("--precompute_amg_state", action="store_true") parser.add_argument("--prefer_decoder", action="store_false") + parser.add_argument("--skip_segmented", action="store_false") args = parser.parse_args() @@ -467,5 +506,5 @@ def main(): embedding_path=args.embedding_path, model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, - prefer_decoder=args.prefer_decoder, + prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented ) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index db2ec5187..e0b3b88a5 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -118,6 +118,8 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: viewer.layers["point_prompts"].data = [] viewer.layers["point_prompts"].refresh() if "prompts" in viewer.layers: + # Select all prompts and then remove them. + viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) viewer.layers["prompts"].remove_selected() viewer.layers["prompts"].refresh() if not clear_segmentations: