Skip to content

Commit

Permalink
Merge branch 'dev' into support-3d-iterative-prompting
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Sep 29, 2024
2 parents c16db47 + 22cd606 commit 93b8f2b
Show file tree
Hide file tree
Showing 20 changed files with 623 additions and 229 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.10"]
python-version: ["3.11", "3.12"]

steps:
- name: Checkout
Expand Down
6 changes: 6 additions & 0 deletions doc/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ This can happen for long running computations. You just need to wait a bit longe
`micro_sam` performs automatic segmentation in 3D volumes by first segmenting slices individually in 2D and merging the segmentations across 3D based on overlap of objects between slices. The expected shape of your 3D RGB volume should be `(Z * Y * X * 3)` (reason: Segment Anything is devised to consider 3-channel inputs, so while the user provides micro-sam with 1-channel inputs, we handle this by triplicating this to fit the requirement, or with 3-channel inputs, we use them in the expected RGB array structures as it is).


### 15. I want to use a model stored in a different directory than the `micro_sam` cache. How can I do this?
The `micro-sam` CLIs for precomputation of image embeddings and annotators (Annotator 2d, Annotator 3d, Annotator Tracking, Image Series Annotator) accept the argument `-c` / `--checkpoint` to pass model checkpoints. If you start a `micro-sam` annotator from the `napari` plugin menu, you can provide the path to model checkpoints in the annotator widget (on right) under `Embedding Settings` drop-down in the `custom weights path` option.

NOTE: It is important to choose the correct model type when you opt for the above recommendation, using the `-m / --model_type` argument or selecting it from the `Model` dropdown in `Embedding Settings` respectively. Otherwise you will face parameter mismatch issues.


## Fine-tuning questions


Expand Down
3 changes: 2 additions & 1 deletion environment_cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ channels:
- conda-forge
dependencies:
- cpuonly
- nifty =1.2.1=*_4
- imagecodecs
- magicgui
- napari
- napari <0.5
- pip
- pooch
- pyqt
Expand Down
3 changes: 2 additions & 1 deletion environment_gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ channels:
- conda-forge
dependencies:
- imagecodecs
- nifty =1.2.1=*_4
- magicgui
- napari
- napari <0.5
- pip
- pooch
- pyqt
Expand Down
2 changes: 1 addition & 1 deletion examples/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def wholeslide_annotator(use_finetuned_model):

def main():
# Whether to use the fine-tuned SAM model for light microscopy data.
use_finetuned_model = False
use_finetuned_model = True

# 2d annotator for livecell data
livecell_annotator(use_finetuned_model)
Expand Down
198 changes: 198 additions & 0 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import os
from pathlib import Path
from typing import Dict, Optional, Union, Tuple

import numpy as np
import imageio.v3 as imageio

from . import util
from .instance_segmentation import (
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase
)
from .multi_dimensional_segmentation import automatic_3d_segmentation


def automatic_instance_segmentation(
input_path: Union[Union[os.PathLike, str], np.ndarray],
output_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str]] = None,
model_type: str = util._DEFAULT_MODEL,
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
key: Optional[str] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
use_amg: bool = False,
amg_kwargs: Optional[Dict] = None,
**generate_kwargs
) -> np.ndarray:
"""Run automatic segmentation for the input image.
Args:
input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
or a container file (e.g. hdf5 or zarr).
output_path: The output path where the instance segmentations will be saved.
embedding_path: The path where the embeddings are cached already / will be saved.
model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
ndim: The dimensionality of the data.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method.
amg_kwargs: optional keyword arguments for creating the AMG or AIS class.
generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class.
Returns:
The segmentation result.
"""
predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)

if "decoder_state" in state and not use_amg: # AIS
decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"])
segmenter = get_amg(
predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None,
**({} if amg_kwargs is None else amg_kwargs)
)
else: # AMG
segmenter = get_amg(
predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs)
)

# Load the input image file.
if isinstance(input_path, np.ndarray):
image_data = input_path
else:
image_data = util.load_image_data(input_path, key)

if ndim == 3 or image_data.ndim == 3:
if image_data.ndim != 3:
raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'")

instances = automatic_3d_segmentation(
volume=image_data,
predictor=predictor,
segmentor=segmenter,
embedding_path=embedding_path,
tile_shape=tile_shape,
halo=halo,
**generate_kwargs
)
else:
# Precompute the image embeddings.
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
input_=image_data,
save_path=embedding_path,
ndim=ndim,
tile_shape=tile_shape,
halo=halo,
)

segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
masks = segmenter.generate(**generate_kwargs)

if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels
if isinstance(segmenter, InstanceSegmentationWithDecoder):
this_shape = segmenter._foreground.shape
elif isinstance(segmenter, AMGBase):
this_shape = segmenter._original_size
else:
this_shape = image_data.shape[-2:]

instances = np.zeros(this_shape, dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)

if output_path is not None:
# Save the instance segmentation
output_path = Path(output_path).with_suffix(".tif")
imageio.imwrite(output_path, instances, compression="zlib")

return instances


def main():
"""@private"""
import argparse

available_models = list(util.get_model_names())
available_models = ", ".join(available_models)

parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
parser.add_argument(
"-i", "--input_path", required=True,
help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
"or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
)
parser.add_argument(
"-o", "--output_path", required=True,
help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
)
parser.add_argument(
"-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
)
parser.add_argument(
"--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
)
parser.add_argument(
"-k", "--key",
help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
"for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
)
parser.add_argument(
"-m", "--model_type", default=util._DEFAULT_MODEL,
help=f"The segment anything model that will be used, one of {available_models}."
)
parser.add_argument(
"-c", "--checkpoint", default=None,
help="Checkpoint from which the SAM model will be loaded loaded."
)
parser.add_argument(
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
)
parser.add_argument(
"--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
)
parser.add_argument(
"-n", "--ndim", type=int, default=None,
help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
)
parser.add_argument(
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
)

args, parameter_args = parser.parse_known_args()

def _convert_argval(value):
# The values for the parsed arguments need to be in the expected input structure as provided.
# i.e. integers and floats should be in their original types.
try:
return int(value)
except ValueError:
return float(value)

# NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
# the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
generate_kwargs = {
parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

automatic_instance_segmentation(
input_path=args.input_path,
output_path=args.output_path,
embedding_path=args.embedding_path,
model_type=args.model_type,
checkpoint_path=args.checkpoint,
key=args.key,
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
use_amg=args.amg,
**generate_kwargs,
)


if __name__ == "__main__":
main()
21 changes: 15 additions & 6 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@

import vigra
import numpy as np
from skimage.measure import regionprops

import torch

from skimage.measure import label, regionprops
from skimage.segmentation import relabel_sequential
from torchvision.ops.boxes import batched_nms, box_area

from torch_em.model import UNETR
Expand Down Expand Up @@ -47,6 +48,7 @@ def mask_data_to_segmentation(
with_background: bool,
min_object_size: int = 0,
max_object_size: Optional[int] = None,
label_masks: bool = True,
) -> np.ndarray:
"""Convert the output of the automatic mask generation to an instance segmentation.
Expand All @@ -57,7 +59,11 @@ def mask_data_to_segmentation(
object in the output will be mapped to zero (the background value).
min_object_size: The minimal size of an object in pixels.
max_object_size: The maximal size of an object in pixels.
<<<<<<< HEAD
=======
label_masks: Whether to apply connected components to the result before remving small objects.
>>>>>>> master
Returns:
The instance segmentation.
"""
Expand All @@ -81,6 +87,8 @@ def require_numpy(mask):
segmentation[require_numpy(mask["segmentation"])] = this_seg_id
seg_id = this_seg_id + 1

if label_masks:
segmentation = label(segmentation)
seg_ids, sizes = np.unique(segmentation, return_counts=True)

# In some cases objects may be smaller than peviously calculated,
Expand All @@ -96,7 +104,7 @@ def require_numpy(mask):
filter_ids = np.concatenate([filter_ids, [bg_id]])

segmentation[np.isin(segmentation, filter_ids)] = 0
vigra.analysis.relabelConsecutive(segmentation, out=segmentation)
segmentation = relabel_sequential(segmentation)[0]

return segmentation

Expand Down Expand Up @@ -800,7 +808,7 @@ def get_predictor_and_decoder(
model_type: str,
checkpoint_path: Union[str, os.PathLike],
device: Optional[Union[str, torch.device]] = None,
peft_kwargs: Optional[Dict] = None,
peft_kwargs: Optional[Dict] = None,
) -> Tuple[SamPredictor, DecoderAdapter]:
"""Load the SAM model (predictor) and instance segmentation decoder.
Expand All @@ -826,7 +834,9 @@ def get_predictor_and_decoder(
peft_kwargs=peft_kwargs,
)
if "decoder_state" not in state:
raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state")
raise ValueError(
f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state"
)
decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device)
return predictor, decoder

Expand Down Expand Up @@ -957,7 +967,6 @@ def to_bbox_3d(bbox):
]
return masks

# TODO find good default values (empirically)
def generate(
self,
center_distance_threshold: float = 0.5,
Expand Down
9 changes: 7 additions & 2 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ def merge_instance_segmentation_3d(
return segmentation


# TODO: Enable tiling
def automatic_3d_segmentation(
volume: np.ndarray,
predictor: SamPredictor,
Expand All @@ -365,6 +364,8 @@ def automatic_3d_segmentation(
with_background: bool = True,
gap_closing: Optional[int] = None,
min_z_extent: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
verbose: bool = True,
**kwargs,
) -> np.ndarray:
Expand All @@ -383,6 +384,8 @@ def automatic_3d_segmentation(
operation. The value is used to determine the number of iterations for the closing.
min_z_extent: Require a minimal extent in z for the segmented objects.
This can help to prevent segmentation artifacts.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
verbose: Verbosity flag.
kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
Expand All @@ -393,7 +396,9 @@ def automatic_3d_segmentation(
segmentation = np.zeros(volume.shape, dtype="uint32")

min_object_size = kwargs.pop("min_object_size", 0)
image_embeddings = util.precompute_image_embeddings(predictor, volume, save_path=embedding_path, ndim=3)
image_embeddings = util.precompute_image_embeddings(
predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, tile_shape=tile_shape, halo=halo,
)

for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose):
segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i)
Expand Down
4 changes: 2 additions & 2 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def precompute_state(
a container file (e.g. hdf5 or zarr) or a folder with images files.
In case of a container file the argument `key` must be given. In case of a folder
it can be given to provide a glob pattern to subselect files from the folder.
output_path: The output path were the embeddings and other state will be saved.
output_path: The output path where the embeddings and other state will be saved.
pattern: Glob pattern to select files in a folder. The embeddings will be computed
for each of these files. To select all files in a folder pass "*".
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
model_type: The SegmentAnything model to use. Will use the standard vit_l model by default.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def initialize_predictor(
prefer_decoder=True,
pbar_init=None,
pbar_update=None,
skip_load=True,
):
assert ndim in (2, 3)

Expand Down Expand Up @@ -128,7 +129,7 @@ def initialize_predictor(
raise RuntimeError("Require a save path to precompute the amg state")

cache_state = cache_amg_state if self.decoder is None else partial(
cache_is_state, decoder=self.decoder, skip_load=True,
cache_is_state, decoder=self.decoder, skip_load=skip_load,
)

if ndim == 2:
Expand Down
Loading

0 comments on commit 93b8f2b

Please sign in to comment.