diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 504a5a4b..f79f7b46 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -1152,15 +1152,20 @@ def automatic_instance_segmentation( ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, + use_amg: bool = False, + **generate_kwargs ): """ """ - predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint_path) - if tile_shape is None: - segmenter = InstanceSegmentationWithDecoder(predictor=predictor, decoder=decoder) - else: - segmenter = TiledInstanceSegmentationWithDecoder(predictor=predictor, decoder=decoder) + predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) + if "decoder_state" in state and not use_amg: # AIS + predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint_path) + segmenter = get_amg(predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None) + else: # AMG + segmenter = get_amg(predictor=predictor, is_tiled=tile_shape is not None) + + # Load the input image file. if isinstance(input_path, np.ndarray): image_data = input_path else: @@ -1168,11 +1173,10 @@ def automatic_instance_segmentation( # Precompute the image embeddings. image_embeddings = util.precompute_image_embeddings( - predictor, image_data, embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, + predictor=predictor, input_=image_data, save_path=embedding_path, ndim=ndim, tile_shape=tile_shape, halo=halo, ) segmenter.initialize(image=image_data, image_embeddings=image_embeddings) - generate_kwargs = {} # TODO: check how we can allow users to pass parameters. masks = segmenter.generate(**generate_kwargs) if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels @@ -1187,6 +1191,12 @@ def automatic_instance_segmentation( else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) + import napari + v = napari.Viewer() + v.add_image(image_data) + v.add_labels(instances) + napari.run() + breakpoint() @@ -1232,11 +1242,19 @@ def main(): ) parser.add_argument( "-n", "--ndim", type=int, default=None, - help="The number of spatial dimensions in the data. " - "Please specify this if your data has a channel dimension." + help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." ) + parser.add_argument( + "--amg", action="store_true", help="Whether to use automatic mask generation with the model." + ) + + args, parameter_args = parser.parse_known_args() + # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to + # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) + generate_kwargs = { + parameter_args[i].lstrip("--"): float(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) + } - args = parser.parse_args() automatic_instance_segmentation( input_path=args.input_path, embedding_path=args.embedding_path, @@ -1246,6 +1264,8 @@ def main(): ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, + use_amg=args.amg, + **generate_kwargs, )