From fbfde8edc83708967ad52f480c42925db881f552 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 11 Oct 2023 14:29:21 +0200 Subject: [PATCH] Update model predictor adaptor for bioimage models --- examples/model_zoo/get_bioimage_modelzoo.py | 18 ++++ micro_sam/model_zoo.py | 101 ++++++++++++++++---- micro_sam/predictor_adaptor.py | 38 ++++++++ 3 files changed, 138 insertions(+), 19 deletions(-) create mode 100644 examples/model_zoo/get_bioimage_modelzoo.py create mode 100644 micro_sam/predictor_adaptor.py diff --git a/examples/model_zoo/get_bioimage_modelzoo.py b/examples/model_zoo/get_bioimage_modelzoo.py new file mode 100644 index 00000000..bcf91e1d --- /dev/null +++ b/examples/model_zoo/get_bioimage_modelzoo.py @@ -0,0 +1,18 @@ +from micro_sam import model_zoo + + +def main(): + parser = model_zoo._get_modelzoo_parser() + args = parser.parse_args() + + model_zoo.get_modelzoo_yaml( + image_path=args.input_path, + box_prompts=None, + model_type=args.model_type, + output_path=args.output_path, + doc_path=args.doc_path + ) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/model_zoo.py b/micro_sam/model_zoo.py index 600a075d..2287575b 100644 --- a/micro_sam/model_zoo.py +++ b/micro_sam/model_zoo.py @@ -1,24 +1,63 @@ import os +import argparse import numpy as np from glob import glob +from typing import List + import imageio.v2 as imageio +import torch + +from micro_sam import util + +from .predictor_adaptor import PredictorAdaptor +from .prompt_based_segmentation import _compute_box_from_mask + from bioimageio.core.build_spec import build_model -def _get_livecell_npy_path(input_dir): +def _get_model(image, model_type): + "Returns the model and predictor while initializing with the model checkpoints" + predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore + image_embeddings = util.precompute_image_embeddings(predictor, image) + util.set_precomputed(predictor, image_embeddings) + return predictor, sam_model + + +def _get_livecell_npy_paths( + input_dir: str, + model_type: str +): test_img_paths = sorted(glob(os.path.join(input_dir, "images", "livecell_test_images", "*"))) - input_image = imageio.imread(test_img_paths[0]) + chosen_input = test_img_paths[0] + + input_image = imageio.imread(chosen_input) + + fname = os.path.split(chosen_input)[-1] + cell_type = fname.split("_")[0] + label_image = imageio.imread(os.path.join(input_dir, "annotations", "livecell_test_images", cell_type, fname)) + save_image_path = "./test-livecell-image.npy" np.save(save_image_path, input_image) - # TODO: probably we need the prompt inputs here as well + predictor, sam_model = _get_model(input_image, model_type) + get_instance_segmentation = PredictorAdaptor(sam_model=sam_model) - # TODO: get output paths - # outputs: model(inputs) -> outputs: converted to numpy - save_output_path = ".npy" + box_prompts = _compute_box_from_mask(label_image) + save_box_prompt_path = "./test-box-prompts.npy" + np.save(save_box_prompt_path, box_prompts) - return [save_image_path], [save_output_path] + instances = get_instance_segmentation( + input_image=torch.from_numpy(input_image)[None, None], + predictor=predictor, + image_embeddings=None, + box_prompts=torch.from_numpy(box_prompts)[None] + ) + + save_output_path = "./test-livecell-output.npy" + np.save(save_output_path, instances.squeeze().numpy()) + + return [save_image_path, save_box_prompt_path], [save_output_path] def _get_documentation(doc_path): @@ -28,11 +67,30 @@ def _get_documentation(doc_path): return doc_path -def _get_modelzoo_yaml(): - input_list, output_list = _get_livecell_npy_path("/scratch/usr/nimanwai/data/livecell") +def _get_sam_checkpoints(model_type): + checkpoint = util._get_checkpoint(model_type, None) + print(f"{model_type} is available at {checkpoint}") + return checkpoint + + +def get_modelzoo_yaml( + image_path: str, + box_prompts: List[int], + model_type: str, + output_path: str, + doc_path: str +): + # load the model and the image and prompts + # feed prompts and image to the model to get the output + # save the numpy file for the output to get the expected data + + input_list, output_list = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type) + _checkpoint = _get_sam_checkpoints(model_type) + + breakpoint() build_model( - weight_uri="~/.sam_models/vit_t_mobile_sam.pth", + weight_uri=_checkpoint, test_inputs=input_list, # type: ignore test_outputs=output_list, # type: ignore input_axes=["bcyx"], @@ -42,16 +100,21 @@ def _get_modelzoo_yaml(): authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"}, {"name": "Constantin Pape", "affiliation": "Uni Goettingen"}], tags=["instance-segmentation", "segment-anything"], - license="CC-BY-4.0", # TODO: check with Constantin - documentation=_get_documentation("./doc.md"), + license="CC-BY-4.0", + documentation=_get_documentation(doc_path), cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}], - output_path="./modelzoo/my_micro_sam.zip" + output_path=output_path ) -def main(): - _get_modelzoo_yaml() - - -if __name__ == "__main__": - main() +def _get_modelzoo_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_path", type=str, + help="Path to the raw inputs' directory") + parser.add_argument("-m", "--model_type", type=str, default="vit_b", + help="Name of the model to get the SAM checkpoints") + parser.add_argument("-o", "--output_path", type=str, default="./models/sam.zip", + help="Path to the output bioimage modelzoo-format SAM model") + parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md", + help="Path to the documentation") + return parser diff --git a/micro_sam/predictor_adaptor.py b/micro_sam/predictor_adaptor.py new file mode 100644 index 00000000..c6931369 --- /dev/null +++ b/micro_sam/predictor_adaptor.py @@ -0,0 +1,38 @@ +from typing import Optional + +import torch + +from segment_anything.predictor import SamPredictor + + +class PredictorAdaptor(SamPredictor): + """Wrapper around the SamPredictor to be used by BioImage.IO model format. + + This model supports the same functionality as SamPredictor and can provide mask segmentations + from box, point or mask input prompts. + """ + def __call__( + self, + input_image: torch.Tensor, + image_embeddings: Optional[torch.Tensor] = None, + box_prompts: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed + pass # do nothing + elif self.is_image_set and image_embeddings is not None: + raise NotImplementedError # TODO: replace the image embeedings + elif image_embeddings is not None: + pass # TODO set the image embeddings + # self.features = image_embeddings + elif not self.is_image_set: + self.set_torch_image(input_image) # compute the image embeddings + + instance_segmentation, _, _ = self.predict_torch( + point_coords=None, + point_labels=None, + boxes=box_prompts, + multimask_output=False + ) + # TODO get the image embeddings via image_embeddings = self.features + # and return them + return instance_segmentation