diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 574cffbb716b3..7c8aae1bce104 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -476,6 +476,9 @@ file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPEND file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" ) +file(GLOB onnxruntime_python_transformers_models_sam2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/sam2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -547,6 +550,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/sam2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -656,6 +660,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_phi2_src} $/onnxruntime/transformers/models/phi2/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_sam2_src} + $/onnxruntime/transformers/models/sam2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 2375104ac96f5..0fa038d5cfc62 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,13 +1,16 @@ import copy import logging from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Tuple, Union import numpy import torch from onnxruntime import InferenceSession, RunOptions +# Type alias +ShapeDict = Mapping[str, Union[Tuple, List[int]]] + logger = logging.getLogger(__name__) @@ -262,7 +265,7 @@ def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor): ) self.output_tensors[self.buffer_sharing[name]] = tensor - def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): + def allocate_buffers(self, shape_dict: ShapeDict): """Allocate tensors for I/O Binding""" if self.enable_cuda_graph: for name, shape in shape_dict.items(): @@ -346,7 +349,7 @@ def __init__( self, ort_session: InferenceSession, device: torch.device, - shape_dict: Dict[str, Union[Tuple[int], List[int]]], + shape_dict: ShapeDict, enable_gpu_graph: bool = False, gpu_graph_id: int = -1, stream: int = 0, @@ -406,7 +409,7 @@ def __init__(self, ort_session: InferenceSession, device: torch.device, stream: def get_binding( self, - shape_dict: Dict[str, Union[Tuple[int], List[int]]], + shape_dict: ShapeDict, use_cuda_graph: bool = False, buffer_sharing: Optional[Dict[str, str]] = None, ) -> GpuBinding: diff --git a/onnxruntime/python/tools/transformers/models/sam2/README.md b/onnxruntime/python/tools/transformers/models/sam2/README.md new file mode 100644 index 0000000000000..6ae2b35ba248c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/README.md @@ -0,0 +1,65 @@ +# SAM2 ONNX Model Export + +## Setup Environment +It is recommend to setup a machine with python 3.10, 3.11 or 3.12. Then install [PyTorch 2.4.1](https://pytorch.org/) and [Onnx Runtime 1.19.2]. + +### CPU Only +To install the CPU-only version of PyTorch and Onnx Runtime for exporting and running ONNX models, use the following commands: +``` +python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu +python3 -m pip install onnxruntime onnx opencv-python matplotlib +``` + +### GPU +If your machine has an NVIDIA GPU, you can install the CUDA version of PyTorch and Onnx Runtime for exporting and running ONNX models: + +``` +python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 +python3 -m pip install onnxruntime-gpu onnx opencv-python matplotlib +``` + +onnxruntime-gpu requires CUDA 12.x, cuDNN 9.x, and other dependencies (such as MSVC Runtime on Windows). For more information, see the [installation guide](https://onnxruntime.ai/docs/install/#python-installs). + +## Download Checkpoints + +Clone the SAM2 git repository and download the checkpoints: +```bash +git clone https://github.com/facebookresearch/segment-anything-2.git +cd segment-anything-2 +python3 -m pip install -e . +cd checkpoints +sh ./download_ckpts.sh +``` + +On Windows, you can replace `sh ./download_ckpts.sh` with the following commands: +```bash +curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt > sam2_hiera_tiny.pt +curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt > sam2_hiera_small.pt +curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt > sam2_hiera_base_plus.pt +curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt > sam2_hiera_large.pt +``` + +## Export ONNX +To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command: +```bash +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 +``` + +The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option. + +If you want the model outputs multiple masks, append the `--multimask_output` option. + +To see all parameters, run the following command: +```bash +python3 convert_to_onnx.py -h +``` + +## Run Demo +The exported ONNX models can run on a CPU. The demo will output sam2_demo.png. +```bash +curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg +python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo +``` + +## Limitations +- The exported image_decoder model does not support batch mode for now. diff --git a/onnxruntime/python/tools/transformers/models/sam2/__init__.py b/onnxruntime/python/tools/transformers/models/sam2/__init__.py new file mode 100644 index 0000000000000..815be385d7dd4 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os.path +import sys + +sys.path.append(os.path.dirname(__file__)) + +transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) +if transformers_dir not in sys.path: + sys.path.append(transformers_dir) diff --git a/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py new file mode 100644 index 0000000000000..9b629f5c40802 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py @@ -0,0 +1,195 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import os +import pathlib +import sys + +import torch +from image_decoder import export_decoder_onnx, test_decoder_onnx +from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx +from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx +from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx +from sam2_demo import run_demo, show_all_images +from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX") + + parser.add_argument( + "--model_type", + required=False, + type=str, + choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"], + default="sam2_hiera_large", + help="The model type to export", + ) + + parser.add_argument( + "--components", + required=False, + nargs="+", + choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"], + default=["image_encoder", "image_decoder"], + help="Type of ONNX models to export. " + "Note that image_decoder is a combination of prompt_encoder and mask_decoder", + ) + + parser.add_argument( + "--output_dir", + type=str, + help="The output directory for the ONNX models", + default="sam2_onnx_models", + ) + + parser.add_argument( + "--dynamic_batch_axes", + required=False, + default=False, + action="store_true", + help="Export image_encoder with dynamic batch axes", + ) + + parser.add_argument( + "--multimask_output", + required=False, + default=False, + action="store_true", + help="Export mask_decoder or image_decoder with multimask_output", + ) + + parser.add_argument( + "--disable_dynamic_multimask_via_stability", + required=False, + action="store_true", + help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only." + "This option will be ignored when multimask_output is True", + ) + + parser.add_argument( + "--sam2_dir", + required=False, + type=str, + default="./segment-anything-2", + help="The directory of segment-anything-2 git repository", + ) + + parser.add_argument( + "--overwrite", + required=False, + default=False, + action="store_true", + help="Overwrite onnx model file if exists.", + ) + + parser.add_argument( + "--demo", + required=False, + default=False, + action="store_true", + help="Run demo with the exported ONNX models.", + ) + + parser.add_argument( + "--verbose", + required=False, + default=False, + action="store_true", + help="Print verbose information", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + + checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints") + sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs") + if not os.path.exists(args.sam2_dir): + raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(checkpoints_dir): + raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(sam2_config_dir): + raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.") + + if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")): + raise FileNotFoundError( + f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory." + ) + + if args.sam2_dir not in sys.path: + sys.path.append(args.sam2_dir) + + pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu") + + for component in args.components: + if component == "image_encoder": + onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) + if args.overwrite or not os.path.exists(onnx_model_path): + export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) + test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False) + + elif component == "mask_decoder": + onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx") + if args.overwrite or not os.path.exists(onnx_model_path): + export_mask_decoder_onnx( + sam2_model, + onnx_model_path, + args.multimask_output, + not args.disable_dynamic_multimask_via_stability, + args.verbose, + ) + test_mask_decoder_onnx( + sam2_model, + onnx_model_path, + args.multimask_output, + not args.disable_dynamic_multimask_via_stability, + ) + elif component == "prompt_encoder": + onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx") + if args.overwrite or not os.path.exists(onnx_model_path): + export_prompt_encoder_onnx(sam2_model, onnx_model_path) + test_prompt_encoder_onnx(sam2_model, onnx_model_path) + elif component == "image_decoder": + onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output) + if args.overwrite or not os.path.exists(onnx_model_path): + export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) + test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output) + + if args.demo: + # Export required ONNX models for demo if not already exported. + onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type) + if not os.path.exists(onnx_model_path): + export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose) + + onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True) + if not os.path.exists(onnx_model_path): + export_decoder_onnx(sam2_model, onnx_model_path, True) + + onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False) + if not os.path.exists(onnx_model_path): + export_decoder_onnx(sam2_model, onnx_model_path, False) + + ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir) + print("demo output files for ONNX Runtime:", ort_image_files) + + # Get results from torch engine to compare. + torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir) + print("demo output files for PyTorch:", torch_image_files) + + show_all_images(ort_image_files, torch_image_files) + + +if __name__ == "__main__": + setup_logger(verbose=False) + with torch.no_grad(): + main() diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py new file mode 100644 index 0000000000000..0f7a1099461bc --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/image_decoder.py @@ -0,0 +1,249 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import warnings + +import torch +import torch.nn.functional as F +from image_encoder import SAM2ImageEncoder, random_sam2_input_image +from mask_decoder import SAM2MaskDecoder +from prompt_encoder import SAM2PromptEncoder +from sam2.modeling.sam2_base import SAM2Base +from sam2_utils import compare_tensors_with_tolerance +from torch import nn + +logger = logging.getLogger(__name__) + + +class SAM2ImageDecoder(nn.Module): + def __init__( + self, + sam_model: SAM2Base, + multimask_output: bool, + dynamic_multimask_via_stability: bool = True, + return_logits: bool = False, + mask_threshold: float = 0.0, + ) -> None: + super().__init__() + self.prompt_encoder = SAM2PromptEncoder(sam_model) + self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability) + self.return_logits = return_logits + self.mask_threshold = mask_threshold + + @torch.no_grad() + def forward( + self, + image_features_0: torch.Tensor, + image_features_1: torch.Tensor, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + input_masks: torch.Tensor, + has_input_masks: torch.Tensor, + original_image_size: torch.Tensor, + ): + """ + Decode masks from image features and prompts. Batched images are not supported. H=W=1024. + + Args: + image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder. + image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder. + image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder. + point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel + coordinate in (x, y) format of the P input points in image of size 1024x1024. + point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means + positive (foreground), 0 means negative (background), -1 means padding, + 2 (box left upper corner), 3 (box right bottom corner). + input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model. + Typically coming from a previous iteration. + has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise. + original_image_size(torch.Tensor): [2]. original image size H_o, W_o. + + Returns: + masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size. + iou_predictions (torch.Tensor): [1, M]. scores for M masks. + low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks. + """ + sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder( + point_coords, point_labels, input_masks, has_input_masks + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings + ) + + # Interpolate the low resolution masks back to the original image size. + masks = F.interpolate( + low_res_masks, + (original_image_size[0], original_image_size[1]), + mode="bilinear", + align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch. + ) + + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not self.return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + +def export_decoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + multimask_output: bool = False, + verbose: bool = False, +): + batch_size = 1 + image = random_sam2_input_image(batch_size) + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image) + + logger.info("image_features_0.shape: %s", image_features_0.shape) + logger.info("image_features_1.shape: %s", image_features_1.shape) + logger.info("image_embeddings.shape: %s", image_embeddings.shape) + + sam2_decoder = SAM2ImageDecoder( + sam2_model, + multimask_output=multimask_output, + dynamic_multimask_via_stability=True, + ).cpu() + + num_labels = 2 + num_points = 3 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.ones(1, dtype=torch.float) + original_image_size = torch.tensor([1200, 1800], dtype=torch.int32) + + example_inputs = ( + image_features_0, + image_features_1, + image_embeddings, + point_coords, + point_labels, + input_masks, + has_input_masks, + original_image_size, + ) + + logger.info("point_coords.shape: %s", point_coords.shape) + logger.info("point_labels.shape: %s", point_labels.shape) + logger.info("input_masks.shape: %s", input_masks.shape) + logger.info("has_input_masks.shape: %s", has_input_masks.shape) + logger.info("original_image_size.shape: %s", original_image_size.shape) + + if verbose: + masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs) + logger.info("masks.shape: %s", masks.shape) + logger.info("iou_predictions.shape: %s", iou_predictions.shape) + logger.info("low_res_masks.shape: %s", low_res_masks.shape) + + input_names = [ + "image_features_0", + "image_features_1", + "image_embeddings", + "point_coords", + "point_labels", + "input_masks", + "has_input_masks", + "original_image_size", + ] + + output_names = ["masks", "iou_predictions", "low_res_masks"] + + dynamic_axes = { + "point_coords": {0: "num_labels", 1: "num_points"}, + "point_labels": {0: "num_labels", 1: "num_points"}, + "input_masks": {0: "num_labels"}, + "has_input_masks": {0: "num_labels"}, + "masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"}, + "low_res_masks": {0: "num_labels"}, + "iou_predictions": {0: "num_labels"}, + } + + with warnings.catch_warnings(): + if not verbose: + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + + torch.onnx.export( + sam2_decoder, + example_inputs, + onnx_model_path, + export_params=True, + opset_version=16, + do_constant_folding=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + ) + + logger.info("decoder onnx model saved to %s", onnx_model_path) + + +def test_decoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + multimask_output=False, +): + + batch_size = 1 + image = random_sam2_input_image(batch_size) + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image) + + sam2_image_decoder = SAM2ImageDecoder( + sam2_model, + multimask_output=multimask_output, + dynamic_multimask_via_stability=True, + ).cpu() + + num_labels = 1 + num_points = 5 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.zeros(1, dtype=torch.float) + original_image_size = torch.tensor([1500, 1500], dtype=torch.int32) + + example_inputs = ( + image_features_0, + image_features_1, + image_embeddings, + point_coords, + point_labels, + input_masks, + has_input_masks, + original_image_size, + ) + + masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + + model_inputs = ort_session.get_inputs() + input_names = [model_inputs[i].name for i in range(len(model_inputs))] + logger.info("input_names: %s", input_names) + + model_outputs = ort_session.get_outputs() + output_names = [model_outputs[i].name for i in range(len(model_outputs))] + logger.info("output_names: %s", output_names) + inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))} + outputs = ort_session.run(output_names, inputs) + + for i, output_name in enumerate(output_names): + logger.info(f"{output_name}.shape: %s", outputs[i].shape) + + ort_masks, ort_iou_predictions, ort_low_res_masks = outputs + if ( + compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float()) + and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions)) + and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks)) + ): + print("onnx model has been verified:", onnx_model_path) + else: + print("onnx model verification failed:", onnx_model_path) diff --git a/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py new file mode 100644 index 0000000000000..ec05e5f5b0f6c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/image_encoder.py @@ -0,0 +1,164 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import warnings + +import torch +from sam2.modeling.sam2_base import SAM2Base +from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image +from torch import nn + +import onnxruntime + +logger = logging.getLogger(__name__) + + +class SAM2ImageEncoder(nn.Module): + def __init__(self, sam_model: SAM2Base) -> None: + super().__init__() + self.model = sam_model + self.image_encoder = sam_model.image_encoder + self.no_mem_embed = sam_model.no_mem_embed + + def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Encodes images into features. + + Only supports H=W=1024. If you want to use different image sizes like 512x512, + see https://github.com/facebookresearch/segment-anything-2/issues/138. + + Args: + image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width. + + Returns: + image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0 + image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1 + image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride + """ + backbone_out = self.image_encoder(image) + + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) + backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) + + # Prepare and flatten visual features. + feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :] + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + + # flatten NxCxHxW to HWxNxC + # TODO: we should avoid this transpose since it will be transposed back to NCHW later. + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + + vision_feats[-1] = vision_feats[-1] + self.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).reshape(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1]) + ][::-1] + + return feats[0], feats[1], feats[2] + + +def export_image_encoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + dynamic_batch_axes: bool = False, + verbose: bool = False, +): + image = random_sam2_input_image() + + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image) + logger.info("image.shape: %s", image.shape) + logger.info("image_features_0.shape: %s", image_features_0.shape) + logger.info("image_features_1.shape: %s", image_features_1.shape) + logger.info("image_embeddings.shape: %s", image_embeddings.shape) + + dynamic_axes = None + if dynamic_batch_axes: + dynamic_axes = { + "image": {0: "batch_size"}, + "image_features_0": {0: "batch_size"}, + "image_features_1": {0: "batch_size"}, + "image_embeddings": {0: "batch_size"}, + } + + with warnings.catch_warnings(): + if not verbose: + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + torch.onnx.export( + sam2_encoder, + image, + onnx_model_path, + export_params=True, + opset_version=17, + do_constant_folding=True, + input_names=["image"], + output_names=["image_features_0", "image_features_1", "image_embeddings"], + dynamic_axes=dynamic_axes, + ) + + print("encoder onnx model saved to", onnx_model_path) + + +def test_image_encoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + dynamic_batch_axes=False, +): + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + + model_inputs = ort_session.get_inputs() + input_names = [model_inputs[i].name for i in range(len(model_inputs))] + logger.info("input_names: %s", input_names) + + model_outputs = ort_session.get_outputs() + output_names = [model_outputs[i].name for i in range(len(model_outputs))] + logger.info("output_names: %s", output_names) + + batch_sizes = [1, 2] if dynamic_batch_axes else [1] + for batch_size in batch_sizes: + image = random_sam2_input_image(batch_size) + + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone()) + + logger.info("image.shape: %s", image.shape) + logger.info("image_features_0.shape: %s", image_features_0.shape) + logger.info("image_features_1.shape: %s", image_features_1.shape) + logger.info("image_embeddings.shape: %s", image_embeddings.shape) + + outputs = ort_session.run(output_names, {"image": image.numpy()}) + for i, output_name in enumerate(output_names): + logger.info("output %s shape %s", output_name, outputs[i].shape) + ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs + + # ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results. + if ( + compare_tensors_with_tolerance( + "image_features_0", + image_features_0, + torch.tensor(ort_image_features_0), + mismatch_percentage_tolerance=1, + ) + and compare_tensors_with_tolerance( + "image_features_1", + image_features_1, + torch.tensor(ort_image_features_1), + mismatch_percentage_tolerance=1, + ) + and compare_tensors_with_tolerance( + "image_embeddings", + image_embeddings, + torch.tensor(ort_image_embeddings), + mismatch_percentage_tolerance=1, + ) + ): + print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}") + else: + print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}") diff --git a/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py new file mode 100644 index 0000000000000..56473c002d4ae --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/mask_decoder.py @@ -0,0 +1,208 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import warnings + +import torch +from image_encoder import SAM2ImageEncoder, random_sam2_input_image +from prompt_encoder import SAM2PromptEncoder +from sam2.modeling.sam2_base import SAM2Base +from torch import nn + +logger = logging.getLogger(__name__) + + +class SAM2MaskDecoder(nn.Module): + def __init__( + self, + sam_model: SAM2Base, + multimask_output: bool, + dynamic_multimask_via_stability: bool = True, + ) -> None: + super().__init__() + self.mask_decoder = sam_model.sam_mask_decoder + self.prompt_encoder = sam_model.sam_prompt_encoder + self.model = sam_model + self.multimask_output = multimask_output + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + + @torch.no_grad() + def forward( + self, + image_features_0: torch.Tensor, + image_features_1: torch.Tensor, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_embeddings: torch.Tensor, + dense_embeddings: torch.Tensor, + ): + """ + Decode masks from image and prompt embeddings. Only support H=W=1024. + + Args: + image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder. + image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder. + image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder. + image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding. + sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes. + dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks. + + Returns: + low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks. + iou_predictions (torch.Tensor): [1, M]. scores for M masks. + """ + low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + repeat_image=sparse_embeddings.shape[0] > 1, # batch mode + high_res_features=[image_features_0, image_features_1], + ) + + if self.multimask_output: + low_res_masks = low_res_masks[:, 1:, :, :] + iou_predictions = iou_predictions[:, 1:] + elif self.dynamic_multimask_via_stability: + # When outputting a single mask, if the stability score from the current single-mask + # output (based on output token 0) falls below a threshold, we instead select from + # multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score. + low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability( + low_res_masks, iou_predictions + ) + else: + low_res_masks = low_res_masks[:, 0:1, :, :] + iou_predictions = iou_predictions[:, 0:1] + + return low_res_masks, iou_predictions + + +def export_mask_decoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + multimask_output: bool, + dynamic_multimask_via_stability: bool = True, + verbose=False, +): + sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu() + + image = random_sam2_input_image() + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image) + logger.info("image_features_0.shape: %s", image_features_0.shape) + logger.info("image_features_1.shape: %s", image_features_1.shape) + logger.info("image_embeddings.shape: %s", image_embeddings.shape) + + # encode an random prompt + num_labels = 2 + num_points = 3 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.ones(1, dtype=torch.float) + + sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder( + point_coords, point_labels, input_masks, has_input_masks + ) + + logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape) + logger.info("dense_embeddings.shape: %s", dense_embeddings.shape) + logger.info("image_pe.shape: %s", image_pe.shape) + + sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability) + inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings) + low_res_masks, iou_predictions = sam2_mask_decoder(*inputs) + logger.info("low_res_masks.shape: %s", low_res_masks.shape) + logger.info("iou_predictions.shape: %s", iou_predictions.shape) + + with warnings.catch_warnings(): + if not verbose: + warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) + warnings.filterwarnings("ignore", category=UserWarning) + torch.onnx.export( + sam2_mask_decoder, + inputs, + onnx_model_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=[ + "image_features_0", + "image_features_1", + "image_embeddings", + "image_pe", + "sparse_embeddings", + "dense_embeddings", + ], + output_names=["low_res_masks", "iou_predictions"], + dynamic_axes={ + "sparse_embeddings": {0: "num_labels", 1: "num_points+1"}, + "dense_embeddings": {0: "num_labels"}, + "low_res_masks": {0: "num_labels"}, + "iou_predictions": {0: "num_labels"}, + }, + ) + + print("mask decoder onnx model saved to", onnx_model_path) + + +def test_mask_decoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, + multimask_output: bool, + dynamic_multimask_via_stability: bool, +): + sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu() + + image = random_sam2_input_image() + sam2_encoder = SAM2ImageEncoder(sam2_model).cpu() + image_features_0, image_features_1, image_embeddings = sam2_encoder(image) + + num_labels = 1 + num_points = 5 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float) + input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.ones(1, dtype=torch.float) + + sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder( + point_coords, point_labels, input_masks, has_input_masks + ) + + sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability) + inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings) + low_res_masks, iou_predictions = sam2_mask_decoder(*inputs) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + + model_inputs = ort_session.get_inputs() + input_names = [model_inputs[i].name for i in range(len(model_inputs))] + logger.info("input_names: %s", input_names) + + model_outputs = ort_session.get_outputs() + output_names = [model_outputs[i].name for i in range(len(model_outputs))] + logger.info("output_names: %s", output_names) + + outputs = ort_session.run( + output_names, + { + "image_features_0": image_features_0.numpy(), + "image_features_1": image_features_1.numpy(), + "image_embeddings": image_embeddings.numpy(), + "image_pe": image_pe.numpy(), + "sparse_embeddings": sparse_embeddings.numpy(), + "dense_embeddings": dense_embeddings.numpy(), + }, + ) + + for i, output_name in enumerate(output_names): + logger.info("output %s shape: %s", output_name, outputs[i].shape) + + ort_low_res_masks, ort_iou_predictions = outputs + torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4) + torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4) + print(f"onnx model has been verified: {onnx_model_path}") diff --git a/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py new file mode 100644 index 0000000000000..883c51858346c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/prompt_encoder.py @@ -0,0 +1,189 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import torch +from sam2.modeling.sam2_base import SAM2Base +from sam2_utils import compare_tensors_with_tolerance +from torch import nn + +logger = logging.getLogger(__name__) + + +class SAM2PromptEncoder(nn.Module): + def __init__(self, sam_model: SAM2Base): + super().__init__() + self.prompt_encoder = sam_model.sam_prompt_encoder + self.model = sam_model + + @torch.no_grad() + def forward( + self, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + input_masks: torch.Tensor, + has_input_masks: torch.Tensor, + ): + """Encode prompts. + + Args: + point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel + coordinate in (x, y) format of the P input points in image of size 1024x1024. + point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means + positive (foreground), 0 means negative (background), -1 means padding, + 2 (box left upper corner), 3 (box right bottom corner). + input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model. + Typically coming from a previous iteration. + has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise. + Returns: + sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes. + dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks. + image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding. + """ + sparse_embeddings = self._embed_points(point_coords, point_labels) + dense_embeddings = self._embed_masks(input_masks, has_input_masks) + image_pe = self.prompt_encoder.get_dense_pe() + + return sparse_embeddings, dense_embeddings, image_pe + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + + padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device) + padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device) + point_coords = torch.cat([point_coords, padding_point], dim=1) + point_labels = torch.cat([point_labels, padding_label], dim=1) + + # Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0). + point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size + point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size + + point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1) + + for i in range(self.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor: + mask_embedding = self.prompt_encoder.mask_downscaling(input_masks) + no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape) + mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding + logger.info("mask_embedding.shape: %s", mask_embedding.shape) + return mask_embedding + + +def export_prompt_encoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, +): + sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu() + + num_labels = 2 + num_points = 3 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32) + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.ones(1, dtype=torch.float) + + sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder( + point_coords, point_labels, input_masks, has_input_masks + ) + + logger.info("point_coords.shape: %s", point_coords.shape) + logger.info("point_labels.shape: %s", point_labels.shape) + logger.info("input_masks.shape: %s", input_masks.shape) + logger.info("has_input_masks.shape: %s", has_input_masks.shape) + + logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape) + logger.info("dense_embeddings.shape: %s", dense_embeddings.shape) + logger.info("image_pe.shape: %s", image_pe.shape) + + torch.onnx.export( + sam2_prompt_encoder, + (point_coords, point_labels, input_masks, has_input_masks), + onnx_model_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"], + output_names=["sparse_embeddings", "dense_embeddings", "image_pe"], + dynamic_axes={ + "point_coords": {0: "num_labels", 1: "num_points"}, + "point_labels": {0: "num_labels", 1: "num_points"}, + "input_masks": {0: "num_labels"}, + "sparse_embeddings": {0: "num_labels", 1: "num_points+1"}, + "dense_embeddings": {0: "num_labels"}, + }, + ) + + print("prompt encoder onnx model saved to ", onnx_model_path) + + +def test_prompt_encoder_onnx( + sam2_model: SAM2Base, + onnx_model_path: str, +): + sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu() + + num_labels = 1 + num_points = 5 + point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float) + point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32) + input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float) + has_input_masks = torch.ones(1, dtype=torch.float) + + sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder( + point_coords, point_labels, input_masks, has_input_masks + ) + + import onnxruntime + + ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers()) + + model_inputs = ort_session.get_inputs() + input_names = [model_inputs[i].name for i in range(len(model_inputs))] + logger.info("input_names: %s", input_names) + + model_outputs = ort_session.get_outputs() + output_names = [model_outputs[i].name for i in range(len(model_outputs))] + logger.info("output_names: %s", output_names) + + outputs = ort_session.run( + output_names, + { + "point_coords": point_coords.numpy(), + "point_labels": point_labels.numpy(), + "input_masks": input_masks.numpy(), + "has_input_masks": has_input_masks.numpy(), + }, + ) + + for i, output_name in enumerate(output_names): + logger.info("output %s shape: %s", output_name, outputs[i].shape) + + ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs + if ( + compare_tensors_with_tolerance( + "sparse_embeddings", + sparse_embeddings, + torch.tensor(ort_sparse_embeddings), + mismatch_percentage_tolerance=0.2, + ) + and compare_tensors_with_tolerance( + "dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2 + ) + and compare_tensors_with_tolerance( + "image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2 + ) + ): + print(f"onnx model has been verified: {onnx_model_path}") + else: + print(f"onnx model verification failed: {onnx_model_path}") diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py new file mode 100644 index 0000000000000..e2cd93ae2157d --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_demo.py @@ -0,0 +1,293 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os + +import matplotlib.image as mpimg +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.patches import Rectangle +from PIL import Image +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor +from sam2_utils import build_sam2_model + + +def show_mask(mask, ax, random_color=False, borders=True): + if random_color: + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) + else: + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) + h, w = mask.shape[-2:] + mask = mask.astype(np.uint8) + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + if borders: + import cv2 + + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + # Try to smooth contours + contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] + mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) + ax.imshow(mask_image) + + +def show_points(coords, labels, ax, marker_size=375): + pos_points = coords[labels == 1] + neg_points = coords[labels == 0] + ax.scatter( + pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25 + ) + ax.scatter( + neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25 + ) + + +def show_box(box, ax): + x0, y0 = box[0], box[1] + w, h = box[2] - box[0], box[3] - box[1] + ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) + + +def show_masks( + image, + masks, + scores, + point_coords=None, + box_coords=None, + input_labels=None, + borders=True, + output_image_file_prefix=None, + image_files=None, +): + for i, (mask, score) in enumerate(zip(masks, scores)): + plt.figure(figsize=(10, 10)) + plt.imshow(image) + show_mask(mask, plt.gca(), borders=borders) + if point_coords is not None: + assert input_labels is not None + show_points(point_coords, input_labels, plt.gca()) + + if box_coords is not None: + show_box(box_coords, plt.gca()) + + if len(scores) > 1: + plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) + + plt.axis("off") + if output_image_file_prefix: + filename = f"{output_image_file_prefix}_{i}.png" + if os.path.exists(filename): + os.remove(filename) + plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0) + if isinstance(image_files, list): + image_files.append(filename) + plt.show(block=False) + plt.close() + + +def get_predictor( + checkpoint_dir: str, + device: torch.device, + model_type="sam2_hiera_large", + engine="torch", + onnx_directory="sam2_onnx_models", +): + sam2_model = build_sam2_model(checkpoint_dir, model_type, device=device) + if engine == "torch": + predictor = SAM2ImagePredictor(sam2_model) + else: + predictor = SAM2ImageOnnxPredictor(sam2_model, onnx_directory=onnx_directory, model_type=model_type) + return predictor + + +def run_demo( + checkpoint_dir: str, + model_type="sam2_hiera_large", + engine="torch", + onnx_directory="sam2_onnx_models", + enable_batch=False, +): + use_gpu = torch.cuda.is_available() + device = torch.device("cuda" if use_gpu else "cpu") + + if use_gpu: + if engine == "torch": + # Turn on tfloat32 for Ampere GPUs. + if torch.cuda.get_device_properties(0).major >= 8: + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + elif engine == "ort": + import onnxruntime + + assert use_gpu == ("CUDAExecutionProvider" in onnxruntime.get_available_providers()) + + np.random.seed(3) + image = Image.open("truck.jpg") + image = np.array(image.convert("RGB")) + + predictor = get_predictor(checkpoint_dir, device, model_type, engine, onnx_directory=onnx_directory) + + predictor.set_image(image) + prefix = f"sam2_demo_{engine}_" + + # The model returns masks, quality predictions for those masks, + # and low resolution mask logits that can be passed to the next iteration of prediction. + # With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where + # scores gives the model's own estimation of the quality of these masks. + # For ambiguous prompts such as a single point, it is recommended to use multimask_output=True + # even if only a single mask is desired; + input_point = np.array([[500, 375]]) + input_label = np.array([1]) + masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + multimask_output=True, + ) + + sorted_ind = np.argsort(scores)[::-1] + masks = masks[sorted_ind] + scores = scores[sorted_ind] + logits = logits[sorted_ind] + + image_files = [] + show_masks( + image, + masks, + scores, + point_coords=input_point, + input_labels=input_label, + borders=True, + output_image_file_prefix=prefix + "multimask", + image_files=image_files, + ) + + # Multiple points. + input_point = np.array([[500, 375], [1125, 625]]) + input_label = np.array([1, 1]) + mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask + masks, scores, _ = predictor.predict( + point_coords=input_point, + point_labels=input_label, + mask_input=mask_input[None, :, :], + multimask_output=False, + ) + show_masks( + image, + masks, + scores, + point_coords=input_point, + input_labels=input_label, + output_image_file_prefix=prefix + "multi_points", + image_files=image_files, + ) + + # Specify a window and a background point. + input_point = np.array([[500, 375], [1125, 625]]) + input_label = np.array([1, 0]) + mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask + masks, scores, _ = predictor.predict( + point_coords=input_point, + point_labels=input_label, + mask_input=mask_input[None, :, :], + multimask_output=False, + ) + show_masks( + image, + masks, + scores, + point_coords=input_point, + input_labels=input_label, + output_image_file_prefix=prefix + "background_point", + image_files=image_files, + ) + + # Take a box as input + input_box = np.array([425, 600, 700, 875]) + masks, scores, _ = predictor.predict( + point_coords=None, + point_labels=None, + box=input_box[None, :], + multimask_output=False, + ) + show_masks( + image, + masks, + scores, + box_coords=input_box, + output_image_file_prefix=prefix + "box", + image_files=image_files, + ) + + # Combining points and boxes + input_box = np.array([425, 600, 700, 875]) + input_point = np.array([[575, 750]]) + input_label = np.array([0]) + + masks, scores, logits = predictor.predict( + point_coords=input_point, + point_labels=input_label, + box=input_box, + multimask_output=False, + ) + show_masks( + image, + masks, + scores, + box_coords=input_box, + point_coords=input_point, + input_labels=input_label, + output_image_file_prefix=prefix + "box_and_point", + image_files=image_files, + ) + + # TODO: support batched prompt inputs + if enable_batch: + input_boxes = np.array( + [ + [75, 275, 1725, 850], + [425, 600, 700, 875], + [1375, 550, 1650, 800], + [1240, 675, 1400, 750], + ] + ) + masks, scores, _ = predictor.predict( + point_coords=None, + point_labels=None, + box=input_boxes, + multimask_output=False, + ) + plt.figure(figsize=(10, 10)) + plt.imshow(image) + for mask in masks: + show_mask(mask.squeeze(0), plt.gca(), random_color=True) + for box in input_boxes: + show_box(box, plt.gca()) + plt.axis("off") + plt.show() + plt.savefig(prefix + "batch_prompt.png") + image_files.append(prefix + "batch_prompt.png") + return image_files + + +def show_all_images(left_images, right_images): + # Show images in two rows since display screen is horizontal in most cases. + fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80)) + for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images)): + left_img = mpimg.imread(left_img_path) + right_img = mpimg.imread(right_img_path) + + axes[0, i].imshow(left_img) + axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10) + axes[0, i].axis("off") + axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0]) + + axes[1, i].imshow(right_img) + axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10) + axes[1, i].axis("off") + axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0]) + + plt.tight_layout() + plt.savefig("sam2_demo.png", format="png", bbox_inches="tight", dpi=1000) + plt.show() diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py new file mode 100644 index 0000000000000..36b87f0ffbd90 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_image_onnx_predictor.py @@ -0,0 +1,283 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import logging +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2_utils import decoder_shape_dict, encoder_shape_dict, get_decoder_onnx_path, get_image_encoder_onnx_path + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +def create_ort_session( + onnx_path: str, + session_options=None, + provider="CUDAExecutionProvider", + enable_cuda_graph=False, + use_tf32=True, +) -> InferenceSession: + if provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() + provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + provider_options["use_tf32"] = int(use_tf32) + providers = [(provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + print(f"Using providers: {providers}") + return InferenceSession(onnx_path, session_options, providers=providers) + + +def create_session( + onnx_path: str, session_options=None, provider="CUDAExecutionProvider", device="cuda", enable_cuda_graph=False +) -> CudaSession: + ort_session = create_ort_session( + onnx_path, session_options, provider, enable_cuda_graph=enable_cuda_graph, use_tf32=True + ) + cuda_session = CudaSession(ort_session, device=torch.device(device), enable_cuda_graph=enable_cuda_graph) + return cuda_session + + +class SAM2ImageOnnxPredictor(SAM2ImagePredictor): + def __init__( + self, + sam_model: SAM2Base, + onnx_directory: str = "sam2_onnx_models", + model_type: str = "sam2_hiera_large", + onnx_dtype: torch.dtype = torch.float32, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to compute the image embedding for an image, and then allow mask prediction given prompts. + + Arguments: + sam_model (SAM2Base): The model to use for mask prediction. + onnx_directory (str): The path of the directory that contains encoder and decoder onnx models. + onnx_dtype (torch.dtype): The data type to use for ONNX inputs. + mask_threshold (float): The threshold to convert mask logits to binary masks. Default is 0.0. + max_hole_area (float): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (float): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__( + sam_model, mask_threshold=mask_threshold, max_hole_area=max_hole_area, max_sprinkle_area=max_sprinkle_area + ) + + print(self.device) + if torch.cuda.is_available(): + provider = "CUDAExecutionProvider" + device = "cuda" + else: + provider = "CPUExecutionProvider" + device = "cpu" + + # This model is exported by image_encoder.py. + onnx_path = get_image_encoder_onnx_path(onnx_directory, model_type) + + self.encoder_session = create_session( + onnx_path, + session_options=None, + provider=provider, + device=device, + enable_cuda_graph=False, + ) + self.onnx_dtype = onnx_dtype + + # This model is exported by image_decoder.py. It outputs only one mask. + onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=False) + self.decoder_session = create_session( + onnx_path, + session_options=None, + provider=provider, + device=device, + enable_cuda_graph=False, + ) + + # This model is exported by image_decoder.py. It outputs multiple (3) masks. + onnx_path = get_decoder_onnx_path(onnx_directory, model_type, multimask_output=True) + self.decoder_session_multi_out = create_session( + onnx_path, + session_options=None, + provider=provider, + device=device, + enable_cuda_graph=False, + ) + + @torch.no_grad() + def set_image(self, image: Union[np.ndarray, Image]): + """ + Calculates the image embeddings for the provided image. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. + The image should be in HWC format if np.ndarray, or WHC format if PIL Image with pixel values in [0, 255]. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + # For numpy array image, we assume (HxWxC) format. + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + + # Computing image embeddings for the provided image + io_shapes = encoder_shape_dict(batch_size=1, height=input_image.shape[2], width=input_image.shape[3]) + self.encoder_session.allocate_buffers(io_shapes) + + feed_dict = {"image": input_image.to(self.onnx_dtype).to(self.device)} + + for key, value in feed_dict.items(): + logger.debug(f"{key}: {value.shape}, {value.dtype}") + logger.debug(f"encoder onnx: {self.encoder_session.ort_session._model_path}") + + ort_outputs = self.encoder_session.infer(feed_dict) + + self._features = { + "image_embed": ort_outputs["image_embeddings"], + "high_res_feats": [ort_outputs[f"image_features_{i}"] for i in range(2)], + } + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + assert not return_logits # onnx model is exported for returning bool masks. + + if not self._is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + assert concat_points is not None + num_labels = concat_points[0].shape[0] + shape_dict = decoder_shape_dict( + original_image_height=self._orig_hw[img_idx][0], + original_image_width=self._orig_hw[img_idx][1], + num_labels=num_labels, + max_points=concat_points[0].shape[1], + num_masks=3 if multimask_output else 1, + ) + if multimask_output: + decoder_session = self.decoder_session_multi_out + else: + decoder_session = self.decoder_session + + decoder_session.allocate_buffers(shape_dict) + + image_features_0 = self._features["high_res_feats"][0][img_idx].unsqueeze(0) + image_features_1 = self._features["high_res_feats"][1][img_idx].unsqueeze(0) + image_embeddings = self._features["image_embed"][img_idx].unsqueeze(0) + + if mask_input is None: + input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float, device=self.device) + has_input_masks = torch.zeros(num_labels, dtype=torch.float, device=self.device) + else: + input_masks = mask_input[img_idx].unsqueeze(0).repeat(num_labels, 1, 1, 1) + has_input_masks = torch.ones(num_labels, dtype=torch.float, device=self.device) + + feed_dict = { + "image_embeddings": image_embeddings.contiguous().to(dtype=torch.float32).to(self.device), + "image_features_0": image_features_0.contiguous().to(dtype=torch.float32).to(self.device), + "image_features_1": image_features_1.contiguous().to(dtype=torch.float32).to(self.device), + "point_coords": concat_points[0].to(dtype=torch.float32).to(self.device), + "point_labels": concat_points[1].to(dtype=torch.int32).to(self.device), + "input_masks": input_masks.to(dtype=torch.float32).to(self.device), + "has_input_masks": has_input_masks.to(dtype=torch.float32).to(self.device), + "original_image_size": torch.tensor(self._orig_hw[img_idx], dtype=torch.int32, device=self.device), + } + + for key, value in feed_dict.items(): + logger.debug(f"{key}: {value.shape}, {value.dtype}") + logger.debug(f"decoder onnx: {self.decoder_session.ort_session._model_path}") + + ort_outputs = decoder_session.infer(feed_dict) + + masks = ort_outputs["masks"] + iou_predictions = ort_outputs["iou_predictions"] + low_res_masks = ort_outputs["low_res_masks"] + + return torch.Tensor(masks), torch.Tensor(iou_predictions), torch.Tensor(low_res_masks) diff --git a/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py new file mode 100644 index 0000000000000..cf88eb42213f2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/sam2/sam2_utils.py @@ -0,0 +1,122 @@ +# ------------------------------------------------------------------------- +# Copyright (R) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +import os + +import torch +from sam2.build_sam import build_sam2 +from sam2.modeling.sam2_base import SAM2Base + +logger = logging.getLogger(__name__) + + +def get_model_cfg(model_type) -> str: + assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"] + if model_type == "sam2_hiera_tiny": + model_cfg = "sam2_hiera_t.yaml" + elif model_type == "sam2_hiera_small": + model_cfg = "sam2_hiera_s.yaml" + elif model_type == "sam2_hiera_base_plus": + model_cfg = "sam2_hiera_b+.yaml" + else: + model_cfg = "sam2_hiera_l.yaml" + return model_cfg + + +def build_sam2_model(checkpoint_dir: str, model_type: str, device="cpu") -> SAM2Base: + sam2_checkpoint = os.path.join(checkpoint_dir, f"{model_type}.pt") + model_cfg = get_model_cfg(model_type) + sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device) + return sam2_model + + +def get_decoder_onnx_path(dir: str, model_type, multimask_output) -> str: + return os.path.join(dir, f"{model_type}_decoder" + ("_multi" if multimask_output else "") + ".onnx") + + +def get_image_encoder_onnx_path(dir: str, model_type) -> str: + return os.path.join(dir, f"{model_type}_image_encoder.onnx") + + +def encoder_shape_dict(batch_size: int, height: int, width: int): + assert height == 1024 and width == 1024, "Only 1024x1024 images are supported." + return { + "image": [batch_size, 3, height, width], + "image_features_0": [batch_size, 32, height // 4, width // 4], + "image_features_1": [batch_size, 64, height // 8, width // 8], + "image_embeddings": [batch_size, 256, height // 16, width // 16], + } + + +def decoder_shape_dict( + original_image_height: int, + original_image_width: int, + num_labels: int = 1, + max_points: int = 16, + num_masks: int = 1, +) -> dict: + height: int = 1024 + width: int = 1024 + return { + "image_features_0": [1, 32, height // 4, width // 4], + "image_features_1": [1, 64, height // 8, width // 8], + "image_embeddings": [1, 256, height // 16, width // 16], + "point_coords": [num_labels, max_points, 2], + "point_labels": [num_labels, max_points], + "input_masks": [num_labels, 1, height // 4, width // 4], + "has_input_masks": [num_labels], + "original_image_size": [2], + "masks": [num_labels, num_masks, original_image_height, original_image_width], + "iou_predictions": [num_labels, num_masks], + "low_res_masks": [num_labels, num_masks, height // 4, width // 4], + } + + +def compare_tensors_with_tolerance( + name: str, + tensor1: torch.Tensor, + tensor2: torch.Tensor, + atol=5e-3, + rtol=1e-4, + mismatch_percentage_tolerance=0.1, +) -> bool: + assert tensor1.shape == tensor2.shape + a = tensor1.clone().float() + b = tensor2.clone().float() + + differences = torch.abs(a - b) + mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item() + + total_elements = a.numel() + mismatch_percentage = (mismatch_count / total_elements) * 100 + + passed = mismatch_percentage < mismatch_percentage_tolerance + + log_func = logger.error if not passed else logger.info + log_func( + "%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).", + name, + mismatch_percentage, + mismatch_count, + total_elements, + "passed" if passed else "failed", + mismatch_percentage_tolerance, + ) + + return passed + + +def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor: + image = torch.randn(batch_size, 3, image_height, image_width).cpu() + return image + + +def setup_logger(verbose=True): + if verbose: + logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s") + logging.getLogger().setLevel(logging.INFO) + else: + logging.basicConfig(format="[%(message)s") + logging.getLogger().setLevel(logging.WARNING) diff --git a/setup.py b/setup.py index ac8f465851484..f471e9cd0e652 100644 --- a/setup.py +++ b/setup.py @@ -494,8 +494,9 @@ def finalize_options(self): "onnxruntime.transformers.models.llama", "onnxruntime.transformers.models.longformer", "onnxruntime.transformers.models.phi2", - "onnxruntime.transformers.models.t5", + "onnxruntime.transformers.models.sam2", "onnxruntime.transformers.models.stable_diffusion", + "onnxruntime.transformers.models.t5", "onnxruntime.transformers.models.whisper", ]