Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bioengine integration #513

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/bioimageio/export_model_for_bioengine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from micro_sam.bioimageio.bioengine_export import export_bioengine_model

export_bioengine_model("vit_t", "test-export", opset=12)
141 changes: 141 additions & 0 deletions examples/bioimageio/hypha_data_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import json
import uuid
import mimetypes
import os
from urllib.parse import parse_qs

class HyphaDataStore:
def __init__(self):
self.storage = {}
self._svc = None
self._server = None

async def setup(self, server, service_id="data-store", visibility="public"):
self._server = server
self._svc = await server.register_service({
"id": service_id,
"type": "functions",
"config": {
"visibility": visibility,
"require_context": False
},
"get": self.http_get,
}, overwrite=True)

def get_url(self, obj_id: str):
assert self._svc, "Service not initialized, call `setup()`"
assert obj_id in self.storage, "Object not found " + obj_id
return f"{self._server.config.public_base_url}/{self._server.config.workspace}/apps/{self._svc.id.split(':')[1]}/get?id={obj_id}"

def put(self, obj_type: str, value: any, name: str, comment: str = ""):
assert self._svc, "Please call `setup()` before using the store"
obj_id = str(uuid.uuid4())
if obj_type == 'file':
data = value
assert isinstance(data, (str, bytes)), "Value must be a string or bytes"
if isinstance(data, str) and data.startswith("file://"):
# File URL examples:
# Absolute URL: `file:///home/data/myfile.png`
# Relative URL: `file://./myimage.png`, or `file://myimage.png`
with open(data.replace("file://", ""), 'rb') as fil:
data = fil.read()
mime_type, _ = mimetypes.guess_type(name)
self.storage[obj_id] = {
'type': obj_type,
'name': name,
'value': data,
'mime_type': mime_type or 'application/octet-stream',
'comment': comment
}
else:
self.storage[obj_id] = {
'type': obj_type,
'name': name,
'value': value,
'mime_type': 'application/json',
'comment': comment
}
return obj_id

def get(self, id: str):
assert self._svc, "Please call `setup()` before using the store"
obj = self.storage.get(id)
return obj

def http_get(self, scope, context=None):
query_string = scope['query_string']
id = parse_qs(query_string).get('id', [])[0]
obj = self.storage.get(id)
if obj is None:
return {'status': 404, 'headers': {}, 'body': "Not found: " + id}

if obj['type'] == 'file':
data = obj['value']
if isinstance(data, str):
if not os.path.isfile(data):
return {
"status": 404,
'headers': {'Content-Type': 'text/plain'},
"body": "File not found: " + data
}
with open(data, 'rb') as fil:
data = fil.read()
headers = {
'Content-Type': obj['mime_type'],
'Content-Length': str(len(obj['value'])),
'Content-Disposition': f'inline; filename="{obj["name"].split("/")[-1]}"'
}

return {
'status': 200,
'headers': headers,
'body': obj['value']
}
else:
return {
'status': 200,
'headers': {'Content-Type': 'application/json'},
'body': json.dumps(obj['value'])
}

def http_list(self, scope, context=None):
query_string = scope.get('query_string', b'')
kws = parse_qs(query_string).get('keyword', [])
keyword = kws[0] if kws else None
result = [value for key, value in self.storage.items() if not keyword or keyword in value['name']]
return {'status': 200, 'headers': {'Content-Type': 'application/json'}, 'body': json.dumps(result)}

def remove(self, obj_id: str):
assert self._svc, "Please call `setup()` before using the store"
if obj_id in self.storage:
del self.storage[obj_id]
return True
raise IndexError("Not found: " + obj_id)

async def test_data_store(server_url="https://ai.imjoy.io"):
from imjoy_rpc.hypha import connect_to_server, login
token = await login({"server_url": server_url})
server = await connect_to_server({"server_url": server_url, "token": token})

ds = HyphaDataStore()
# Setup would need to be completed in an ASGI compatible environment
await ds.setup(server)

# Test PUT operation
file_id = ds.put('file', 'file:///home/data.txt', 'data.txt')
binary_id = ds.put('file', b'Some binary content', 'example.bin')
json_id = ds.put('json', {'hello': 'world'}, 'example.json')

# Test GET operation
assert ds.get(file_id)['type'] == 'file'
assert ds.get(binary_id)['type'] == 'file'
assert ds.get(json_id)['type'] == 'json'

# Test GET URL generation
print("URL for getting file", ds.get_url(file_id))
print("URL for getting binary object", ds.get_url(binary_id))
print("URL for getting json object", ds.get_url(json_id))

if __name__ == "__main__":
import asyncio
asyncio.run(test_data_store())
44 changes: 44 additions & 0 deletions examples/bioimageio/run_imjoy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import numpy as np
from imjoy_rpc.hypha import connect_to_server
import time

image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)

# SERVER_URL = 'http://127.0.0.1:9520' # "https://ai.imjoy.io"
# SERVER_URL = "https://hypha.bioimage.io"
# SERVER_URL = "https://ai.imjoy.io"
SERVER_URL = "https://hypha.bioimage.io"


async def test_backbone(triton):
config = await triton.get_config(model_name="micro-sam-vit-b-backbone")
print(config)

image = np.random.randint(0, 255, size=(1, 3, 1024, 1024), dtype=np.uint8).astype(
"float32"
)

start_time = time.time()
result = await triton.execute(
inputs=[image],
model_name="micro-sam-vit-b-backbone",
)
print("Backbone", result)
embedding = result['output0__0']
print("Time taken: ", time.time() - start_time)
print("Test passed", embedding.shape)


async def run():
server = await connect_to_server(
{"name": "test client", "server_url": SERVER_URL, "method_timeout": 100}
)
triton = await server.get_service("triton-client")
await test_backbone(triton)


if __name__ == "__main__":
import asyncio
asyncio.run(run())
185 changes: 185 additions & 0 deletions examples/bioimageio/sam_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import os
import warnings
from functools import partial

# import urllib
import imageio.v3 as imageio
import numpy as np
import requests
import torch

from hypha_data_store import HyphaDataStore
from segment_anything import sam_model_registry, SamPredictor
from segment_anything.utils.onnx import SamOnnxModel

image_url = "https://owncloud.gwdg.de/index.php/s/fSaOJIOYjmFBjPM/download"


def get_sam_model(model_name):
models = {
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/staged/1/files/vit_b.pt",
# TODO
"vit_b_em_organelles": "",
}
model_url = models[model_name]
checkpoint_path = f"{model_name}.pt"

if not os.path.exists(checkpoint_path):
response = requests.get(model_url)
if response.status_code == 200:
with open(checkpoint_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)

device = "cuda" if torch.cuda.is_available() else "cpu"
model_type = model_name[:5]
sam = sam_model_registry[model_type]()
ckpt = torch.load(checkpoint_path, map_location=device)
sam.load_state_dict(ckpt)
return sam


def export_onnx_model(
sam,
output_path,
opset: int,
return_single_mask: bool = True,
gelu_approximate: bool = False,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:

onnx_model = SamOnnxModel(
model=sam,
return_single_mask=return_single_mask,
use_stability_score=use_stability_score,
return_extra_metrics=return_extra_metrics,
)

if gelu_approximate:
for n, m in onnx_model.named_modules:
if isinstance(m, torch.nn.GELU):
m.approximate = "tanh"

dynamic_axes = {
"point_coords": {1: "num_points"},
"point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size

mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
"image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
"point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
"point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([1], dtype=torch.float),
"orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

_ = onnx_model(**dummy_inputs)

output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
with open(output_path, "wb") as f:
print(f"Exporting onnx model to {output_path}...")
torch.onnx.export(
onnx_model,
tuple(dummy_inputs.values()),
f,
export_params=True,
verbose=False,
opset_version=opset,
do_constant_folding=True,
input_names=list(dummy_inputs.keys()),
output_names=output_names,
dynamic_axes=dynamic_axes,
)


def get_example_image():
image = imageio.imread(image_url)
return np.asarray(image)


def _to_image(input_):
# we require the input to be uint8
if input_.dtype != np.dtype("uint8"):
# first normalize the input to [0, 1]
input_ = input_.astype("float32") - input_.min()
input_ = input_ / input_.max()
# then bring to [0, 255] and cast to uint8
input_ = (input_ * 255).astype("uint8")
if input_.ndim == 2:
image = np.concatenate([input_[..., None]] * 3, axis=-1)
elif input_.ndim == 3 and input_.shape[-1] == 3:
image = input_
else:
raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
return image


def compute_embeddings(model_name="vit_b"):
sam = get_sam_model(model_name)
predictor = SamPredictor(sam)
image = get_example_image()
predictor.reset_image()
predictor.set_image(_to_image(image))
image_embeddings = predictor.get_image_embedding().cpu().numpy()
return image_embeddings


async def get_onnx(ds, model_name="vit_b", opset_version=12):
output_path = f"{model_name}.onnx"
if not os.path.exists(output_path):
sam = get_sam_model(model_name)
export_onnx_model(sam, output_path, opset=opset_version)

file_id = ds.put("file", f"file://{output_path}", output_path)
url = ds.get_url(file_id)
return url


async def start_server():
from imjoy_rpc.hypha import connect_to_server, login

server_url = "https://ai.imjoy.io"

token = await login({"server_url": server_url})
server = await connect_to_server({"server_url": server_url, "token": token})

# Upload to hypha.
ds = HyphaDataStore()
await ds.setup(server)

svc = await server.register_service({
"name": "Sam Server",
"id": "bioimageio-colab",
"config": {
"visibility": "public"
},
"get_onnx": partial(get_onnx, ds=ds),
"compute_embeddings": compute_embeddings,
"get_example_image": get_example_image,
"ping": lambda: "pong"
})
sid = svc['id']
# config_str = f'{{"service_id": "{sid}", "server_url": "{server_url}"}}'
# encoded_config = urllib.parse.quote(config_str, safe='/', encoding=None, errors=None)
# annotator_url = 'https://imjoy.io/lite?plugin=https://raw.githubusercontent.com/bioimage-io/bioimageio-colab/main/plugins/bioimageio-colab.imjoy.html&config=' + encoded_config
print(sid)


if __name__ == "__main__":
import asyncio

loop = asyncio.get_event_loop()
loop.create_task(start_server())

loop.run_forever()
Loading
Loading