Skip to content

Commit

Permalink
Full bioimageio integration (#569)
Browse files Browse the repository at this point in the history
Update model zoo integration
  • Loading branch information
constantinpape authored May 3, 2024
1 parent 136400d commit dbe7e9c
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 98 deletions.
10 changes: 5 additions & 5 deletions doc/finetuned_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ We currently offer the following models:
- `vit_l`: Default Segment Anything model with vit-l backbone.
- `vit_b`: Default Segment Anything model with vit-b backbone.
- `vit_t`: Segment Anything model with vit-tiny backbone. From the [Mobile SAM publication](https://arxiv.org/abs/2306.14289).
- `vit_l_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-l backbone. ([zenodo](TODO), [bioimage.io](TODO))
- `vit_l_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-l backbone. ([zenodo](TODO), [idealistic-rat on bioimage.io](TODO))
- `vit_b_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-b backbone. ([zenodo](https://zenodo.org/doi/10.5281/zenodo.11103797), [diplomatic-bug on bioimage.io](TODO))
- `vit_t_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-t backbone. ([zenodo](TODO), [bioimage.io](TODO))
- `vit_l_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-l backbone. ([zenodo](TODO), [bioimage.io](TODO))
- `vit_b_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-b backbone. ([zenodo](TODO), [bioimage.io](TODO))
- `vit_t_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-t backbone. ([zenodo](TODO), [bioimage.io](TODO))
- `vit_t_lm`: Finetuned Segment Anything model for cells and nuclei in light microscopy data with vit-t backbone. ([zenodo](TODO), [faithful-chicken bioimage.io](TODO))
- `vit_l_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-l backbone. ([zenodo](TODO), [humorous-crab on bioimage.io](TODO))
- `vit_b_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-b backbone. ([zenodo](TODO), [noisy-ox on bioimage.io](TODO))
- `vit_t_em_organelles`: Finetuned Segment Anything model for mitochodria and nuclei in electron microscopy data with vit-t backbone. ([zenodo](TODO), [greedy-whale on bioimage.io](https://doi.org/10.5281/zenodo.11110950))

See the two figures below of the improvements through the finetuned model for LM and EM data.

Expand Down
28 changes: 13 additions & 15 deletions micro_sam/bioimageio/model_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def _check_model(model_description, input_paths, result_paths):
).as_single_block()
prediction = pp.predict_sample_block(sample)

assert len(prediction) == 3
predicted_mask = prediction[0]
predicted_mask = prediction.blocks["masks"].data.data
assert predicted_mask.shape == mask.shape
assert np.allclose(mask, predicted_mask)

# Run the checks with partial prompts.
Expand All @@ -262,8 +262,7 @@ def _check_model(model_description, input_paths, result_paths):
model=model_description, image=image, embeddings=embeddings, **kwargs
).as_single_block()
prediction = pp.predict_sample_block(sample)
assert len(prediction) == 3
predicted_mask = prediction[0]
predicted_mask = prediction.blocks["masks"].data.data
assert predicted_mask.shape == mask.shape


Expand Down Expand Up @@ -300,7 +299,7 @@ def export_sam_model(
spec.InputTensorDescr(
id=spec.TensorId("image"),
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
# NOTE: to support 1 and 3 channels we can add another preprocessing.
# Best solution: Have a pre-processing for this! (1C -> RGB)
spec.ChannelAxis(channel_names=[spec.Identifier(cname) for cname in "RGB"]),
Expand All @@ -316,7 +315,7 @@ def export_sam_model(
id=spec.TensorId("box_prompts"),
optional=True,
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
Expand All @@ -332,7 +331,7 @@ def export_sam_model(
id=spec.TensorId("point_prompts"),
optional=True,
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
Expand All @@ -352,7 +351,7 @@ def export_sam_model(
id=spec.TensorId("point_labels"),
optional=True,
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
Expand All @@ -371,7 +370,7 @@ def export_sam_model(
id=spec.TensorId("mask_prompts"),
optional=True,
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
spec.IndexInputAxis(
id=spec.AxisId("object"),
size=spec.ARBITRARY_SIZE
Expand All @@ -389,7 +388,7 @@ def export_sam_model(
id=spec.TensorId("embeddings"),
optional=True,
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
# NOTE: we currently have to specify all the channel names
# (It would be nice to also support size)
spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
Expand All @@ -407,7 +406,7 @@ def export_sam_model(
spec.OutputTensorDescr(
id=spec.TensorId("masks"),
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
# NOTE: we use the data dependent size here to avoid dependency on optional inputs
spec.IndexOutputAxis(
id=spec.AxisId("object"), size=spec.DataDependentSize(),
Expand Down Expand Up @@ -435,7 +434,7 @@ def export_sam_model(
spec.OutputTensorDescr(
id=spec.TensorId("scores"),
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
# NOTE: we use the data dependent size here to avoid dependency on optional inputs
spec.IndexOutputAxis(
id=spec.AxisId("object"), size=spec.DataDependentSize(),
Expand All @@ -451,7 +450,7 @@ def export_sam_model(
spec.OutputTensorDescr(
id=spec.TensorId("embeddings"),
axes=[
spec.BatchAxis(),
spec.BatchAxis(size=1),
spec.ChannelAxis(channel_names=[spec.Identifier(f"c{i}") for i in range(256)]),
spec.SpaceOutputAxis(id=spec.AxisId("y"), size=64),
spec.SpaceOutputAxis(id=spec.AxisId("x"), size=64),
Expand Down Expand Up @@ -520,7 +519,6 @@ def export_sam_model(
# config=
)

# TODO this requires the new bioimageio.core release
# _check_model(model_description, input_paths, result_paths)
_check_model(model_description, input_paths, result_paths)

save_bioimageio_package(model_description, output_path=output_path)
2 changes: 1 addition & 1 deletion scripts/model_export/.gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
adjectives.txt
animals.yaml
collection.json
do_upload.sh
do_workflow.sh
8 changes: 8 additions & 0 deletions scripts/model_export/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# BioImage.IO Scripts

Scripts for managing the micro-sam models in the BioImage.IO collection: https://github.com/bioimage-io/collection

- export_models.py: to create the modelzoo format.
- stage_models.py: to call the stage workflow for uploading new versions.
- publish_models.py: to call the publish workflow for already staged versions.
- get_dois.py: to get the zenodo DOI's for already published versions.
91 changes: 14 additions & 77 deletions scripts/model_export/export_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,19 @@

import argparse
import os
import json
import warnings
from glob import glob

import bioimageio.spec.model.v0_5 as spec
import h5py
import imageio.v3 as imageio
import numpy as np
import requests
import xxhash
import yaml

from micro_sam.bioimageio import export_sam_model
from skimage.measure import label

from models import get_id_and_emoji, MODEL_TO_NAME

BUF_SIZE = 65536 # lets read stuff in 64kb chunks!

INPUT_FOLDER = "./v2"
Expand All @@ -39,62 +37,6 @@ def create_doc(model_type, modality, version):
return doc


def download_file(url, filename):
if os.path.exists(filename):
return

# Send HTTP GET request to the URL
response = requests.get(url)

# Check if the request was successful
if response.status_code == 200:
# Open a local file in write-text mode
with open(filename, "w", encoding=response.encoding or "utf-8") as file:
file.write(response.text) # Using .text instead of .content
print(f"File '{filename}' has been downloaded successfully.")
else:
print(f"Failed to download the file. Status code: {response.status_code}")


def get_id_and_emoji():
addjective_url = "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/main/adjectives.txt"
animal_url = "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/main/animals.yaml"
collection_url = "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json"

adjective_file = "adjectives.txt"
download_file(addjective_url, adjective_file)
adjectives = []
with open(adjective_file) as f:
for adj in f.readlines():
adjectives.append(adj.rstrip("\n"))

animal_file = "animals.yaml"
download_file(animal_url, animal_file)
with open(animal_file) as f:
animal_dict = yaml.safe_load(f)
animal_names = list(animal_dict.keys())

collection_file = "collection.json"
download_file(collection_url, collection_file)
with open(collection_file) as f:
collection = json.load(f)["collection"]

existing_ids = []
for entry in collection:
this_id = entry.get("nickname", None)
if this_id is None:
continue
existing_ids.append(this_id)

adj, name = np.random.choice(adjectives), np.random.choice(animal_names)
model_id = f"{adj}-{name}"
while model_id in existing_ids:
adj, name = np.random.choice(adjectives), np.random.choice(animal_names)
model_id = f"{adj}-{name}"

return model_id, animal_dict[name]


def get_data(modality):
if modality == "lm":
image_path = os.path.join(
Expand All @@ -114,12 +56,6 @@ def get_data(modality):
label_image = f["labels"][0]
label_image = label(label_image == 1)

# import napari
# v = napari.Viewer()
# v.add_image(image)
# v.add_labels(label_image)
# napari.run()

assert image.shape == label_image.shape
return image, label_image

Expand All @@ -146,19 +82,20 @@ def export_model(model_path, model_type, modality, version, email):
output_folder = os.path.join(OUTPUT_FOLDER, modality)
os.makedirs(output_folder, exist_ok=True)

export_name = f"{model_type}_{modality}"
output_path = os.path.join(output_folder, export_name)
model_name = f"{model_type}_{modality}"
output_path = os.path.join(output_folder, model_name)
if os.path.exists(output_path):
print("The model", export_name, "has already been exported.")
print("The model", model_name, "has already been exported.")
return

image, label_image = get_data(modality)
covers = get_covers(modality)
doc = create_doc(model_type, modality, version)

model_id, emoji = get_id_and_emoji()
model_id, emoji = get_id_and_emoji(model_name)
uploader = spec.Uploader(email=email)

export_name = MODEL_TO_NAME[model_name]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
export_sam_model(
Expand All @@ -178,22 +115,22 @@ def export_model(model_path, model_type, modality, version, email):
encoder_path = os.path.join(output_path + ".unzip", f"{model_type}.pt")
encoder_checksum = compute_checksum(encoder_path)
print("Encoder:")
print(export_name, f"xxh128:{encoder_checksum}")
print(model_name, f"xxh128:{encoder_checksum}")

decoder_path = os.path.join(output_path + ".unzip", f"{model_type}_decoder.pt")
decoder_checksum = compute_checksum(decoder_path)
print("Decoder:")
print(f"{export_name}_decoder", f"xxh128:{decoder_checksum}")
print(f"{model_name}_decoder", f"xxh128:{decoder_checksum}")


def export_all_models(email):
models = glob(os.path.join("./v2/**/vit*"))
def export_all_models(email, version):
models = glob(os.path.join(f"./v{version}/**/vit*"))
for path in models:
modality, model_type = path.split("/")[-2:]
# print(model_path, modality, model_type)
model_path = os.path.join(path, "best.pt")
assert os.path.exists(model_path), model_path
export_model(model_path, model_type, modality, version=2, email=email)
export_model(model_path, model_type, modality, version=version, email=email)


# For testing.
Expand All @@ -206,10 +143,10 @@ def export_vit_t_lm(email):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--email", required=True)
parser.add_argument("-v", "--version", default=2, type=int)
args = parser.parse_args()

# export_vit_t_lm(args.email)
export_all_models(args.email)
export_all_models(args.email, args.version)


if __name__ == "__main__":
Expand Down
28 changes: 28 additions & 0 deletions scripts/model_export/get_dois.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

import requests
from models import MODEL_TO_ID

URL_BASE = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/%s/versions.json"


def get_zenodo_url(model_name, model_id):

url = URL_BASE % model_id
response = requests.get(url)

if response.status_code == 200:
data = response.json()
else:
print("Failed to retrieve data:", response.status_code)

doi = data["published"]["1"]["doi"]
doi = f"https://doi.org/{doi}"

print(model_name, ":")
print(model_id)
print(doi)
print()


for model_name, model_id in MODEL_TO_ID.items():
get_zenodo_url(model_name, model_id)
Loading

0 comments on commit dbe7e9c

Please sign in to comment.