Skip to content

Commit

Permalink
Update model predictor adaptor for bioimage models
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 11, 2023
1 parent 2073119 commit fbfde8e
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 19 deletions.
18 changes: 18 additions & 0 deletions examples/model_zoo/get_bioimage_modelzoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from micro_sam import model_zoo


def main():
parser = model_zoo._get_modelzoo_parser()
args = parser.parse_args()

model_zoo.get_modelzoo_yaml(
image_path=args.input_path,
box_prompts=None,
model_type=args.model_type,
output_path=args.output_path,
doc_path=args.doc_path
)


if __name__ == "__main__":
main()
101 changes: 82 additions & 19 deletions micro_sam/model_zoo.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,63 @@
import os
import argparse
import numpy as np
from glob import glob
from typing import List

import imageio.v2 as imageio

import torch

from micro_sam import util

from .predictor_adaptor import PredictorAdaptor
from .prompt_based_segmentation import _compute_box_from_mask

from bioimageio.core.build_spec import build_model


def _get_livecell_npy_path(input_dir):
def _get_model(image, model_type):
"Returns the model and predictor while initializing with the model checkpoints"
predictor, sam_model = util.get_sam_model(model_type=model_type, return_sam=True) # type: ignore
image_embeddings = util.precompute_image_embeddings(predictor, image)
util.set_precomputed(predictor, image_embeddings)
return predictor, sam_model


def _get_livecell_npy_paths(
input_dir: str,
model_type: str
):
test_img_paths = sorted(glob(os.path.join(input_dir, "images", "livecell_test_images", "*")))
input_image = imageio.imread(test_img_paths[0])
chosen_input = test_img_paths[0]

input_image = imageio.imread(chosen_input)

fname = os.path.split(chosen_input)[-1]
cell_type = fname.split("_")[0]
label_image = imageio.imread(os.path.join(input_dir, "annotations", "livecell_test_images", cell_type, fname))

save_image_path = "./test-livecell-image.npy"
np.save(save_image_path, input_image)

# TODO: probably we need the prompt inputs here as well
predictor, sam_model = _get_model(input_image, model_type)
get_instance_segmentation = PredictorAdaptor(sam_model=sam_model)

# TODO: get output paths
# outputs: model(inputs) -> outputs: converted to numpy
save_output_path = "<RANDOM_PATH>.npy"
box_prompts = _compute_box_from_mask(label_image)
save_box_prompt_path = "./test-box-prompts.npy"
np.save(save_box_prompt_path, box_prompts)

return [save_image_path], [save_output_path]
instances = get_instance_segmentation(
input_image=torch.from_numpy(input_image)[None, None],
predictor=predictor,
image_embeddings=None,
box_prompts=torch.from_numpy(box_prompts)[None]
)

save_output_path = "./test-livecell-output.npy"
np.save(save_output_path, instances.squeeze().numpy())

return [save_image_path, save_box_prompt_path], [save_output_path]


def _get_documentation(doc_path):
Expand All @@ -28,11 +67,30 @@ def _get_documentation(doc_path):
return doc_path


def _get_modelzoo_yaml():
input_list, output_list = _get_livecell_npy_path("/scratch/usr/nimanwai/data/livecell")
def _get_sam_checkpoints(model_type):
checkpoint = util._get_checkpoint(model_type, None)
print(f"{model_type} is available at {checkpoint}")
return checkpoint


def get_modelzoo_yaml(
image_path: str,
box_prompts: List[int],
model_type: str,
output_path: str,
doc_path: str
):
# load the model and the image and prompts
# feed prompts and image to the model to get the output
# save the numpy file for the output to get the expected data

input_list, output_list = _get_livecell_npy_paths(input_dir=image_path, model_type=model_type)
_checkpoint = _get_sam_checkpoints(model_type)

breakpoint()

build_model(
weight_uri="~/.sam_models/vit_t_mobile_sam.pth",
weight_uri=_checkpoint,
test_inputs=input_list, # type: ignore
test_outputs=output_list, # type: ignore
input_axes=["bcyx"],
Expand All @@ -42,16 +100,21 @@ def _get_modelzoo_yaml():
authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"},
{"name": "Constantin Pape", "affiliation": "Uni Goettingen"}],
tags=["instance-segmentation", "segment-anything"],
license="CC-BY-4.0", # TODO: check with Constantin
documentation=_get_documentation("./doc.md"),
license="CC-BY-4.0",
documentation=_get_documentation(doc_path),
cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy", "doi": "10.1101/2023.08.21.554208"}],
output_path="./modelzoo/my_micro_sam.zip"
output_path=output_path
)


def main():
_get_modelzoo_yaml()


if __name__ == "__main__":
main()
def _get_modelzoo_parser():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_path", type=str,
help="Path to the raw inputs' directory")
parser.add_argument("-m", "--model_type", type=str, default="vit_b",
help="Name of the model to get the SAM checkpoints")
parser.add_argument("-o", "--output_path", type=str, default="./models/sam.zip",
help="Path to the output bioimage modelzoo-format SAM model")
parser.add_argument("-d", "--doc_path", type=str, default="./documentation.md",
help="Path to the documentation")
return parser
38 changes: 38 additions & 0 deletions micro_sam/predictor_adaptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional

import torch

from segment_anything.predictor import SamPredictor


class PredictorAdaptor(SamPredictor):
"""Wrapper around the SamPredictor to be used by BioImage.IO model format.
This model supports the same functionality as SamPredictor and can provide mask segmentations
from box, point or mask input prompts.
"""
def __call__(
self,
input_image: torch.Tensor,
image_embeddings: Optional[torch.Tensor] = None,
box_prompts: Optional[torch.Tensor] = None
) -> torch.Tensor:
if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed
pass # do nothing
elif self.is_image_set and image_embeddings is not None:
raise NotImplementedError # TODO: replace the image embeedings
elif image_embeddings is not None:
pass # TODO set the image embeddings
# self.features = image_embeddings
elif not self.is_image_set:
self.set_torch_image(input_image) # compute the image embeddings

instance_segmentation, _, _ = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=box_prompts,
multimask_output=False
)
# TODO get the image embeddings via image_embeddings = self.features
# and return them
return instance_segmentation

0 comments on commit fbfde8e

Please sign in to comment.