Skip to content

Commit

Permalink
Implement automatic 3d segmentation for a given slice
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Nov 1, 2023
1 parent b8fcaa5 commit 0df19b3
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 9 deletions.
20 changes: 19 additions & 1 deletion examples/use_as_library/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import napari

from micro_sam import instance_segmentation, util
from micro_sam.multi_dimensional_segmentation import segment_3d_from_slice


def cell_segmentation():
Expand Down Expand Up @@ -124,9 +125,26 @@ def cell_segmentation_with_tiling():
napari.run()


def segmentation_in_3d():
"""
"""
from micro_sam.sample_data import synthetic_data

shape = (5, 512, 512)
data, _ = synthetic_data(shape)
predictor = util.get_sam_model(model_type="vit_t")
seg = segment_3d_from_slice(predictor, data, embedding_path="./tmp_embeddings.zarr", verbose=True)

v = napari.Viewer()
v.add_image(data)
v.add_labels(seg)
napari.run()


def main():
cell_segmentation()
# cell_segmentation()
# cell_segmentation_with_tiling()
segmentation_in_3d()


if __name__ == "__main__":
Expand Down
65 changes: 61 additions & 4 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
Multi-dimensional segmentation with segment anything.
"""Multi-dimensional segmentation with segment anything.
"""

from typing import Any, Optional
import os
from typing import Any, Optional, Union

import numpy as np
from segment_anything.predictor import SamPredictor
from tqdm import tqdm

from . import util
from .instance_segmentation import AutomaticMaskGenerator, mask_data_to_segmentation
from .precompute_state import cache_amg_state
from .prompt_based_segmentation import segment_from_mask


Expand All @@ -21,7 +24,7 @@ def segment_mask_in_volume(
iou_threshold: float,
projection: str,
progress_bar: Optional[Any] = None,
box_extension: int = 0,
box_extension: float = 0.0,
) -> np.ndarray:
"""Segment an object mask in in volumetric data.
Expand Down Expand Up @@ -132,3 +135,57 @@ def segment_range(z_start, z_stop, increment, stopping_criterion, threshold=None
_update_progress()

return segmentation


def segment_3d_from_slice(
predictor: SamPredictor,
raw: np.ndarray,
z: Optional[int] = None,
embedding_path: Optional[Union[str, os.PathLike]] = None,
projection: str = "mask",
box_extension: float = 0.0,
verbose: bool = True,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
min_object_size_z: int = 50,
iou_threshold: float = 0.8,
precompute_amg_state: bool = True,
):
"""
Args:
predictor:
Returns:
The
"""
# Perform automatic instance segmentation.
image_embeddings = util.precompute_image_embeddings(predictor, raw, save_path=embedding_path, ndim=3)

if z is None:
z = raw.shape[0] // 2

if precompute_amg_state and (embedding_path is not None):
amg = cache_amg_state(predictor, raw[z], image_embeddings, embedding_path, verbose=verbose, i=z)
else:
amg = AutomaticMaskGenerator(predictor)
amg.initialize(raw[z], image_embeddings, i=z, verbose=verbose)

seg_z = amg.generate(pred_iou_thresh=pred_iou_thresh, stability_score_thresh=stability_score_thresh)
seg_z = mask_data_to_segmentation(
seg_z, shape=raw.shape[1:], with_background=True, min_object_size=min_object_size_z
)

seg_ids = np.unique(seg_z)[1:]
segmentation = np.zeros(raw.shape, dtype=seg_z.dtype)
for seg_id in tqdm(seg_ids, desc="Segment objects in 3d", disable=not verbose):
this_seg = np.zeros_like(segmentation)
this_seg[z][seg_z == seg_id] = 1
this_seg = segment_mask_in_volume(
this_seg, predictor, image_embeddings,
segmented_slices=np.array([z]), stop_lower=False, stop_upper=False,
iou_threshold=iou_threshold, projection=projection, box_extension=box_extension,
)
segmentation[this_seg > 0] = seg_id

return segmentation
13 changes: 11 additions & 2 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def cache_amg_state(
image_embeddings: util.ImageEmbeddings,
save_path: Union[str, os.PathLike],
verbose: bool = True,
i: Optional[int] = None,
**kwargs,
) -> instance_segmentation.AMGBase:
"""Compute and cache or load the state for the automatic mask generator.
Expand All @@ -32,6 +33,7 @@ def cache_amg_state(
image_embeddings: The image embeddings.
save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
verbose: Whether to run the computation verbose.
i: The index for which to cache the state.
kwargs: The keyword arguments for the amg class.
Returns:
Expand All @@ -40,7 +42,14 @@ def cache_amg_state(
is_tiled = image_embeddings["input_size"] is None
amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)

save_path_amg = os.path.join(save_path, "amg_state.pickle")
# If i is given we compute the state for a given slice/frame.
# And we have to save the state for slices/frames separately.
if i is None:
save_path_amg = os.path.join(save_path, "amg_state.pickle")
else:
os.makedirs(os.path.join(save_path, "amg_state"), exist_ok=True)
save_path_amg = os.path.join(save_path, "amg_state", f"state-{i}.pkl")

if os.path.exists(save_path_amg):
if verbose:
print("Load the AMG state from", save_path_amg)
Expand All @@ -52,7 +61,7 @@ def cache_amg_state(
if verbose:
print("Precomputing the state for instance segmentation.")

amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose, i=i)
amg_state = amg.get_state()

# put all state onto the cpu so that the state can be deserialized without a gpu
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,12 @@ def sample_data_segmentation():
return [(data, add_image_kwargs)]


def synthetic_data(shape):
def synthetic_data(shape, seed=None):
"""Create synthetic image data and segmentation for training."""
ndim = len(shape)
assert ndim in (2, 3)
image_shape = shape if ndim == 2 else shape[1:]
image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15)
image = binary_blobs(length=image_shape[0], blob_size_fraction=0.05, volume_fraction=0.15, seed=seed)

if image_shape[1] != image_shape[0]:
image = resize(image, image_shape, order=0, anti_aliasing=False, preserve_range=True).astype(image.dtype)
Expand Down

0 comments on commit 0df19b3

Please sign in to comment.