Skip to content

Commit

Permalink
Add onnx export script for segment anything v2 (#22119)
Browse files Browse the repository at this point in the history
### Description
Add ONNX export script for segment anything v2 (SAM2).

### Limitations
* Does not support video. Only support image right now.
* The decoder does not support batch inference.

### Credits
The demo that is based on [SAM2
notebook](https://github.com/facebookresearch/segment-anything-2/blob/main/notebooks/image_predictor_example.ipynb),
and modified to run with ORT.

The export of decoder is inspired by
https://github.com/vietanhdev/samexporter.

### Demo
Example output of demo:

![sam2_demo](https://github.com/user-attachments/assets/9a9fa360-8c20-482e-9935-a7aba9cf15de)

### Motivation and Context
For support optimization of SAM2 image segmentation.
  • Loading branch information
tianleiwu authored Sep 18, 2024
1 parent 05acfb9 commit a9740d6
Show file tree
Hide file tree
Showing 13 changed files with 1,796 additions and 5 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,9 @@ file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPEND
file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_sam2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/sam2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py"
)
Expand Down Expand Up @@ -547,6 +550,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/whisper
Expand Down Expand Up @@ -656,6 +660,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_phi2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/phi2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_sam2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/sam2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_stable_diffusion_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion/
Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/python/tools/transformers/io_binding_helper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import copy
import logging
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import numpy
import torch

from onnxruntime import InferenceSession, RunOptions

# Type alias
ShapeDict = Mapping[str, Union[Tuple, List[int]]]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -262,7 +265,7 @@ def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
)
self.output_tensors[self.buffer_sharing[name]] = tensor

def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]):
def allocate_buffers(self, shape_dict: ShapeDict):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
for name, shape in shape_dict.items():
Expand Down Expand Up @@ -346,7 +349,7 @@ def __init__(
self,
ort_session: InferenceSession,
device: torch.device,
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
shape_dict: ShapeDict,
enable_gpu_graph: bool = False,
gpu_graph_id: int = -1,
stream: int = 0,
Expand Down Expand Up @@ -406,7 +409,7 @@ def __init__(self, ort_session: InferenceSession, device: torch.device, stream:

def get_binding(
self,
shape_dict: Dict[str, Union[Tuple[int], List[int]]],
shape_dict: ShapeDict,
use_cuda_graph: bool = False,
buffer_sharing: Optional[Dict[str, str]] = None,
) -> GpuBinding:
Expand Down
65 changes: 65 additions & 0 deletions onnxruntime/python/tools/transformers/models/sam2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SAM2 ONNX Model Export

## Setup Environment
It is recommend to setup a machine with python 3.10, 3.11 or 3.12. Then install [PyTorch 2.4.1](https://pytorch.org/) and [Onnx Runtime 1.19.2].

### CPU Only
To install the CPU-only version of PyTorch and Onnx Runtime for exporting and running ONNX models, use the following commands:
```
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
python3 -m pip install onnxruntime onnx opencv-python matplotlib
```

### GPU
If your machine has an NVIDIA GPU, you can install the CUDA version of PyTorch and Onnx Runtime for exporting and running ONNX models:

```
python3 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
python3 -m pip install onnxruntime-gpu onnx opencv-python matplotlib
```

onnxruntime-gpu requires CUDA 12.x, cuDNN 9.x, and other dependencies (such as MSVC Runtime on Windows). For more information, see the [installation guide](https://onnxruntime.ai/docs/install/#python-installs).

## Download Checkpoints

Clone the SAM2 git repository and download the checkpoints:
```bash
git clone https://github.com/facebookresearch/segment-anything-2.git
cd segment-anything-2
python3 -m pip install -e .
cd checkpoints
sh ./download_ckpts.sh
```

On Windows, you can replace `sh ./download_ckpts.sh` with the following commands:
```bash
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt > sam2_hiera_tiny.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt > sam2_hiera_small.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt > sam2_hiera_base_plus.pt
curl https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt > sam2_hiera_large.pt
```

## Export ONNX
To export ONNX models, run the convert_to_onnx.py script and specify the segment-anything-2 directory created by the above git clone command:
```bash
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2
```

The exported ONNX models will be found in the sam2_onnx_models sub-directory. You can change the output directory using the `--output_dir` option.

If you want the model outputs multiple masks, append the `--multimask_output` option.

To see all parameters, run the following command:
```bash
python3 convert_to_onnx.py -h
```

## Run Demo
The exported ONNX models can run on a CPU. The demo will output sam2_demo.png.
```bash
curl https://raw.githubusercontent.com/facebookresearch/segment-anything-2/main/notebooks/images/truck.jpg > truck.jpg
python3 convert_to_onnx.py --sam2_dir path/to/segment-anything-2 --demo
```

## Limitations
- The exported image_decoder model does not support batch mode for now.
12 changes: 12 additions & 0 deletions onnxruntime/python/tools/transformers/models/sam2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import os.path
import sys

sys.path.append(os.path.dirname(__file__))

transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
if transformers_dir not in sys.path:
sys.path.append(transformers_dir)
195 changes: 195 additions & 0 deletions onnxruntime/python/tools/transformers/models/sam2/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import argparse
import os
import pathlib
import sys

import torch
from image_decoder import export_decoder_onnx, test_decoder_onnx
from image_encoder import export_image_encoder_onnx, test_image_encoder_onnx
from mask_decoder import export_mask_decoder_onnx, test_mask_decoder_onnx
from prompt_encoder import export_prompt_encoder_onnx, test_prompt_encoder_onnx
from sam2_demo import run_demo, show_all_images
from sam2_utils import build_sam2_model, get_decoder_onnx_path, get_image_encoder_onnx_path, setup_logger


def parse_arguments():
parser = argparse.ArgumentParser(description="Export SAM2 models to ONNX")

parser.add_argument(
"--model_type",
required=False,
type=str,
choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
default="sam2_hiera_large",
help="The model type to export",
)

parser.add_argument(
"--components",
required=False,
nargs="+",
choices=["image_encoder", "mask_decoder", "prompt_encoder", "image_decoder"],
default=["image_encoder", "image_decoder"],
help="Type of ONNX models to export. "
"Note that image_decoder is a combination of prompt_encoder and mask_decoder",
)

parser.add_argument(
"--output_dir",
type=str,
help="The output directory for the ONNX models",
default="sam2_onnx_models",
)

parser.add_argument(
"--dynamic_batch_axes",
required=False,
default=False,
action="store_true",
help="Export image_encoder with dynamic batch axes",
)

parser.add_argument(
"--multimask_output",
required=False,
default=False,
action="store_true",
help="Export mask_decoder or image_decoder with multimask_output",
)

parser.add_argument(
"--disable_dynamic_multimask_via_stability",
required=False,
action="store_true",
help="Disable mask_decoder dynamic_multimask_via_stability, and output first mask only."
"This option will be ignored when multimask_output is True",
)

parser.add_argument(
"--sam2_dir",
required=False,
type=str,
default="./segment-anything-2",
help="The directory of segment-anything-2 git repository",
)

parser.add_argument(
"--overwrite",
required=False,
default=False,
action="store_true",
help="Overwrite onnx model file if exists.",
)

parser.add_argument(
"--demo",
required=False,
default=False,
action="store_true",
help="Run demo with the exported ONNX models.",
)

parser.add_argument(
"--verbose",
required=False,
default=False,
action="store_true",
help="Print verbose information",
)

args = parser.parse_args()
return args


def main():
args = parse_arguments()

checkpoints_dir = os.path.join(args.sam2_dir, "checkpoints")
sam2_config_dir = os.path.join(args.sam2_dir, "sam2_configs")
if not os.path.exists(args.sam2_dir):
raise FileNotFoundError(f"{args.sam2_dir} does not exist. Please specify --sam2_dir correctly.")

if not os.path.exists(checkpoints_dir):
raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")

if not os.path.exists(sam2_config_dir):
raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")

if not os.path.exists(os.path.join(checkpoints_dir, f"{args.model_type}.pt")):
raise FileNotFoundError(
f"{checkpoints_dir}/{args.model_type}.pt does not exist. Please download checkpoints under the directory."
)

if args.sam2_dir not in sys.path:
sys.path.append(args.sam2_dir)

pathlib.Path(args.output_dir).mkdir(parents=True, exist_ok=True)

sam2_model = build_sam2_model(checkpoints_dir, args.model_type, device="cpu")

for component in args.components:
if component == "image_encoder":
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
if args.overwrite or not os.path.exists(onnx_model_path):
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)
test_image_encoder_onnx(sam2_model, onnx_model_path, dynamic_batch_axes=False)

elif component == "mask_decoder":
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_mask_decoder.onnx")
if args.overwrite or not os.path.exists(onnx_model_path):
export_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
args.verbose,
)
test_mask_decoder_onnx(
sam2_model,
onnx_model_path,
args.multimask_output,
not args.disable_dynamic_multimask_via_stability,
)
elif component == "prompt_encoder":
onnx_model_path = os.path.join(args.output_dir, f"{args.model_type}_prompt_encoder.onnx")
if args.overwrite or not os.path.exists(onnx_model_path):
export_prompt_encoder_onnx(sam2_model, onnx_model_path)
test_prompt_encoder_onnx(sam2_model, onnx_model_path)
elif component == "image_decoder":
onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, args.multimask_output)
if args.overwrite or not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)
test_decoder_onnx(sam2_model, onnx_model_path, args.multimask_output)

if args.demo:
# Export required ONNX models for demo if not already exported.
onnx_model_path = get_image_encoder_onnx_path(args.output_dir, args.model_type)
if not os.path.exists(onnx_model_path):
export_image_encoder_onnx(sam2_model, onnx_model_path, args.dynamic_batch_axes, args.verbose)

onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, True)
if not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, True)

onnx_model_path = get_decoder_onnx_path(args.output_dir, args.model_type, False)
if not os.path.exists(onnx_model_path):
export_decoder_onnx(sam2_model, onnx_model_path, False)

ort_image_files = run_demo(checkpoints_dir, args.model_type, engine="ort", onnx_directory=args.output_dir)
print("demo output files for ONNX Runtime:", ort_image_files)

# Get results from torch engine to compare.
torch_image_files = run_demo(checkpoints_dir, args.model_type, engine="torch", onnx_directory=args.output_dir)
print("demo output files for PyTorch:", torch_image_files)

show_all_images(ort_image_files, torch_image_files)


if __name__ == "__main__":
setup_logger(verbose=False)
with torch.no_grad():
main()
Loading

0 comments on commit a9740d6

Please sign in to comment.