Skip to content

Commit

Permalink
Add support for AMG
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Sep 24, 2024
1 parent f9db0c8 commit 8a7f7e9
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,27 +1152,31 @@ def automatic_instance_segmentation(
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
use_amg: bool = False,
**generate_kwargs
):
"""
"""
predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint_path)
if tile_shape is None:
segmenter = InstanceSegmentationWithDecoder(predictor=predictor, decoder=decoder)
else:
segmenter = TiledInstanceSegmentationWithDecoder(predictor=predictor, decoder=decoder)
predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)

if "decoder_state" in state and not use_amg: # AIS
predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint_path)
segmenter = get_amg(predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None)
else: # AMG
segmenter = get_amg(predictor=predictor, is_tiled=tile_shape is not None)

# Load the input image file.
if isinstance(input_path, np.ndarray):
image_data = input_path
else:
image_data = util.load_image_data(input_path, key)

# Precompute the image embeddings.
image_embeddings = util.precompute_image_embeddings(
predictor, image_data, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
predictor=predictor, input_=image_data, save_path=embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
)

segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
generate_kwargs = {} # TODO: check how we can allow users to pass parameters.
masks = segmenter.generate(**generate_kwargs)

if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels
Expand All @@ -1187,6 +1191,12 @@ def automatic_instance_segmentation(
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)

import napari
v = napari.Viewer()
v.add_image(image_data)
v.add_labels(instances)
napari.run()

breakpoint()


Expand Down Expand Up @@ -1232,11 +1242,19 @@ def main():
)
parser.add_argument(
"-n", "--ndim", type=int, default=None,
help="The number of spatial dimensions in the data. "
"Please specify this if your data has a channel dimension."
help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension."
)
parser.add_argument(
"--amg", action="store_true", help="Whether to use automatic mask generation with the model."
)

args, parameter_args = parser.parse_known_args()
# NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to
# the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS)
generate_kwargs = {
parameter_args[i].lstrip("--"): float(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2)
}

args = parser.parse_args()
automatic_instance_segmentation(
input_path=args.input_path,
embedding_path=args.embedding_path,
Expand All @@ -1246,6 +1264,8 @@ def main():
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
use_amg=args.amg,
**generate_kwargs,
)


Expand Down

0 comments on commit 8a7f7e9

Please sign in to comment.