Skip to content

Commit

Permalink
[ONNX][TOSA] Adds ONNX to TOSA e2e tests (llvm#3358)
Browse files Browse the repository at this point in the history
- Refactors OnnxBackend to be generic and consume any Torch backend.

---------

Signed-off-by: Suraj Sudhir <[email protected]>
  • Loading branch information
sjarus authored May 17, 2024
1 parent 28193fd commit cba91a9
Show file tree
Hide file tree
Showing 6 changed files with 1,040 additions and 138 deletions.
12 changes: 8 additions & 4 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import (
RefBackendLinalgOnTensorsBackend,
)
from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import (
LinalgOnTensorsOnnxBackend,
)
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import (
LinalgOnTensorsTosaBackend,
)
Expand All @@ -56,6 +53,7 @@
FX_IMPORTER_STABLEHLO_XFAIL_SET,
FX_IMPORTER_STABLEHLO_CRASHING_SET,
FX_IMPORTER_TOSA_XFAIL_SET,
ONNX_TOSA_XFAIL_SET,
)

# Import tests to register them in the global registry.
Expand All @@ -75,6 +73,7 @@ def _get_argparse():
"lazy_tensor_core",
"torchdynamo",
"onnx",
"onnx_tosa",
"fx_importer",
"fx_importer_stablehlo",
"fx_importer_tosa",
Expand All @@ -98,6 +97,7 @@ def _get_argparse():
"fx_importer": run the model through the fx importer frontend and execute the graph using Linalg-on-Tensors.
"fx_importer_stablehlo": run the model through the fx importer frontend and execute the graph using Stablehlo backend.
"fx_importer_tosa": run the model through the fx importer frontend and execute the graph using the TOSA backend.
"onnx_tosa": Import ONNX to Torch via the torch-onnx-to-torch path and execute the graph using the TOSA backend.
""",
)
parser.add_argument(
Expand Down Expand Up @@ -191,9 +191,13 @@ def main():
xfail_set = TORCHDYNAMO_XFAIL_SET
crashing_set = TORCHDYNAMO_CRASHING_SET
elif args.config == "onnx":
config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend())
config = OnnxBackendTestConfig(RefBackendLinalgOnTensorsBackend())
xfail_set = ONNX_XFAIL_SET
crashing_set = ONNX_CRASHING_SET
elif args.config == "onnx_tosa":
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
xfail_set = ONNX_TOSA_XFAIL_SET
crashing_set = set()

do_not_attempt = set(
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
Expand Down
Loading

0 comments on commit cba91a9

Please sign in to comment.