diff --git a/examples/use_as_library/instance_segmentation.py b/examples/use_as_library/instance_segmentation.py index be9300d3..a447ed0a 100644 --- a/examples/use_as_library/instance_segmentation.py +++ b/examples/use_as_library/instance_segmentation.py @@ -2,6 +2,7 @@ import napari from micro_sam import instance_segmentation, util +from micro_sam.multi_dimensional_segmentation import segment_3d_from_slice def cell_segmentation(): @@ -32,36 +33,15 @@ def cell_segmentation(): # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh' # without having to call initialize again. - instances_amg = amg.generate(pred_iou_thresh=0.88) - instances_amg = instance_segmentation.mask_data_to_segmentation( - instances_amg, shape=image.shape, with_background=True - ) - - # Use the mutex waterhsed based instance segmentation logic. - # Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm. - # These initial masks are used as prompts for the actual instance segmentation. - # This class uses the same overall design as 'AutomaticMaskGenerator'. - - # Create the automatic mask generator class. - amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10) - - # Initialize the mask generator with the image and the pre-computed embeddings. - amg_mws.initialize(image, embeddings, verbose=True) - - # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh' - # without having to call initialize again. - # NOTE: the main advantage of this method is that it's faster than the original implementation, - # however the quality is not as high as the original instance segmentation quality yet. - instances_mws = amg_mws.generate(pred_iou_thresh=0.88) - instances_mws = instance_segmentation.mask_data_to_segmentation( - instances_mws, shape=image.shape, with_background=True + instances = amg.generate(pred_iou_thresh=0.88) + instances = instance_segmentation.mask_data_to_segmentation( + instances, shape=image.shape, with_background=True ) # Show the results. v = napari.Viewer() v.add_image(image) - v.add_labels(instances_amg) - v.add_labels(instances_mws) + v.add_labels(instances) napari.run() @@ -94,39 +74,71 @@ def cell_segmentation_with_tiling(): # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh' # without having to call initialize again. - instances_amg = amg.generate(pred_iou_thresh=0.88) - instances_amg = instance_segmentation.mask_data_to_segmentation( - instances_amg, shape=image.shape, with_background=True + instances = amg.generate(pred_iou_thresh=0.88) + instances = instance_segmentation.mask_data_to_segmentation( + instances, shape=image.shape, with_background=True ) - # Use the mutex waterhsed based instance segmentation logic. - # Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm. - # These initial masks are used as prompts for the actual instance segmentation. - # This class uses the same overall design as 'AutomaticMaskGenerator'. - - # Create the automatic mask generator class. - amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10) + # Show the results. + v = napari.Viewer() + v.add_image(image) + v.add_labels(instances) + v.add_labels(instances) + napari.run() - # Initialize the mask generator with the image and the pre-computed embeddings. - amg_mws.initialize(image, embeddings, verbose=True) - # Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh' - # without having to call initialize again. - # NOTE: the main advantage of this method is that it's faster than the original implementation. - # however the quality is not as high as the original instance segmentation quality yet. - instances_mws = amg_mws.generate(pred_iou_thresh=0.88) +def segmentation_in_3d(): + """Run instance segmentation in 3d, for segmenting all objects that intersect + with a given slice. If you use a fine-tuned model for this then you should + first find good parameters for 2d segmentation. + """ + import imageio.v3 as imageio + from micro_sam.sample_data import fetch_nucleus_3d_example_data + + # Load the example image data: 3d nucleus segmentation. + path = fetch_nucleus_3d_example_data("./data") + data = imageio.imread(path) + + # Load the SAM model for prediction. + model_type = "vit_b" # The model-type to use: vit_h, vit_l, vit_b etc. + checkpoint_path = None # You can specifiy the path to a custom (fine-tuned) model here. + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + + # Run 3d segmentation for a given slice. Will segment all objects found in that slice + # throughout the volume. + + # The slice that is used for segmentation in 2d. If you don't specify a slice + # then the middle slice is used. + z_slice = data.shape[0] // 2 + + # The threshold for filtering objects in the 2d segmentation based on the model's + # predicted iou score. If you use a custom model you should first find a good setting + # for this value, e.g. with the 2d annotation tool. + pred_iou_thresh = 0.88 + + # The threshold for filtering objects in the 2d segmentation based on the model's + # stability score for a given object. If you use a custom model you should first find a good setting + # for this value, e.g. with the 2d annotation tool. + stability_score_thresh = 0.95 + + instances = segment_3d_from_slice( + predictor, data, z=z_slice, + pred_iou_thresh=pred_iou_thresh, + stability_score_thresh=stability_score_thresh, + verbose=True + ) # Show the results. v = napari.Viewer() - v.add_image(image) - v.add_labels(instances_amg) - v.add_labels(instances_mws) + v.add_image(data) + v.add_labels(instances) napari.run() def main(): - cell_segmentation() + # cell_segmentation() # cell_segmentation_with_tiling() + segmentation_in_3d() if __name__ == "__main__": diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 3ac5a24b..3abb1f33 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -53,6 +53,7 @@ def mask_data_to_segmentation( shape: Tuple[int, ...], with_background: bool, min_object_size: int = 0, + max_object_size: Optional[int] = None, ) -> np.ndarray: """Convert the output of the automatic mask generation to an instance segmentation. @@ -63,6 +64,7 @@ def mask_data_to_segmentation( with_background: Whether the segmentation has background. If yes this function assures that the largest object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. + max_object_size: The maximal size of an object in pixels. Returns: The instance segmentation. """ @@ -77,6 +79,8 @@ def require_numpy(mask): for mask in masks: if mask["area"] < min_object_size: continue + if max_object_size is not None and mask["area"] > max_object_size: + continue this_seg_id = mask.get("seg_id", seg_id) segmentation[require_numpy(mask["segmentation"])] = this_seg_id diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index d4bccaaa..49020ef1 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -1,13 +1,16 @@ -""" -Multi-dimensional segmentation with segment anything. +"""Multi-dimensional segmentation with segment anything. """ -from typing import Any, Optional +import os +from typing import Any, Optional, Union import numpy as np from segment_anything.predictor import SamPredictor +from tqdm import tqdm from . import util +from .instance_segmentation import AutomaticMaskGenerator, mask_data_to_segmentation +from .precompute_state import cache_amg_state from .prompt_based_segmentation import segment_from_mask @@ -21,7 +24,7 @@ def segment_mask_in_volume( iou_threshold: float, projection: str, progress_bar: Optional[Any] = None, - box_extension: int = 0, + box_extension: float = 0.0, ) -> np.ndarray: """Segment an object mask in in volumetric data. @@ -35,7 +38,7 @@ def segment_mask_in_volume( iou_threshold: The IOU threshold for continuing segmentation across 3d. projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'. progress_bar: Optional progress bar. - box_extension: Extension factor for increasing the box size after projection + box_extension: Extension factor for increasing the box size after projection. Returns: Array with the volumetric segmentation @@ -132,3 +135,79 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None _update_progress() return segmentation + + +def segment_3d_from_slice( + predictor: SamPredictor, + raw: np.ndarray, + z: Optional[int] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, + projection: str = "mask", + box_extension: float = 0.0, + verbose: bool = True, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + min_object_size_z: int = 50, + max_object_size_z: Optional[int] = None, + iou_threshold: float = 0.8, +): + """Segment all objects in a volume intersecting with a specific slice. + + This function first segments the objects in the specified slice using the + automatic instance segmentation functionality. Then it segments all objects that + were found in that slice in the volume. + + Args: + predictor: The segment anything predictor. + raw: The volumetric image data. + z: The slice from which to start segmentation. + If none is given the central slice will be used. + embedding_path: The path were embeddings will be cached. + If none is given embeddings will not be cached. + projection: The projection method to use. One of 'mask', 'bounding_box' or 'points'. + box_extension: Extension factor for increasing the box size after projection. + verbose: Whether to print progress bar and other status messages. + pred_iou_thresh: The predicted iou value to filter objects in `AutomaticMaskGenerator.generate`. + stability_score_thresh: The stability score to filter objects in `AutomaticMaskGenerator.generate`. + min_object_size_z: Minimal object size in the segmented frame. + max_object_size_z: Maximal object size in the segmented frame. + iou_threshold: The IOU threshold for linking objects across slices. + + Returns: + Segmentation volume. + """ + # Compute the image embeddings. + image_embeddings = util.precompute_image_embeddings(predictor, raw, save_path=embedding_path, ndim=3) + + # Select the middle slice if no slice is given. + if z is None: + z = raw.shape[0] // 2 + + # Perform automatic instance segmentation. + if embedding_path is not None: + amg = cache_amg_state(predictor, raw[z], image_embeddings, embedding_path, verbose=verbose, i=z) + else: + amg = AutomaticMaskGenerator(predictor) + amg.initialize(raw[z], image_embeddings, i=z, verbose=verbose) + + seg_z = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh) + seg_z = mask_data_to_segmentation( + seg_z, shape=raw.shape[1:], with_background=True, + min_object_size=min_object_size_z, + max_object_size=max_object_size_z, + ) + + # Segment all objects that were found in 3d. + seg_ids = np.unique(seg_z)[1:] + segmentation = np.zeros(raw.shape, dtype=seg_z.dtype) + for seg_id in tqdm(seg_ids, desc="Segment objects in 3d", disable=not verbose): + this_seg = np.zeros_like(segmentation) + this_seg[z][seg_z == seg_id] = 1 + this_seg = segment_mask_in_volume( + this_seg, predictor, image_embeddings, + segmented_slices=np.array([z]), stop_lower=False, stop_upper=False, + iou_threshold=iou_threshold, projection=projection, box_extension=box_extension, + ) + segmentation[this_seg > 0] = seg_id + + return segmentation diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 8d4eca20..66e6edf5 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -22,6 +22,7 @@ def cache_amg_state( image_embeddings: util.ImageEmbeddings, save_path: Union[str, os.PathLike], verbose: bool = True, + i: Optional[int] = None, **kwargs, ) -> instance_segmentation.AMGBase: """Compute and cache or load the state for the automatic mask generator. @@ -32,6 +33,7 @@ def cache_amg_state( image_embeddings: The image embeddings. save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'. verbose: Whether to run the computation verbose. + i: The index for which to cache the state. kwargs: The keyword arguments for the amg class. Returns: @@ -40,7 +42,14 @@ def cache_amg_state( is_tiled = image_embeddings["input_size"] is None amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs) - save_path_amg = os.path.join(save_path, "amg_state.pickle") + # If i is given we compute the state for a given slice/frame. + # And we have to save the state for slices/frames separately. + if i is None: + save_path_amg = os.path.join(save_path, "amg_state.pickle") + else: + os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True) + save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl") + if os.path.exists(save_path_amg): if verbose: print("Load the AMG state from", save_path_amg) @@ -52,7 +61,7 @@ def cache_amg_state( if verbose: print("Precomputing the state for instance segmentation.") - amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose) + amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i) amg_state = amg.get_state() # put all state onto the cpu so that the state can be deserialized without a gpu diff --git a/micro_sam/sample_data.py b/micro_sam/sample_data.py index bef9c71e..b7523a98 100644 --- a/micro_sam/sample_data.py +++ b/micro_sam/sample_data.py @@ -322,12 +322,12 @@ def sample_data_segmentation(): return [(data, add_image_kwargs)] -def synthetic_data(shape): +def synthetic_data(shape, seed=None): """Create synthetic image data and segmentation for training.""" ndim = len(shape) assert ndim in (2, 3) image_shape = shape if ndim == 2 else shape[1:] - image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15) + image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15, seed=seed) if image_shape[1] != image_shape[0]: image = resize(image, image_shape, order=0, anti_aliasing=False, preserve_range=True).astype(image.dtype) @@ -337,3 +337,29 @@ def synthetic_data(shape): segmentation = label(image) return image, segmentation + + +def fetch_nucleus_3d_example_data(save_directory: Union[str, os.PathLike]) -> str: + """Download the sample data for 3d segmentation of nuclei. + + This data contains a small crop from a volume from the publication + "Efficient automatic 3D segmentation of cell nuclei for high-content screening" + https://doi.org/10.1186/s12859-022-04737-4 + + Args: + save_directory: Root folder to save the downloaded data. + Returns: + The path of the downloaded image. + """ + save_directory = Path(save_directory) + os.makedirs(save_directory, exist_ok=True) + print("Example data directory is:", save_directory.resolve()) + fname = "3d-nucleus-data.tif" + pooch.retrieve( + url="https://owncloud.gwdg.de/index.php/s/eW0uNCo8gedzWU4/download", + known_hash="4946896f747dc1c3fc82fb2e1320226d92f99d22be88ea5f9c37e3ba4e281205", + fname=fname, + path=save_directory, + progressbar=True, + ) + return os.path.join(save_directory, fname)