Skip to content

Commit

Permalink
Update modelzoo export script
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Oct 12, 2023
1 parent de6b245 commit ebba719
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 41 deletions.
14 changes: 7 additions & 7 deletions micro_sam/modelzoo/bioengine_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,16 @@ def export_image_encoder(
export_name: Optional[str] = None,
checkpoint_path: Optional[str] = None,
) -> None:
"""Export the SAM image encoder to torchscript.
"""Export SAM image encoder to torchscript.
The torchscript image encoder can be used for predicting image embeddings
with a backed, e.g. with [the bioengine](https://github.com/bioimage-io/bioengine-model-runner).
Args:
model_type: The SAM model type.
output_root: The output root directory where the SAM model is saved.
output_root: The output root directory where the exported model is saved.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
checkpoint_path: Optional checkpoint for loading the exported model.
"""
if export_name is None:
export_name = model_type
Expand Down Expand Up @@ -113,15 +113,15 @@ def export_onnx_model(
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
"""Export the SAM prompt enocer and mask decoder to onnx.
"""Export SAM prompt enocer and mask decoder to onnx.
The onnx encoder and decoder can be used for interactive segmentation in the browser.
This code is adapted from
https://github.com/facebookresearch/segment-anything/blob/main/scripts/export_onnx_model.py
Args:
model_type: The SAM model type.
output_root: The output root directory where the SAM model is saved.
output_root: The output root directory where the exported model is saved.
opset: The ONNX opset version.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
Expand Down Expand Up @@ -218,15 +218,15 @@ def export_bioengine_model(
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
"""Export the SAM model to a format compatible with the BioEngine.
"""Export SAM model to a format compatible with the BioEngine.
[The bioengine](https://github.com/bioimage-io/bioengine-model-runner) enables running the
image encoder on an online backend, so that SAM can be used in an online tool, or to predict
the image embeddings via the online backend rather than on CPU.
Args:
model_type: The SAM model type.
output_root: The output root directory where the SAM model is saved.
output_root: The output root directory where the exported model is saved.
opset: The ONNX opset version.
export_name: The name of the exported model.
checkpoint_path: Optional checkpoint for loading the SAM model.
Expand Down
91 changes: 68 additions & 23 deletions micro_sam/modelzoo/bioimageio_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def _get_model(image, model_type, checkpoint_path):
return predictor, sam_model


def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path,
tmp_input_path, tmp_boxes_path, tmp_output_path):
def _create_test_inputs_and_outputs(
image,
labels,
model_type,
checkpoint_path,
input_path,
box_path,
mask_path,
score_path,
embed_path,
):

# For now we just generate a single box prompt here, but we could also generate more input prompts.
generator = PointAndBoxPromptGenerator(0, 0, 4, False, True)
Expand All @@ -31,70 +40,103 @@ def _create_test_inputs_and_outputs(image, labels, model_type, checkpoint_path,
_, _, box_prompts, _ = generator(masks, [bounding_boxes[1]])
box_prompts = box_prompts.numpy()

save_image_path = tmp_input_path.name
save_image_path = input_path.name
np.save(save_image_path, image[None, None])

_, sam_model = _get_model(image, model_type, checkpoint_path)
predictor = PredictorAdaptor(sam_model=sam_model)

save_box_prompt_path = tmp_boxes_path.name
save_box_prompt_path = box_path.name
np.save(save_box_prompt_path, box_prompts)

input_ = util._to_image(image).transpose(2, 0, 1)

# TODO embeddings are also expected output
instances = predictor(
masks, scores, embeddings = predictor(
input_image=torch.from_numpy(input_)[None],
image_embeddings=None,
box_prompts=torch.from_numpy(box_prompts)[None]
)

save_output_path = tmp_output_path.name
np.save(save_output_path, instances.numpy())
np.save(mask_path.name, masks.numpy())
np.save(score_path.name, scores.numpy())
np.save(embed_path.name, embeddings.numpy())

return [save_image_path, save_box_prompt_path], [save_output_path]
return [save_image_path, save_box_prompt_path], [mask_path.name, score_path.name, embed_path.name]


def _get_documentation(doc_path):
def _write_documentation(doc_path, doc):
with open(doc_path, "w") as f:
f.write("# Segment Anything for Microscopy\n")
f.write("We extend Segment Anything, a vision foundation model for image segmentation ")
f.write("by training specialized models for microscopy data.\n")
if doc is None:
f.write("# Segment Anything for Microscopy\n")
f.write("We extend Segment Anything, a vision foundation model for image segmentation ")
f.write("by training specialized models for microscopy data.\n")
else:
f.write(doc)
return doc_path


# TODO enable over-riding the authors and citation and tags from kwargs
# TODO support RGB sample inputs
def export_bioimageio_model(
image: np.ndarray,
label_image: np.ndarray,
model_type: str,
model_name: str,
output_path: Union[str, os.PathLike],
doc_path: Optional[Union[str, os.PathLike]] = None,
doc: Optional[str] = None,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
):
**kwargs
) -> None:
"""Export SAM model to BioImage.IO model format.
The exported model can be uploaded to [bioimage.io](https://bioimage.io/#/) and
be used in tools that support the BioImage.IO model format.
Args:
image: The image for generating test data.
label_image: The segmentation correspoding to `image`.
It is used to derive prompt inputs for the model.
model_type: The type of the SAM model.
model_name: The name of the exported model.
output_path: Where the exported model is saved.
doc: Documentation for the model.
checkpoint_path: Optional checkpoint for loading the SAM model.
kwargs: optional keyword arguments for the 'build_model' function
that converts to the modelzoo format.
"""
with (
tmp_file(suffix=".md") as tmp_doc_path,
tmp_file(suffix=".npy") as tmp_input_path,
tmp_file(suffix=".npy") as tmp_boxes_path,
tmp_file(suffix=".npy") as tmp_output_path
tmp_file(suffix=".npy") as tmp_mask_path,
tmp_file(suffix=".npy") as tmp_score_path,
tmp_file(suffix=".npy") as tmp_embed_path,
):
input_paths, result_paths = _create_test_inputs_and_outputs(
image, label_image, model_type, checkpoint_path, tmp_input_path, tmp_boxes_path, tmp_output_path
image, label_image, model_type, checkpoint_path,
input_path=tmp_input_path,
box_path=tmp_boxes_path,
mask_path=tmp_mask_path,
score_path=tmp_score_path,
embed_path=tmp_embed_path,
)
checkpoint = util._get_checkpoint(model_type, checkpoint_path=checkpoint_path)

architecture_path = os.path.join(os.path.split(__file__)[0], "predictor_adaptor.py")

if doc_path is None:
doc_path = tmp_doc_path.name
_get_documentation(doc_path)
doc_path = tmp_doc_path.name
_write_documentation(doc_path, doc)

build_model(
weight_uri=checkpoint, # type: ignore
test_inputs=input_paths,
test_outputs=result_paths,
input_axes=["bcyx"],
output_axes=["bcyx"],
input_axes=["bcyx", "bic"],
# FIXME this causes some error in build-model
# input_names=["image", "box-prompts"],
output_axes=["bcyx", "bic", "bcyx"],
# FIXME this causes some error in build-model
# output_names=["masks", "scores", "image_embeddings"],
name=model_name,
description="Finetuned Segment Anything models for Microscopy",
authors=[{"name": "Anwai Archit", "affiliation": "Uni Goettingen"},
Expand All @@ -105,5 +147,8 @@ def export_bioimageio_model(
cite=[{"text": "Archit, ..., Pape et al. Segment Anything for Microscopy",
"doi": "10.1101/2023.08.21.554208"}],
output_path=output_path, # type: ignore
architecture=architecture_path
architecture=architecture_path,
**kwargs,
)

# TODO actually test the model
28 changes: 17 additions & 11 deletions micro_sam/modelzoo/predictor_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,34 @@ def __call__(
- image_embeddings: precomputed image embeddings
- box_prompts: box prompts of dimensions C x 4
"""
if self.is_image_set and image_embeddings is None: # we have embeddings set and not passed
# We have image embeddings set and image embeddings were not passed.
if self.is_image_set and image_embeddings is None:
pass # do nothing

# We have image embeddings set and image embeddings were passed.
elif self.is_image_set and image_embeddings is not None:
raise NotImplementedError # TODO: replace the image embeedings
self.features = image_embeddings

# We don't have image embeddings set and image embeddings were passed.
elif image_embeddings is not None:
pass # TODO set the image embeddings
# self.features = image_embeddings
self.features = image_embeddings

# We don't have image embeddings set and they were not apassed
elif not self.is_image_set:
image = self.transform.apply_image_torch(input_image)
self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:]) # compute the image embeddings
self.set_torch_image(image, original_image_size=input_image.numpy().shape[2:])

boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:]) # type: ignore
boxes = self.transform.apply_boxes_torch(box_prompts, original_size=input_image.numpy().shape[2:])

instance_segmentation, _, _ = self.predict_torch(
masks, scores, _ = self.predict_torch(
point_coords=None,
point_labels=None,
boxes=boxes,
multimask_output=False
)

assert instance_segmentation.shape[2:] == input_image.shape[2:], f"{instance_segmentation.shape[2:]} is not as expected ({input_image.shape[2:]})"
assert masks.shape[2:] == input_image.shape[2:],\
f"{masks.shape[2:]} is not as expected ({input_image.shape[2:]})"

# TODO get the image embeddings via image_embeddings = self.features
# and return them
return instance_segmentation
image_embeddings = self.features
return masks, scores, image_embeddings

0 comments on commit ebba719

Please sign in to comment.