Skip to content

Commit

Permalink
Update IREE onnx import to be in sync with Torch-MLIR (iree-org#17476)
Browse files Browse the repository at this point in the history
This commit updates `iree-import-onnx` so that it behaves the same as
torch-mlir's version
(https://github.com/llvm/torch-mlir/blob/main/python/torch_mlir/tools/import_onnx/__main__.py).
Specifically, enabling the data propagation improves shape inference and
is leading to more models passing.

Related to iree-org#17021

---------

Signed-off-by: saienduri <[email protected]>
  • Loading branch information
saienduri authored and gglangg committed Jun 4, 2024
1 parent 4117023 commit 8345b9d
Showing 1 changed file with 65 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
python -m iree.compiler.tools.import_onnx ...
"""
import argparse
import os
from pathlib import Path
import sys
import tempfile

try:
import onnx
Expand All @@ -38,8 +40,8 @@
)


def main(args):
model_proto = load_onnx_model(args.input_file)
def main(args: argparse.Namespace):
model_proto = load_onnx_model(args)
context = Context()
model_info = onnx_importer.ModelInfo(model_proto)
m = model_info.create_module(context=context).operation
Expand All @@ -58,13 +60,56 @@ def main(args):
print(m.get_asm(assume_verified=not args.no_verify))


def load_onnx_model(file_path: Path) -> onnx.ModelProto:
raw_model = onnx.load(file_path)
inferred_model = onnx.shape_inference.infer_shapes(raw_model)
return inferred_model
def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto:
input_dir = os.path.dirname(os.path.abspath(args.input_file))


def parse_arguments(argv=None):
# Load the model, with possible external data coming from the default
# location, or the location specified on the command line.
if args.data_dir is None:
raw_model = onnx.load(args.input_file)
else:
raw_model = onnx.load(args.input_file, load_external_data=False)
onnx.load_external_data_for_model(raw_model, args.data_dir)

# Do shape inference two ways. First, attempt in-memory to avoid redundant
# loading and the need for writing a temporary file somewhere. If that
# fails, typically because of the 2 GB protobuf size limit, try again via
# files. See
# https://onnx.ai/onnx/repo-docs/PythonAPIOverview.html#shape-inference-a-large-onnx-model-2gb
# for details about the file-based technique.

# Run the checker to test whether the file is above the threshold for
# in-memory shape inference. If not, go ahead and do the shape inference.
try:
onnx.checker.check_model(raw_model)
inferred_model = onnx.shape_inference.infer_shapes(
raw_model, data_prop=args.data_prop
)
return inferred_model
except ValueError:
pass

# Model is too big for in-memory inference: do file-based shape inference
# to a temp file.
# Make a temp dir for all the temp files we'll be generating as a side
# effect of infering shapes. For now, the only file is a new .onnx holding
# the revised model with shapes.
with tempfile.TemporaryDirectory(dir=input_dir) as temp_dir_name:
temp_dir_path = Path(temp_dir_name)
temp_inferred_file = temp_dir_path / "temp-inferred.onnx"
onnx.shape_inference.infer_shapes_path(
args.input_file, temp_inferred_file, data_prop=args.data_prop
)

# Load the temp file and the external data.
inferred_model = onnx.load(temp_inferred_file, load_external_data=False)
data_dir = Path(input_dir if args.data_dir is None else args.data_dir)
onnx.load_external_data_for_model(inferred_model, data_dir)

return inferred_model


def parse_arguments(argv=None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="IREE ONNX import tool")
parser.add_argument("input_file", help="ONNX protobuf input", type=Path)
parser.add_argument(
Expand All @@ -75,6 +120,18 @@ def parse_arguments(argv=None):
action="store_true",
help="Disable verification prior to printing",
)
parser.add_argument(
"--data-prop",
default=True,
action=argparse.BooleanOptionalAction,
help="Toggle data propogation for onnx shape inference",
)
parser.add_argument(
"--data-dir",
help="Path to the base directory of the data."
" Defaults to the directory of the input file.",
type=Path,
)
args = parser.parse_args(argv)
return args

Expand Down

0 comments on commit 8345b9d

Please sign in to comment.