Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bioimageio Model Creation #227

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f2fa218
WIP Add bioimage.io model creation
anwai98 Oct 10, 2023
2073119
Update model building script
anwai98 Oct 11, 2023
fbfde8e
Update model predictor adaptor for bioimage models
anwai98 Oct 11, 2023
c62fe93
Refactor modelzoo functionality into submodule
constantinpape Oct 11, 2023
615ffc6
Add first working scripts for bioengine export
constantinpape Oct 11, 2023
09f5f76
Add input prompt transofrms to adaptor
anwai98 Oct 11, 2023
b869a33
Update numpy input saving
anwai98 Oct 11, 2023
96dc3f6
Update bioengine export script
constantinpape Oct 11, 2023
61b076e
Merge branch 'aa-modelzoo' of https://github.com/computational-cell-a…
constantinpape Oct 11, 2023
7d6bfca
Refactor modelzoo export
constantinpape Oct 12, 2023
b470257
Add tempfile for model conversion inputs
anwai98 Oct 12, 2023
de6b245
Add doc-strings to bioengine export functionality
constantinpape Oct 12, 2023
ebba719
Update modelzoo export script
constantinpape Oct 12, 2023
ee831ef
Update url in imjoy test
constantinpape Oct 13, 2023
b28c885
Merge branch 'dev' into aa-modelzoo
constantinpape Mar 14, 2024
390ce23
Update to bioimageio.spec v0.5 WIP
constantinpape Mar 14, 2024
6cccc06
Update example script
constantinpape Mar 15, 2024
b9df849
Update bioimageio export
constantinpape Mar 18, 2024
d911392
Minor fixes
constantinpape Mar 19, 2024
ec56035
Work on export
constantinpape Mar 19, 2024
2d9c88a
More modelzoo updtes
constantinpape Mar 20, 2024
a170511
Add all possible model inputs
constantinpape Mar 21, 2024
d2e4909
Merge branch 'dev' into aa-modelzoo
constantinpape Apr 9, 2024
c181e29
Bioimageio updates WIP
constantinpape Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/bioimageio/export_model_for_bioengine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from micro_sam.modelzoo.bioengine_export import export_bioengine_model

export_bioengine_model("vit_b", "test-export", opset=12)
20 changes: 20 additions & 0 deletions examples/bioimageio/export_model_for_bioimageio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from micro_sam.bioimageio import export_sam_model
from micro_sam.sample_data import synthetic_data


def export_model_with_synthetic_data():
image, labels = synthetic_data(shape=(1024, 1022))

export_sam_model(
image, labels,
model_type="vit_t", name="sam-test-vit-t",
output_path="./test_export.zip",
)


def main():
export_model_with_synthetic_data()


if __name__ == "__main__":
main()
44 changes: 44 additions & 0 deletions examples/bioimageio/imjoy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from imjoy_rpc.hypha import connect_to_server
import time

image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)

# SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io"
# SERVER_URL = "https://hypha.bioimage.io"
# SERVER_URL = "https://ai.imjoy.io"
SERVER_URL = "https://hypha.bioimage.io"


async def test_backbone(triton):
config = await triton.get_config(model_name="micro-sam-vit-b-backbone")
print(config)

image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)

start_time = time.time()
result = await triton.execute(
inputs=[image],
model_name="micro-sam-vit-b-backbone",
)
print("Backbone", result)
embedding = result['output0__0']
print("Time taken: ", time.time() - start_time)
print("Test passed", embedding.shape)


async def run():
server = await connect_to_server(
{"name": "test client", "server_url": SERVER_URL, "method_timeout": 100}
)
triton = await server.get_service("triton-client")
await test_backbone(triton)


if __name__ == "__main__":
import asyncio
asyncio.run(run())
1 change: 1 addition & 0 deletions micro_sam/bioimageio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .model_export import export_sam_model
250 changes: 250 additions & 0 deletions micro_sam/bioimageio/bioengine_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
import os
import warnings
from typing import Optional, Union

import torch
from segment_anything.utils.onnx import SamOnnxModel

try:
import onnxruntime
onnxruntime_exists = True
except ImportError:
onnxruntime_exists = False

from ..util import get_sam_model


ENCODER_CONFIG = """name: "%s"
backend: "pytorch"
platform: "pytorch_libtorch"

max_batch_size : 1
input [
{
name: "input0__0"
data_type: TYPE_FP32
dims: [3, -1, -1]
}
]
output [
{
name: "output0__0"
data_type: TYPE_FP32
dims: [256, 64, 64]
}
]

parameters: {
key: "INFERENCE_MODE"
value: {
string_value: "true"
}
}"""


DECODER_CONFIG = """name: "%s"
backend: "onnxruntime"
platform: "onnxruntime_onnx"

parameters: {
key: "INFERENCE_MODE"
value: {
string_value: "true"
}
}

instance_group {
count: 1
kind: KIND_CPU
}"""


def _to_numpy(tensor):
return tensor.cpu().numpy()


def export_image_encoder(
model_type: str,
output_root: Union[str, os.PathLike],
export_name: Optional[str] = None,
checkpoint_path: Optional[str] = None,
) -> None:
"""Export SAM image encoder to torchscript.

The torchscript image encoder can be used for predicting image embeddings
with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner).

Args:
model_type: The SAM model type.
output_root: The output root directory where the exported model is saved.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the exported model.
"""
if export_name is None:
export_name = model_type
name = f"sam-{export_name}-encoder"

output_folder = os.path.join(output_root, name)
weight_output_folder = os.path.join(output_folder, "1")
os.makedirs(weight_output_folder, exist_ok=True)

predictor = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
encoder = predictor.model.image_encoder

encoder.eval()
input_ = torch.rand(1, 3, 1024, 1024)
traced_model = torch.jit.trace(encoder, input_)
weight_path = os.path.join(weight_output_folder, "model.pt")
traced_model.save(weight_path)

config_output_path = os.path.join(output_folder, "config.pbtxt")
with open(config_output_path, "w") as f:
f.write(ENCODER_CONFIG % name)


def export_onnx_model(
model_type,
output_root,
opset: int,
export_name: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_single_mask: bool = True,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
"""Export SAM prompt enocer and mask decoder to onnx.

The onnx encoder and decoder can be used for interactive segmentation in the browser.
This code is adapted from
https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py

Args:
model_type: The SAM model type.
output_root: The output root directory where the exported model is saved.
opset: The ONNX opset version.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
return_single_mask: Whether the mask decoder returns a single or multiple masks.
gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
does not have an efficient GeLU implementation.
use_stability_score: Whether to use the stability score instead of the predicted score.
return_extra_metrics: Whether to return a larger set of metrics.
"""
if export_name is None:
export_name = model_type
name = f"sam-{export_name}-decoder"

output_folder = os.path.join(output_root, name)
weight_output_folder = os.path.join(output_folder, "1")
os.makedirs(weight_output_folder, exist_ok=True)

_, sam = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_sam=True)
weight_path = os.path.join(weight_output_folder, "model.onnx")

onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)

if gelu_approximate:
for n, m in onnx_model.named_modules:
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"

dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size

mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

_ = onnx_model(**dummy_inputs)

output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(weight_path, "wb") as f:
print(f"Exporting onnx model to {weight_path}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)

if onnxruntime_exists:
ort_inputs = {k: _to_numpy(v) for k, v in dummy_inputs.items()}
# set cpu provider default
providers = ["CPUExecutionProvider"]
ort_session = onnxruntime.InferenceSession(weight_path, providers=providers)
_ = ort_session.run(None, ort_inputs)
print("Model has successfully been run with ONNXRuntime.")

config_output_path = os.path.join(output_folder, "config.pbtxt")
with open(config_output_path, "w") as f:
f.write(DECODER_CONFIG % name)


def export_bioengine_model(
model_type,
output_root,
opset: int,
export_name: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
return_single_mask: bool = True,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
"""Export SAM model to a format compatible with the BioEngine.

[The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the
image encoder on an online backend, so that SAM can be used in an online tool, or to predict
the image embeddings via the online backend rather than on CPU.

Args:
model_type: The SAM model type.
output_root: The output root directory where the exported model is saved.
opset: The ONNX opset version.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
return_single_mask: Whether the mask decoder returns a single or multiple masks.
gelu_approximate: Whether to use a GeLU approximation, in case the ONNX backend
does not have an efficient GeLU implementation.
use_stability_score: Whether to use the stability score instead of the predicted score.
return_extra_metrics: Whether to return a larger set of metrics.
"""
export_image_encoder(model_type, output_root, export_name, checkpoint_path)
export_onnx_model(
model_type=model_type,
output_root=output_root,
opset=opset,
export_name=export_name,
checkpoint_path=checkpoint_path,
return_single_mask=return_single_mask,
gelu_approximate=gelu_approximate,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)
Loading
Loading