Skip to content

Commit

Permalink
[alt] Add a CL-based mode (nod-ai#342)
Browse files Browse the repository at this point in the history
Adds `--mode=cl-onnx-iree` flag option which mirrors `--mode=onnx-iree`
(default), but runs most stages via command line scripts rather than
python bindings. This seems to work better for a few reasons:

1. Easier command-line reproducers for failures. (The script which fails
is printed in the corresponding log file)
2. Memory management per stage is more contained and doesn't need to
rely on python's garbage collector.
  • Loading branch information
zjgarvey authored Sep 18, 2024
1 parent 00ba2f3 commit dbbe625
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 53 deletions.
55 changes: 53 additions & 2 deletions alt_e2eshark/e2e_testing/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import abc
import onnxruntime as ort
from typing import TypeVar, List
from e2e_testing.storage import TestTensors
from e2e_testing.storage import TestTensors, get_shape_string
from e2e_testing.framework import CompiledOutput, ModelArtifact
from onnx import ModelProto
import os
from pathlib import Path

Invoker = TypeVar("Invoker")

Expand Down Expand Up @@ -72,7 +74,7 @@ def compile(self, module, *, save_to: str = None):
)
# log the vmfb
if save_to:
with open(save_to + "compiled_model.vmfb", "wb") as f:
with open(os.path.join(save_to, "compiled_model.vmfb"), "wb") as f:
f.write(b)
return b

Expand All @@ -94,6 +96,55 @@ def func(x):

return func

class CLIREEBackend(BackendBase):
'''This backend calls iree through the command line to compile and run MLIR modules'''
def __init__(self, *, device="local-task", hal_target_backend="llvm-cpu", extra_args : List[str] = None):
self.device = device
self.hal_target_backend = hal_target_backend
self.extra_args = []
if extra_args:
for a in extra_args:
if a[0:2] == "--":
self.extra_args.append(a)
else:
self.extra_args.append("--" + a)

def compile(self, module_path: str, *, save_to : str = None) -> str:
vmfb_path = os.path.join(save_to, "compiled_model.vmfb")
arg_string = f"--iree-hal-target-backends={self.hal_target_backend} "
for arg in self.extra_args:
arg_string += arg
arg_string += " "
command_error_dump = os.path.join(save_to, "detail", "compilation.detail.log")
commands_log = os.path.join(save_to, "commands", "compilation.commands.log")
script = f"iree-compile {module_path} {arg_string}-o {vmfb_path} 1> {command_error_dump} 2>&1"
with open(commands_log, "w") as file:
file.write(script)
# remove old vmfb if it exists
Path(vmfb_path).unlink(missing_ok=True)
os.system(script)
if not os.path.exists(vmfb_path):
error_message = f"failure executing command: \n{script}\n failed to produce a vmfb at {vmfb_path}.\n"
if os.path.exists(command_error_dump):
error_message += "Error Details:\n\n"
with open(command_error_dump, "r+") as file:
error_message += file.read()
raise FileNotFoundError(error_message)
return vmfb_path

def load(self, vmfb_path: str, *, func_name=None):
"""A bit hacky. func returns a script that would dump outputs to terminal output. Modified in config.run method"""
run_dir = Path(vmfb_path).parent
def func(x: TestTensors) -> str:
script = f"iree-run-module --module='{vmfb_path}' --device={self.device}"
if func_name:
script += f" --function='{func_name}'"
torch_inputs = x.to_torch().data
for index, input in enumerate(torch_inputs):
script += f" --input='{get_shape_string(input)}=@{run_dir}/input.{index}.bin'"
return script
return func


class OnnxrtIreeEpBackend(BackendBase):
'''This backend uses onnxrt iree-ep to compile and run onnx models for a specified hal_target_backend'''
Expand Down
6 changes: 4 additions & 2 deletions alt_e2eshark/e2e_testing/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def save_processed_output(self, output: TestTensors, save_to: str, name: str):

# the following helper methods aren't meant to be overriden

def get_signature(self, *, from_inputs=True):
def get_signature(self, *, from_inputs=True, leave_dynamic=False):
"""Returns the input or output signature of self.model"""
if not os.path.exists(self.model):
self.construct_model()
return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict)
if not leave_dynamic:
self.update_dim_param_dict()
return get_signature_for_onnx_model(self.model, from_inputs=from_inputs, dim_param_dict=self.dim_param_dict, leave_dynamic=leave_dynamic)

def load_inputs(self, dir_path):
"""computes the input signature of the onnx model and loads inputs from bin files"""
Expand Down
8 changes: 6 additions & 2 deletions alt_e2eshark/e2e_testing/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_sample_inputs_for_onnx_model(model_path, dim_param_dict = None):
return sample_inputs


def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True, dim_param_dict: Optional[dict[str, int]] = None):
def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True, dim_param_dict: Optional[dict[str, int]] = None, leave_dynamic: bool = False):
"""A convenience funtion for retrieving the input or output shapes and dtypes"""
s = onnxruntime.InferenceSession(model_path, None)
if from_inputs:
Expand All @@ -83,7 +83,11 @@ def get_signature_for_onnx_model(model_path, *, from_inputs: bool = True, dim_pa
shapes = []
dtypes = []
for i in nodes:
shapes.append(i.shape)
shape = i.shape
for index, s in enumerate(shape):
if not leave_dynamic and isinstance(s, str) and s in dim_param_dict.keys():
shape[index] = dim_param_dict[s]
shapes.append(shape)
dtypes.append(dtype_from_ort_node(i))
return shapes, dtypes

Expand Down
3 changes: 2 additions & 1 deletion alt_e2eshark/e2e_testing/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
from typing import Tuple, Optional, Dict, List, Any, Union
from pathlib import Path
import os

def get_shape_string(torch_tensor):
input_shape = list(torch_tensor.shape)
Expand Down Expand Up @@ -211,7 +212,7 @@ def load_from(shapes, torch_dtypes, dir_path: str, name: str = "input"):
for i in range(len(shapes)):
shape = shapes[i]
dtype = torch_dtypes[i]
t = load_raw_binary_as_torch_tensor(dir_path + name + "." + str(i) + ".bin", shape, dtype)
t = load_raw_binary_as_torch_tensor(os.path.join(dir_path, name + "." + str(i) + ".bin"), shape, dtype)
tensor_list.append(t)
return TestTensors(tuple(tensor_list))

Expand Down
139 changes: 135 additions & 4 deletions alt_e2eshark/e2e_testing/test_configs/onnxconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
from e2e_testing.framework import TestConfig, OnnxModelInfo, Module, CompiledArtifact
from e2e_testing.storage import TestTensors
from torch_mlir.passmanager import PassManager
from typing import Tuple
from typing import Tuple, Any
from onnxruntime import InferenceSession
import os
from pathlib import Path

REDUCE_TO_LINALG_PIPELINE = [
"torch-lower-to-backend-contract",
Expand Down Expand Up @@ -81,7 +83,7 @@ def import_model(self, model_info: OnnxModelInfo, *, save_to: str = None) -> Tup
imp.import_all()
# log imported IR
if save_to:
with open(save_to + "model.torch_onnx.mlir", "w") as f:
with open(os.path.join(save_to, "model.torch_onnx.mlir"), "w") as f:
f.write(str(m))
return m, func_name

Expand All @@ -96,13 +98,13 @@ def preprocess_model(self, mlir_module: Module, *, save_to: str = None) -> Modul
pm0.run(mlir_module.operation)
# log torch-mlir IR
if save_to:
with open(save_to + "model.torch.mlir", "w") as f:
with open(os.path.join(save_to, "model.torch.mlir"), "w") as f:
f.write(str(mlir_module))
pm1 = PassManager.parse(self.pass_pipeline)
pm1.run(mlir_module.operation)
# log modified IR
if save_to:
with open(save_to + "model.modified.mlir", "w") as f:
with open(os.path.join(save_to, "model.modified.mlir"), "w") as f:
f.write(str(mlir_module))
return mlir_module

Expand All @@ -112,3 +114,132 @@ def compile(self, mlir_module: Module, *, save_to: str = None) -> CompiledArtifa
def run(self, artifact: CompiledArtifact, inputs: TestTensors, *, func_name="main") -> TestTensors:
func = self.backend.load(artifact, func_name=func_name)
return func(inputs)

class CLOnnxTestConfig(TestConfig):
'''This is parallel to OnnxTestConfig, but uses command-line scripts for each stage.'''
def __init__(
self, log_dir: str, backend: BackendBase, torch_mlir_pipeline: Tuple[str, ...]
):
super().__init__()
self.log_dir = log_dir
self.backend = backend
self.tensor_info_dict = dict()
if len(torch_mlir_pipeline) > 0:
self.pass_pipeline = "builtin.module(" + ",".join(torch_mlir_pipeline) + ")"
else:
self.pass_pipeline = None

def import_model(self, program: OnnxModelInfo, *, save_to: str) -> Tuple[str, str]:
if not save_to:
raise ValueError("CLOnnxTestConfig requires saving artifacts")
# setup a detail subdirectory
os.makedirs(os.path.join(save_to, "detail"), exist_ok=True)
# setup a commands subdirectory
os.makedirs(os.path.join(save_to, "commands"), exist_ok=True)
# set file paths
mlir_file = os.path.join(save_to, "model.torch_onnx.mlir")
detail_log = os.path.join(save_to, "detail", "import_model.detail.log")
commands_log = os.path.join(save_to, "commands", "import_model.commands.log")
# get a command line script
script = "python -m torch_mlir.tools.import_onnx "
script += str(program.model)
script += " -o "
script = script + mlir_file
script += f" 1> {detail_log} 2>&1"
# log the command
with open(commands_log, "w") as file:
file.write(script)
# remove old mlir_file if present
Path(mlir_file).unlink(missing_ok=True)
# run the command
os.system(script)
# check if a new mlir file was generated
if not os.path.exists(mlir_file):
error_msg = f"failure executing command: \n{script}\n failed to produce mlir file {mlir_file}.\n"
if os.path.exists(detail_log):
error_msg += "Error detail:\n\n"
with open(detail_log,"r+") as file:
error_msg += file.read()
raise FileNotFoundError(error_msg)
# store output signatures for loading the outputs of iree-run-module
self.tensor_info_dict[program.name] = program.get_signature(from_inputs=False)
# get the func name
# TODO put this as an OnnxModelInfo attr?
model = onnx.load(program.model, load_external_data=False)
func_name = model.graph.name
return mlir_file, func_name

def preprocess_model(self, mlir_module: str, *, save_to: str = None) -> Module:
# if the pass pipeline is empty, return the original module
if not self.pass_pipeline:
return mlir_module
# convert imported torch-onnx ir to torch
onnx_to_torch_pipeline = "builtin.module(func.func(convert-torch-onnx-to-torch))"
# get paths
detail_log = os.path.join(save_to, "detail", "preprocessing.detail.log")
commands_log = os.path.join(save_to, "commands", "preprocessing.commands.log")
torch_ir = os.path.join(save_to, "model.torch.mlir")
linalg_ir = os.path.join(save_to, "model.modified.mlir")
# generate scripts
script0 = f"torch-mlir-opt -pass-pipeline='{onnx_to_torch_pipeline}' {mlir_module} -o {torch_ir} 1> {detail_log} 2>&1"
script1 = f"torch-mlir-opt -pass-pipeline='{self.pass_pipeline}' {torch_ir} -o {linalg_ir} 1> {detail_log} 2>&1"
# remove old torch_ir
Path(torch_ir).unlink(missing_ok=True)
with open(commands_log, "w") as file:
file.write(script0)
file.write(script1)
# run torch-onnx-to-torch
os.system(script0)
if not os.path.exists(torch_ir):
error_msg = f"failure executing command: \n{script0}\n failed to produce mlir file {torch_ir}.\n"
if os.path.exists(detail_log):
error_msg += "Error detail:\n\n"
with open(detail_log,"r+") as file:
error_msg += file.read()
raise FileNotFoundError(error_msg)
# remove old linalg ir
Path(linalg_ir).unlink(missing_ok=True)
# run torch-to-linalg pipeline
os.system(script1)
if not os.path.exists(linalg_ir):
error_msg = f"failure executing command: \n{script1}\n failed to produce mlir file {linalg_ir}.\n"
if os.path.exists(detail_log):
error_msg += "Error detail:\n\n"
with open(detail_log,"r+") as file:
error_msg += file.read()
raise FileNotFoundError(error_msg)
return linalg_ir

def compile(self, mlir_module: str, *, save_to: str = None) -> str:
return self.backend.compile(mlir_module, save_to=save_to)

def run(self, artifact: str, inputs: TestTensors, *, func_name=None) -> TestTensors:
run_dir = Path(artifact).parent
test_name = run_dir.name
detail_log = run_dir.joinpath("detail", "compiled_inference.detail.log")
commands_log = run_dir.joinpath("commands", "compiled_inference.commands.log")
func = self.backend.load(artifact, func_name=func_name)
script = func(inputs)
num_outputs = len(self.tensor_info_dict[test_name][0])
output_files = []
for i in range(num_outputs):
output_files.append(os.path.join(run_dir, f"output.{i}.bin"))
script += f" --output=@'{output_files[i]}'"
# remove existing output files if they already exist
# we use the existence of these files to check if the inference succeeded.
Path(output_files[i]).unlink(missing_ok=True)
# dump additional error messaging to the detail log.
script += f" 1> {detail_log} 2>&1"
with open(commands_log, "w") as file:
file.write(script)
os.system(script)
for file in output_files:
if not os.path.exists(file):
error_msg = f"failure executing command: \n{script}\n failed to produce output file {file}.\n"
if os.path.exists(detail_log):
error_msg += "Error detail:\n\n"
with open(detail_log,"r+") as file:
error_msg += file.read()
raise FileNotFoundError(error_msg)
return TestTensors.load_from(self.tensor_info_dict[test_name][0], self.tensor_info_dict[test_name][1], run_dir, "output")

37 changes: 2 additions & 35 deletions alt_e2eshark/onnx_tests/models/migraphx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,6 @@
# 3. setup dim params for other misc models
# 4. reupload cadence model 1

ALL_MODELS = [
"migraphx_agentmodel__AgentModel",
"migraphx_bert__bert-large-uncased",
"migraphx_bert__bertsquad-12",
"migraphx_cadene__dpn92i1",
"migraphx_cadene__inceptionv4i16",
"migraphx_cadene__resnext101_64x4di1",
"migraphx_cadene__resnext101_64x4di16",
"migraphx_huggingface-transformers__bert_mrpc8",
"migraphx_mlperf__bert_large_mlperf",
"migraphx_mlperf__resnet50_v1",
"migraphx_onnx-misc__taau_low_res_downsample_d2s_for_infer_time_fp16_opset11",
"migraphx_onnx-model-zoo__gpt2-10",
"migraphx_ORT__bert_base_cased_1",
"migraphx_ORT__bert_base_uncased_1",
"migraphx_ORT__bert_large_uncased_1",
"migraphx_ORT__distilgpt2_1",
"migraphx_ORT__onnx_models__bert_base_cased_1_fp16_gpu",
"migraphx_ORT__onnx_models__bert_large_uncased_1_fp16_gpu",
"migraphx_ORT__onnx_models__distilgpt2_1_fp16_gpu",
"migraphx_pytorch-examples__wlang_gru",
"migraphx_pytorch-examples__wlang_lstm",
"migraphx_sd__unet__model",
"migraphx_sdxl__unet__model",
"migraphx_torchvision__densenet121i32",
"migraphx_torchvision__inceptioni1",
"migraphx_torchvision__inceptioni32",
"migraphx_torchvision__resnet50i1",
"migraphx_torchvision__resnet50i64",
]


def dim_param_constructor(dim_param_dict):
class AzureWithDimParams(AzureDownloadableModel):
def __init__(self, *args, **kwargs):
Expand All @@ -70,8 +38,7 @@ def update_dim_param_dict(self):
ORT_model_names = [
"migraphx_ORT__bert_base_cased_1", # batch_size, seq_len
"migraphx_ORT__bert_base_uncased_1", # batch_size, seq_len
# the following test currently crashes for some reason (maybe opset version related?)
# "migraphx_ORT__bert_large_uncased_1", # batch_size, seq_len
"migraphx_ORT__bert_large_uncased_1", # batch_size, seq_len
"migraphx_ORT__distilgpt2_1", # batch_size, seq_len
"migraphx_ORT__onnx_models__bert_base_cased_1_fp16_gpu", # batch_size, seq_len
"migraphx_ORT__onnx_models__bert_large_uncased_1_fp16_gpu", # batch_size, seq_len
Expand Down Expand Up @@ -129,7 +96,7 @@ def update_dim_param_dict(self):
"migraphx_models__whisper-tiny-decoder" : {"batch_size" : 1, "decoder_sequence_length" : 64, "encoder_sequence_length / 2" : 32},
"migraphx_models__whisper-tiny-encoder" : {"batch_size" : 1, "feature_size" : 80, "encoder_sequence_length" : 64},
# this one crashes for some reason...
# "migraphx_sdxl__unet__model" : {"batch_size" : 1, "num_channels" : 4, "height" : 512, "width" : 512, "steps" : 2, "sequence_length" : 64}
"migraphx_sdxl__unet__model" : {"batch_size" : 1, "num_channels" : 4, "height" : 512, "width" : 512, "steps" : 2, "sequence_length" : 64}
}

for key, dim_param in misc_models.items():
Expand Down
Loading

0 comments on commit dbbe625

Please sign in to comment.