diff --git a/examples/bioimageio/export_model_for_bioengine.py b/examples/bioimageio/export_model_for_bioengine.py new file mode 100644 index 00000000..0b3f9762 --- /dev/null +++ b/examples/bioimageio/export_model_for_bioengine.py @@ -0,0 +1,3 @@ +from micro_sam.modelzoo.bioengine_export import export_bioengine_model + +export_bioengine_model("vit_b", "test-export", opset=12) diff --git a/examples/bioimageio/export_model_for_bioimageio.py b/examples/bioimageio/export_model_for_bioimageio.py new file mode 100644 index 00000000..5769ef7c --- /dev/null +++ b/examples/bioimageio/export_model_for_bioimageio.py @@ -0,0 +1,20 @@ +from micro_sam.bioimageio import export_sam_model +from micro_sam.sample_data import synthetic_data + + +def export_model_with_synthetic_data(): + image, labels = synthetic_data(shape=(1024, 1022)) + + export_sam_model( + image, labels, + model_type="vit_t", name="sam-test-vit-t", + output_path="./test_export.zip", + ) + + +def main(): + export_model_with_synthetic_data() + + +if __name__ == "__main__": + main() diff --git a/examples/bioimageio/imjoy_test.py b/examples/bioimageio/imjoy_test.py new file mode 100644 index 00000000..2bd500b5 --- /dev/null +++ b/examples/bioimageio/imjoy_test.py @@ -0,0 +1,44 @@ +import numpy as np +from imjoy_rpc.hypha import connect_to_server +import time + +image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" +) + +# SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io" +# SERVER_URL = "https://hypha.bioimage.io" +# SERVER_URL = "https://ai.imjoy.io" +SERVER_URL = "https://hypha.bioimage.io" + + +async def test_backbone(triton): + config = await triton.get_config(model_name="micro-sam-vit-b-backbone") + print(config) + + image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype( + "float32" + ) + + start_time = time.time() + result = await triton.execute( + inputs=[image], + model_name="micro-sam-vit-b-backbone", + ) + print("Backbone", result) + embedding = result['output0__0'] + print("Time taken: ", time.time() - start_time) + print("Test passed", embedding.shape) + + +async def run(): + server = await connect_to_server( + {"name": "test client", "server_url": SERVER_URL, "method_timeout": 100} + ) + triton = await server.get_service("triton-client") + await test_backbone(triton) + + +if __name__ == "__main__": + import asyncio + asyncio.run(run()) diff --git a/micro_sam/bioimageio/__init__.py b/micro_sam/bioimageio/__init__.py new file mode 100644 index 00000000..f3534f83 --- /dev/null +++ b/micro_sam/bioimageio/__init__.py @@ -0,0 +1 @@ +from .model_export import export_sam_model diff --git a/micro_sam/bioimageio/bioengine_export.py b/micro_sam/bioimageio/bioengine_export.py new file mode 100644 index 00000000..9559f97d --- /dev/null +++ b/micro_sam/bioimageio/bioengine_export.py @@ -0,0 +1,250 @@ +import os +import warnings +from typing import Optional, Union + +import torch +from segment_anything.utils.onnx import SamOnnxModel + +try: + import onnxruntime + onnxruntime_exists = True +except ImportError: + onnxruntime_exists = False + +from ..util import get_sam_model + + +ENCODER_CONFIG = """name: "%s" +backend: "pytorch" +platform: "pytorch_libtorch" + +max_batch_size : 1 +input [ + { + name: "input0__0" + data_type: TYPE_FP32 + dims: [3, -1, -1] + } +] +output [ + { + name: "output0__0" + data_type: TYPE_FP32 + dims: [256, 64, 64] + } +] + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +}""" + + +DECODER_CONFIG = """name: "%s" +backend: "onnxruntime" +platform: "onnxruntime_onnx" + +parameters: { + key: "INFERENCE_MODE" + value: { + string_value: "true" + } +} + +instance_group { + count: 1 + kind: KIND_CPU +}""" + + +def _to_numpy(tensor): + return tensor.cpu().numpy() + + +def export_image_encoder( + model_type: str, + output_root: Union[str, os.PathLike], + export_name: Optional[str] = None, + checkpoint_path: Optional[str] = None, +) -> None: + """Export SAM image encoder to torchscript. + + The torchscript image encoder can be used for predicting image embeddings + with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner). + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the exported model. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-encoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) + encoder = predictor.model.image_encoder + + encoder.eval() + input_ = torch.rand(1, 3, 1024, 1024) + traced_model = torch.jit.trace(encoder, input_) + weight_path = os.path.join(weight_output_folder, "model.pt") + traced_model.save(weight_path) + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(ENCODER_CONFIG % name) + + +def export_onnx_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM prompt enocer and mask decoder to onnx. + + The onnx encoder and decoder can be used for interactive segmentation in the browser. + This code is adapted from + https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + if export_name is None: + export_name = model_type + name = f"sam-{export_name}-decoder" + + output_folder = os.path.join(output_root, name) + weight_output_folder = os.path.join(output_folder, "1") + os.makedirs(weight_output_folder, exist_ok=True) + + _, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True) + weight_path = os.path.join(weight_output_folder, "model.onnx") + + onnx_model = SamOnnxModel( + model=sam, + return_single_mask=return_single_mask, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) + + if gelu_approximate: + for n, m in onnx_model.named_modules: + if isinstance(m, torch.nn.GELU): + m.approximate = "tanh" + + dynamic_axes = { + "point_coords": {1: "num_points"}, + "point_labels": {1: "num_points"}, + } + + embed_dim = sam.prompt_encoder.embed_dim + embed_size = sam.prompt_encoder.image_embedding_size + + mask_input_size = [4 * x for x in embed_size] + dummy_inputs = { + "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), + "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), + "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), + "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), + "has_mask_input": torch.tensor([1], dtype=torch.float), + "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), + } + + _ = onnx_model(**dummy_inputs) + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + with open(weight_path, "wb") as f: + print(f"Exporting onnx model to {weight_path}...") + torch.onnx.export( + onnx_model, + tuple(dummy_inputs.values()), + f, + export_params=True, + verbose=False, + opset_version=opset, + do_constant_folding=True, + input_names=list(dummy_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + if onnxruntime_exists: + ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()} + # set cpu provider default + providers = ["CPUExecutionProvider"] + ort_session = onnxruntime.InferenceSession(weight_path, providers=providers) + _ = ort_session.run(None, ort_inputs) + print("Model has successfully been run with ONNXRuntime.") + + config_output_path = os.path.join(output_folder, "config.pbtxt") + with open(config_output_path, "w") as f: + f.write(DECODER_CONFIG % name) + + +def export_bioengine_model( + model_type, + output_root, + opset: int, + export_name: Optional[str] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + return_single_mask: bool = True, + gelu_approximate: bool = False, + use_stability_score: bool = False, + return_extra_metrics: bool = False, +) -> None: + """Export SAM model to a format compatible with the BioEngine. + + [The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the + image encoder on an online backend, so that SAM can be used in an online tool, or to predict + the image embeddings via the online backend rather than on CPU. + + Args: + model_type: The SAM model type. + output_root: The output root directory where the exported model is saved. + opset: The ONNX opset version. + export_name: The name of the exported model. + checkpoint_path: Optional checkpoint for loading the SAM model. + return_single_mask: Whether the mask decoder returns a single or multiple masks. + gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend + does not have an efficient GeLU implementation. + use_stability_score: Whether to use the stability score instead of the predicted score. + return_extra_metrics: Whether to return a larger set of metrics. + """ + export_image_encoder(model_type, output_root, export_name, checkpoint_path) + export_onnx_model( + model_type=model_type, + output_root=output_root, + opset=opset, + export_name=export_name, + checkpoint_path=checkpoint_path, + return_single_mask=return_single_mask, + gelu_approximate=gelu_approximate, + use_stability_score=use_stability_score, + return_extra_metrics=return_extra_metrics, + ) diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py new file mode 100644 index 00000000..b2d77423 --- /dev/null +++ b/micro_sam/bioimageio/model_export.py @@ -0,0 +1,490 @@ +import os +import tempfile + +from pathlib import Path +from typing import Optional, Union + +import bioimageio.core +import bioimageio.spec.model.v0_5 as spec +import matplotlib.pyplot as plt +import numpy as np +import torch +import xarray + +from bioimageio.spec import save_bioimageio_package +from bioimageio.core.digest_spec import create_sample_for_model + + +from .. import util +from ..prompt_generators import PointAndBoxPromptGenerator +from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box +from ..prompt_based_segmentation import _compute_logits_from_mask +from .predictor_adaptor import PredictorAdaptor + +DEFAULTS = { + "authors": [ + spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), + spec.Author(name="Constantin Pape", affiliation="University Goettingen", github_user="constantinpape"), + ], + "description": "Finetuned Segment Anything Model for Microscopy", + "cite": [ + spec.CiteEntry(text="Archit et al. Segment Anything for Microscopy", doi=spec.Doi("10.1101/2023.08.21.554208")), + ], + "tags": ["segment-anything", "instance-segmentation"], + # FIXME these are details for the uploader we should remove here + "uploader": spec.Uploader(email="constantin.pape@informatik.uni-goettinge.de"), + "id": "acclaimed-angelfish", + "id_emoji": "🐠", +} + + +def _create_test_inputs_and_outputs( + image, + labels, + model_type, + checkpoint_path, + tmp_dir, +): + # For now we just generate a single box prompt here, but we could also generate more input prompts. + generator = PointAndBoxPromptGenerator( + n_positive_points=1, + n_negative_points=2, + dilation_strength=2, + get_point_prompts=True, + get_box_prompts=True, + ) + centers, bounding_boxes = util.get_centers_and_bounding_boxes(labels) + masks = util.segmentation_to_one_hot(labels.astype("int64"), segmentation_ids=[1, 2]) # type: ignore + point_prompts, point_labels, box_prompts, _ = generator(masks, [bounding_boxes[1], bounding_boxes[2]]) + + box_prompts = box_prompts.numpy()[None] + point_prompts = point_prompts.numpy()[None] + point_labels = point_labels.numpy()[None] + + # Generate logits from the two + mask_prompts = np.stack( + [ + _compute_logits_from_mask(labels == 1), + _compute_logits_from_mask(labels == 2), + ] + )[None] + + predictor = PredictorAdaptor(model_type=model_type) + predictor.load_state_dict(torch.load(checkpoint_path)) + + input_ = util._to_image(image).transpose(2, 0, 1)[None] + image_path = os.path.join(tmp_dir, "input.npy") + np.save(image_path, input_) + + masks, scores, embeddings = predictor( + image=torch.from_numpy(input_), + embeddings=None, + box_prompts=torch.from_numpy(box_prompts), + point_prompts=torch.from_numpy(point_prompts), + point_labels=torch.from_numpy(point_labels), + mask_prompts=torch.from_numpy(mask_prompts), + ) + + box_prompt_path = os.path.join(tmp_dir, "box_prompts.npy") + point_prompt_path = os.path.join(tmp_dir, "point_prompts.npy") + point_label_path = os.path.join(tmp_dir, "point_labels.npy") + mask_prompt_path = os.path.join(tmp_dir, "mask_prompts.npy") + np.save(box_prompt_path, box_prompts) + np.save(point_prompt_path, point_prompts) + np.save(point_label_path, point_labels) + np.save(mask_prompt_path, mask_prompts) + + mask_path = os.path.join(tmp_dir, "mask.npy") + score_path = os.path.join(tmp_dir, "scores.npy") + embed_path = os.path.join(tmp_dir, "embeddings.npy") + np.save(mask_path, masks.numpy()) + np.save(score_path, scores.numpy()) + np.save(embed_path, embeddings.numpy()) + + inputs = { + "image": image_path, + "box_prompts": box_prompt_path, + "point_prompts": point_prompt_path, + "point_labels": point_label_path, + "mask_prompts": mask_prompt_path, + } + outputs = { + "mask": mask_path, + "score": score_path, + "embeddings": embed_path + } + return inputs, outputs + + +# TODO url with documentation for the modelzoo interface, and just add it to defaults +def _write_documentation(doc_path, doc): + with open(doc_path, "w") as f: + if doc is None: + f.write("# Segment Anything for Microscopy\n") + f.write("We extend Segment Anything, a vision foundation model for image segmentation ") + f.write("by training specialized models for microscopy data.\n") + else: + f.write(doc) + return doc_path + + +def _get_checkpoint(model_type, checkpoint_path): + if checkpoint_path is None: + model_registry = util.models() + checkpoint_path = model_registry.fetch(model_type) + return checkpoint_path + + +def _write_dependencies(dependency_file, require_mobile_sam): + content = """name: sam +channels: + - pytorch + - conda-forge +dependencies: + - segment-anything""" + if require_mobile_sam: + content += """ + - pip: + - git+https://github.com/ChaoningZhang/MobileSAM.git""" + with open(dependency_file, "w") as f: + f.write(content) + + +def _generate_covers(input_paths, result_paths, tmp_dir): + image = np.load(input_paths["image"]).squeeze() + prompts = np.load(input_paths["box_prompts"]) + mask = np.load(result_paths["mask"]) + + # create the image overlay + if image.ndim == 2: + overlay = np.stack([image, image, image]).transpose((1, 2, 0)) + elif image.shape[0] == 3: + overlay = image.transpose((1, 2, 0)) + else: + overlay = image + overlay = _enhance_image(overlay.astype("float32")) + + # overlay the mask as outline + overlay = _overlay_outline(overlay, mask[0, 0, 0], outline_dilation=2) + + # overlay the bounding box prompt + prompt = prompts[0, 0][[1, 0, 3, 2]] + prompt = np.array([prompt[:2], prompt[2:]]) + overlay = _overlay_box(overlay, prompt, outline_dilation=4) + + # write the cover image + fig, ax = plt.subplots(1) + ax.axis("off") + ax.imshow(overlay.astype("uint8")) + cover_path = os.path.join(tmp_dir, "cover.jpeg") + plt.savefig(cover_path, bbox_inches="tight") + plt.close() + + covers = [cover_path] + return covers + + +def _check_model(model_description, input_paths, result_paths): + # Load inputs. + image = xarray.DataArray(np.load(input_paths["image"]), dims=tuple("bcyx")) + embeddings = xarray.DataArray(np.load(result_paths["embeddings"]), dims=tuple("bcyx")) + box_prompts = xarray.DataArray(np.load(input_paths["box_prompts"]), dims=tuple("bic")) + point_prompts = xarray.DataArray(np.load(input_paths["point_prompts"]), dims=tuple("biic")) + point_labels = xarray.DataArray(np.load(input_paths["point_labels"]), dims=tuple("bic")) + mask_prompts = xarray.DataArray(np.load(input_paths["mask_prompts"]), dims=tuple("bicyx")) + + # Load outputs. + mask = np.load(result_paths["mask"]) + + with bioimageio.core.create_prediction_pipeline(model_description) as pp: + + # Check with all prompts. We only check the result for this setting, + # because this was used to generate the test data. + sample = create_sample_for_model( + model=model_description, + image=image, + box_prompts=box_prompts, + point_prompts=point_prompts, + point_labels=point_labels, + mask_prompts=mask_prompts, + embeddings=embeddings, + ).as_single_block() + prediction = pp.predict_sample_block(sample) + + assert len(prediction) == 3 + predicted_mask = prediction[0] + assert np.allclose(mask, predicted_mask) + + # Run the checks with partial prompts. + prompt_kwargs = [ + # With boxes. + {"box_prompts": box_prompts}, + # With point prompts. + {"point_prompts": point_prompts, "point_labels": point_labels}, + # With masks. + {"mask_prompts": mask_prompts}, + # With boxes and points. + {"box_prompts": box_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, + # With boxes and masks. + {"box_prompts": box_prompts, "mask_prompts": mask_prompts}, + # With points and masks. + {"mask_prompts": mask_prompts, "point_prompts": point_prompts, "point_labels": point_labels}, + ] + + for kwargs in prompt_kwargs: + sample = create_sample_for_model( + model=model_description, image=image, embeddings=embeddings, **kwargs + ).as_single_block() + prediction = pp.predict_sample_block(sample) + assert len(prediction) == 3 + predicted_mask = prediction[0] + assert predicted_mask.shape == mask.shape + + +def export_sam_model( + image: np.ndarray, + label_image: np.ndarray, + model_type: str, + name: str, + output_path: Union[str, os.PathLike], + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + **kwargs +) -> None: + """Export SAM model to BioImage.IO model format. + + The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and + be used in tools that support the BioImage.IO model format. + + Args: + image: The image for generating test data. + label_image: The segmentation correspoding to `image`. + It is used to derive prompt inputs for the model. + model_type: The type of the SAM model. + name: The name of the exported model. + output_path: Where the exported model is saved. + checkpoint_path: Optional checkpoint for loading the SAM model. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + checkpoint_path = _get_checkpoint(model_type, checkpoint_path=checkpoint_path) + input_paths, result_paths = _create_test_inputs_and_outputs( + image, label_image, model_type, checkpoint_path, tmp_dir, + ) + input_descriptions = [ + # First input: the image data. + spec.InputTensorDescr( + id=spec.TensorId("image"), + axes=[ + spec.BatchAxis(), + # NOTE: to support 1 and 3 channels we can add another preprocessing. + # Best solution: Have a pre-processing for this! (1C -> RGB) + spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=spec.ARBITRARY_SIZE), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=spec.ARBITRARY_SIZE), + ], + test_tensor=spec.FileDescr(source=input_paths["image"]), + data=spec.IntervalOrRatioDataDescr(type="uint8") + ), + + # Second input: the box prompts (optional) + spec.InputTensorDescr( + id=spec.TensorId("box_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + # TODO double check the axis names + spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "hwxy"]), + ], + test_tensor=spec.FileDescr(source=input_paths["box_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), + + # Third input: the point prompt coordinates (optional) + spec.InputTensorDescr( + id=spec.TensorId("point_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.IndexInputAxis( + id=spec.AxisId("point"), + size=spec.ARBITRARY_SIZE + ), + spec.ChannelAxis(channel_names=[spec.Identifier(bname) for bname in "xy"]), + ], + test_tensor=spec.FileDescr(source=input_paths["point_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), + + # Fourth input: the point prompt labels (optional) + spec.InputTensorDescr( + id=spec.TensorId("point_labels"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.IndexInputAxis( + id=spec.AxisId("point"), + size=spec.ARBITRARY_SIZE + ), + ], + test_tensor=spec.FileDescr(source=input_paths["point_labels"]), + data=spec.IntervalOrRatioDataDescr(type="int64") + ), + + # Fifth input: the mask prompts (optional) + spec.InputTensorDescr( + id=spec.TensorId("mask_prompts"), + optional=True, + axes=[ + spec.BatchAxis(), + spec.IndexInputAxis( + id=spec.AxisId("object"), + size=spec.ARBITRARY_SIZE + ), + spec.ChannelAxis(channel_names=["channel"]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=256), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=256), + ], + test_tensor=spec.FileDescr(source=input_paths["mask_prompts"]), + data=spec.IntervalOrRatioDataDescr(type="float32") + ), + + # Sixth input: the image embeddings (optional) + spec.InputTensorDescr( + id=spec.TensorId("embeddings"), + optional=True, + axes=[ + spec.BatchAxis(), + # NOTE: we currently have to specify all the channel names + # (It would be nice to also support size) + spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), + spec.SpaceInputAxis(id=spec.AxisId("y"), size=64), + spec.SpaceInputAxis(id=spec.AxisId("x"), size=64), + ], + test_tensor=spec.FileDescr(source=result_paths["embeddings"]), + data=spec.IntervalOrRatioDataDescr(type="float32") + ), + + ] + + output_descriptions = [ + # First output: The mask predictions. + spec.OutputTensorDescr( + id=spec.TensorId("masks"), + axes=[ + spec.BatchAxis(), + # NOTE: we use the data dependent size here to avoid dependency on optional inputs + spec.IndexOutputAxis( + id=spec.AxisId("object"), size=spec.DataDependentSize(), + ), + # NOTE: this could be a 3 once we use multi-masking + spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), + spec.SpaceOutputAxis( + id=spec.AxisId("y"), + size=spec.SizeReference( + tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("y"), + ) + ), + spec.SpaceOutputAxis( + id=spec.AxisId("x"), + size=spec.SizeReference( + tensor_id=spec.TensorId("image"), axis_id=spec.AxisId("x"), + ) + ) + ], + data=spec.IntervalOrRatioDataDescr(type="uint8"), + test_tensor=spec.FileDescr(source=result_paths["mask"]) + ), + + # The score predictions + spec.OutputTensorDescr( + id=spec.TensorId("scores"), + axes=[ + spec.BatchAxis(), + # NOTE: we use the data dependent size here to avoid dependency on optional inputs + spec.IndexOutputAxis( + id=spec.AxisId("object"), size=spec.DataDependentSize(), + ), + # NOTE: this could be a 3 once we use multi-masking + spec.ChannelAxis(channel_names=[spec.Identifier("mask")]), + ], + data=spec.IntervalOrRatioDataDescr(type="float32"), + test_tensor=spec.FileDescr(source=result_paths["score"]) + ), + + # The image embeddings + spec.OutputTensorDescr( + id=spec.TensorId("embeddings"), + axes=[ + spec.BatchAxis(), + spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]), + spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64), + spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64), + ], + data=spec.IntervalOrRatioDataDescr(type="float32"), + test_tensor=spec.FileDescr(source=result_paths["embeddings"]) + ) + ] + + architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py") + architecture = spec.ArchitectureFromFileDescr( + source=Path(architecture_path), + callable="PredictorAdaptor", + kwargs={"model_type": model_type} + ) + + dependency_file = os.path.join(tmp_dir, "environment.yaml") + _write_dependencies(dependency_file, require_mobile_sam=model_type.startswith("vit_t")) + + weight_descriptions = spec.WeightsDescr( + pytorch_state_dict=spec.PytorchStateDictWeightsDescr( + source=Path(checkpoint_path), + architecture=architecture, + pytorch_version=spec.Version(torch.__version__), + dependencies=spec.EnvironmentFileDescr(source=dependency_file), + ) + ) + + doc_path = os.path.join(tmp_dir, "documentation.md") + _write_documentation(doc_path, kwargs.get("documentation", None)) + + covers = _generate_covers(input_paths, result_paths, tmp_dir) + + model_description = spec.ModelDescr( + name=name, + inputs=input_descriptions, + outputs=output_descriptions, + weights=weight_descriptions, + description=kwargs.get("description", DEFAULTS["description"]), + authors=kwargs.get("authors", DEFAULTS["authors"]), + cite=kwargs.get("cite", DEFAULTS["cite"]), + license=spec.LicenseId("CC-BY-4.0"), + documentation=Path(doc_path), + git_repo=spec.HttpUrl("https://github.com/computational-cell-analytics/micro-sam"), + tags=kwargs.get("tags", DEFAULTS["tags"]), + covers=covers, + uploader=kwargs.get("uploader", DEFAULTS["uploader"]), + id=kwargs.get("id", DEFAULTS["id"]), + id_emoji=kwargs.get("id_emoji", DEFAULTS["id_emoji"]), + # TODO attach the decoder weights if given + # Can be list of files??? + # attachments=[spec.FileDescr(source=file_path) for file_path in attachment_files] + # TODO write the config + # dict with yaml values, key must be a str + # micro_sam: ... + # config= + ) + + # _check_model(model_description, input_paths, result_paths) + + save_bioimageio_package(model_description, output_path=output_path) diff --git a/micro_sam/bioimageio/predictor_adaptor.py b/micro_sam/bioimageio/predictor_adaptor.py new file mode 100644 index 00000000..c4a53e64 --- /dev/null +++ b/micro_sam/bioimageio/predictor_adaptor.py @@ -0,0 +1,122 @@ +import warnings +from typing import Optional, Tuple + +import torch +from torch import nn + +from segment_anything.predictor import SamPredictor + +try: + # Avoid import warnings from mobile_sam + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from mobile_sam import sam_model_registry +except ImportError: + from segment_anything import sam_model_registry + + +class PredictorAdaptor(nn.Module): + """Wrapper around the SamPredictor. + + This model supports the same functionality as SamPredictor and can provide mask segmentations + from box, point or mask input prompts. + + Args: + model_type: The type of the model for the image encoder. + Can be one of 'vit_b', 'vit_l', 'vit_h' or 'vit_t'. + For 'vit_t' support the 'mobile_sam' package has to be installed. + """ + def __init__(self, model_type: str) -> None: + super().__init__() + sam_model = sam_model_registry[model_type]() + self.sam = SamPredictor(sam_model) + + def load_state_dict(self, state): + self.sam.model.load_state_dict(state) + + @torch.no_grad() + def forward( + self, + image: torch.Tensor, + box_prompts: Optional[torch.Tensor] = None, + point_prompts: Optional[torch.Tensor] = None, + point_labels: Optional[torch.Tensor] = None, + mask_prompts: Optional[torch.Tensor] = None, + embeddings: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + + Args: + image: torch inputs of dimensions B x C x H x W + box_prompts: box coordinates of dimensions B x OBJECTS x 4 + point_prompts: point coordinates of dimension B x OBJECTS x POINTS x 2 + point_labels: point labels of dimension B x OBJECTS x POINTS + mask_prompts: mask prompts of dimension B x OBJECTS x 256 x 256 + embeddings: precomputed image embeddings B x 256 x 64 x 64 + + Returns: + """ + batch_size = image.shape[0] + if batch_size != 1: + raise ValueError + + # We have image embeddings set and image embeddings were not passed. + if self.sam.is_image_set and embeddings is None: + pass # do nothing + + # The embeddings are passed, so we set them. + elif embeddings is not None: + self.sam.features = embeddings + self.sam.orig_h, self.sam.orig_w = image.shape[2:] + self.sam.input_h, self.sam.input_w = self.sam.transform.apply_image_torch(image).shape[2:] + self.sam.is_image_set = True + + # We don't have image embeddings set and they were not passed. + elif not self.sam.is_image_set: + image = self.sam.transform.apply_image_torch(image) + self.sam.set_torch_image(image, original_image_size=image.numpy().shape[2:]) + self.sam.orig_h, self.sam.orig_w = self.sam.original_size + self.sam.input_h, self.sam.input_w = self.sam.input_size + + assert self.sam.is_image_set, "The predictor has not yet been initialized." + + # Ensure input size and original size are set. + self.sam.input_size = (self.sam.input_h, self.sam.input_w) + self.sam.original_size = (self.sam.orig_h, self.sam.orig_w) + + if box_prompts is None: + boxes = None + else: + boxes = self.sam.transform.apply_boxes_torch(box_prompts, original_size=self.sam.original_size) + + if point_prompts is None: + point_coords = None + else: + assert point_labels is not None + point_coords = self.sam.transform.apply_coords_torch(point_prompts, original_size=self.sam.original_size)[0] + point_labels = point_labels[0] + + if mask_prompts is None: + mask_input = None + else: + mask_input = mask_prompts[0] + + masks, scores, _ = self.sam.predict_torch( + point_coords=point_coords, + point_labels=point_labels, + boxes=boxes, + mask_input=mask_input, + multimask_output=False + ) + + assert masks.shape[2:] == image.shape[2:], \ + f"{masks.shape[2:]} is not as expected ({image.shape[2:]})" + + # Ensure batch axis. + if masks.ndim == 4: + masks = masks[None] + assert scores.ndim == 2 + scores = scores[None] + + embeddings = self.sam.get_image_embedding() + return masks.to(dtype=torch.uint8), scores, embeddings diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 46bfa10d..572f15ff 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -136,7 +136,7 @@ def generate_data_for_model_comparison( # -# Visual evaluation accroding to metrics +# Visual evaluation according to metrics # diff --git a/test/test_bioimageio/test_model_export.py b/test/test_bioimageio/test_model_export.py new file mode 100644 index 00000000..4fbd882c --- /dev/null +++ b/test/test_bioimageio/test_model_export.py @@ -0,0 +1,37 @@ +import os +import unittest + +from shutil import rmtree + +import micro_sam.util as util +from micro_sam.sample_data import synthetic_data + + +class TestModelExport(unittest.TestCase): + tmp_folder = "tmp" + model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + rmtree(self.tmp_folder) + + def test_model_export(self): + from micro_sam.bioimageio import export_sam_model + image, labels = synthetic_data(shape=(1024, 1022)) + + export_path = os.path.join(self.tmp_folder, "test_export.zip") + export_sam_model( + image, labels, + model_type=self.model_type, name="test-export", + output_path=export_path, + ) + + self.assertTrue(os.path.exists(export_path)) + + # TODO more tests: run prediction with models for different prompt settings + + +if __name__ == "__main__": + unittest.main()