Skip to content

Commit

Permalink
Bump torch-mlir to 140cad5 and update TorchOnnxToTorch conversion pip…
Browse files Browse the repository at this point in the history
…eline (iree-org#18867)

This commit bumps torch-mlir to
llvm/torch-mlir@140cad5.

This commit also replaces the `TorchOnnxToTorchPass` used to convert
torch-onnx IR to torch IR with `TorchOnnxToTorchBackendPipeline` which
was introduced here:
llvm/torch-mlir@fa4794d.
This pipeline consists of passes like `TorchOnnxToTorch`,
`ScalarizeShapes`, `DecomposeComplexOps`, `ShapeRefinementPipeline`,
etc., to convert the torch-onnx IR to torch IR consistent with the other
lowering paths for torch backend IR.

With the changes made in this PR, some of the tests which were failing
earlier during compilation now passes compilation, some of them even
passes inference, while there are some tests which were passing earlier
are now failing. The list is as follows:
- For
tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json:
  a.) Earlier failing compilation, now passing inference:
    - onnx/node/generated/test_convtranspose_kernel_shape
    - onnx/node/generated/test_einsum_sum

b.) Earlier failing compilation, now passing compilation but failing
inference
    - onnx/node/generated/test_convtranspose_output_shape
   
  c.) Earlier passing but now failing inference
    - onnx/node/generated/test_scan_sum
    - onnx/node/generated/test_scan9_sum

- For
tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json:
  a.) Earlier failing compilation, now passing inference:
    - onnx/node/generated/test_convtranspose_kernel_shape
    - onnx/node/generated/test_einsum_sum

b.) Earlier failing compilation, now passing compilation but failing
inference
    - onnx/node/generated/test_convtranspose_output_shape
   
  c.) Earlier passing but now failing inference
    - onnx/node/generated/test_scan_sum
    - onnx/node/generated/test_scan9_sum
    - onnx/node/generated/test_slice_default_axes
 
- For tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json:
  a.) Earlier failing compilation, now passing inference:
    - onnx/node/generated/test_convtranspose_kernel_shape

b.) Earlier failing compilation, now passing compilation but failing
inference
    - onnx/node/generated/test_convtranspose_output_shape
    - onnx/node/generated/test_einsum_sum

  c.) Earlier passing but now failing inference
    - onnx/node/generated/test_scan_sum
    - onnx/node/generated/test_scan9_sum
 
This commit
llvm/torch-mlir@55ff110
is expected to fix the newly introduced failures which will be included
in IREE in the next Torch-MLIR bump.

---------

Signed-off-by: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored Oct 23, 2024
1 parent 81c8b25 commit e3f2d47
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
15 changes: 13 additions & 2 deletions compiler/plugins/input/Torch/PluginRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,19 @@ struct TorchSession
OpPassManager &passManager, std::string_view typeMnemonic) override {
if (typeMnemonic == "onnx") {
// ONNX input is a pre-processing step to torch.
passManager.addNestedPass<func::FuncOp>(
mlir::torch::onnx_c::createTorchOnnxToTorchPass());
mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions;
// The `aten.flatten.using_ints` and `aten.unflatten.int` are added to the
// list of backend legal ops so that they are not decomposed into the
// `aten.view` op during the run of `DecomposeComplexOps` pass. The issue
// with this is that the `aten.view` op eventually lowers to
// `tensor.reshape` op while there exists a direct torch->linalg lowering
// for both the flatten/unflatten ops which lowers to
// `tensor.collapse_shape/expand_shape` op, and this is a more preferred
// path for the downstream pipeline.
torchOnnxPipelineOptions.backendLegalOps = {"aten.flatten.using_ints",
"aten.unflatten.int"};
mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline(
passManager, torchOnnxPipelineOptions);
}

if (typeMnemonic == "torch" || typeMnemonic == "onnx") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@
"onnx/node/generated/test_compress_default_axis",
"onnx/node/generated/test_compress_negative_axis",
"onnx/node/generated/test_convtranspose_autopad_same",
"onnx/node/generated/test_convtranspose_kernel_shape",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_cumsum_1d",
"onnx/node/generated/test_cumsum_1d_exclusive",
"onnx/node/generated/test_cumsum_1d_reverse",
Expand All @@ -127,7 +125,6 @@
"onnx/node/generated/test_dft_inverse_opset19",
"onnx/node/generated/test_dft_opset19",
"onnx/node/generated/test_edge_pad",
"onnx/node/generated/test_einsum_sum",
"onnx/node/generated/test_gridsample_bicubic",
"onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1",
"onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1",
Expand Down Expand Up @@ -427,6 +424,7 @@
"onnx/node/generated/test_constantofshape_float_ones",
"onnx/node/generated/test_constantofshape_int_shape_zero",
"onnx/node/generated/test_constantofshape_int_zeros",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_dropout_default_mask_ratio",
"onnx/node/generated/test_gridsample_nearest",
"onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1",
Expand All @@ -447,6 +445,8 @@
"onnx/node/generated/test_reduce_min_empty_set",
"onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
"onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_scan9_sum",
"onnx/node/generated/test_shape_clip_start",
"onnx/node/generated/test_shape_end_1",
"onnx/node/generated/test_shape_start_1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@
"onnx/node/generated/test_compress_default_axis",
"onnx/node/generated/test_compress_negative_axis",
"onnx/node/generated/test_convtranspose_autopad_same",
"onnx/node/generated/test_convtranspose_kernel_shape",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_cumsum_1d",
"onnx/node/generated/test_cumsum_1d_exclusive",
"onnx/node/generated/test_cumsum_1d_reverse",
Expand All @@ -131,7 +129,6 @@
"onnx/node/generated/test_dft_inverse_opset19",
"onnx/node/generated/test_dft_opset19",
"onnx/node/generated/test_edge_pad",
"onnx/node/generated/test_einsum_sum",
"onnx/node/generated/test_gridsample_bicubic",
"onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1",
"onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1",
Expand Down Expand Up @@ -442,6 +439,7 @@
"onnx/node/generated/test_constantofshape_float_ones",
"onnx/node/generated/test_constantofshape_int_shape_zero",
"onnx/node/generated/test_constantofshape_int_zeros",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_dropout_default_mask_ratio",
"onnx/node/generated/test_eyelike_populate_off_main_diagonal",
"onnx/node/generated/test_eyelike_with_dtype",
Expand Down Expand Up @@ -490,6 +488,8 @@
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded",
"onnx/node/generated/test_resize_downsample_scales_linear_align_corners",
"onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_scan9_sum",
"onnx/node/generated/test_shape",
"onnx/node/generated/test_shape_clip_end",
"onnx/node/generated/test_shape_clip_start",
Expand All @@ -501,6 +501,7 @@
"onnx/node/generated/test_shape_start_negative_1",
"onnx/node/generated/test_size",
"onnx/node/generated/test_size_example",
"onnx/node/generated/test_slice_default_axes",
"onnx/node/generated/test_split_zero_size_splits_opset13",
"onnx/node/generated/test_split_zero_size_splits_opset18",
"onnx/node/generated/test_top_k",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@
"onnx/node/generated/test_compress_default_axis",
"onnx/node/generated/test_compress_negative_axis",
"onnx/node/generated/test_convtranspose_autopad_same",
"onnx/node/generated/test_convtranspose_kernel_shape",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_cumsum_1d",
"onnx/node/generated/test_cumsum_1d_exclusive",
"onnx/node/generated/test_cumsum_1d_reverse",
Expand All @@ -162,7 +160,6 @@
"onnx/node/generated/test_dft_inverse_opset19",
"onnx/node/generated/test_dft_opset19",
"onnx/node/generated/test_edge_pad",
"onnx/node/generated/test_einsum_sum",
"onnx/node/generated/test_gridsample",
"onnx/node/generated/test_gridsample_aligncorners_true",
"onnx/node/generated/test_gridsample_bicubic",
Expand Down Expand Up @@ -528,6 +525,7 @@
"onnx/node/generated/test_constantofshape_int_zeros",
"onnx/node/generated/test_convinteger_with_padding",
"onnx/node/generated/test_convinteger_without_padding",
"onnx/node/generated/test_convtranspose_output_shape",
"onnx/node/generated/test_dequantizelinear_int16",
"onnx/node/generated/test_dequantizelinear_uint16",
"onnx/node/generated/test_dropout_default_mask_ratio",
Expand All @@ -536,6 +534,7 @@
"onnx/node/generated/test_einsum_batch_diagonal",
"onnx/node/generated/test_einsum_batch_matmul",
"onnx/node/generated/test_einsum_transpose",
"onnx/node/generated/test_einsum_sum",
"onnx/node/generated/test_eyelike_with_dtype",
"onnx/node/generated/test_isinf_float16",
"onnx/node/generated/test_isnan_float16",
Expand Down Expand Up @@ -596,6 +595,8 @@
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_example_expanded",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random",
"onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded",
"onnx/node/generated/test_scan_sum",
"onnx/node/generated/test_scan9_sum",
"onnx/node/generated/test_shape_clip_start",
"onnx/node/generated/test_shape_end_1",
"onnx/node/generated/test_shape_start_1",
Expand Down
2 changes: 1 addition & 1 deletion third_party/torch-mlir
Submodule torch-mlir updated 26 files
+4 −0 CMakeLists.txt
+1 −0 build_tools/python_deploy/build_windows_ci.sh
+1 −1 externals/llvm-project
+28 −0 include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
+5 −0 include/torch-mlir/Dialect/Torch/Transforms/Passes.h
+59 −22 lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
+145 −1 lib/Dialect/Torch/IR/TorchOps.cpp
+24 −42 lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
+36 −0 lib/Dialect/Torch/Transforms/Passes.cpp
+517 −251 lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
+0 −1 lib/Dialect/TorchConversion/Transforms/Passes.cpp
+74 −27 projects/pt1/e2e_testing/xfail_sets.py
+16 −34 projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py
+7 −2 projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py
+78 −2 projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
+9 −17 projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py
+1 −1 pytorch-hash.txt
+1 −1 pytorch-requirements.txt
+72 −0 test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
+156 −18 test/Dialect/Torch/scalarize-shapes.mlir
+110 −0 test/Dialect/Torch/torch-nary-canonicalize.mlir
+67 −0 test/Dialect/Torch/torch-onnx-to-torch-backend-pipeline.mlir
+77 −77 test/python/fx_importer/sparsity/sparse_test.py
+10 −7 test/python/fx_importer/symbolic_shape_expr_test.py
+3 −1 test/python/fx_importer/v2.3/mutation_import.py
+1 −1 torchvision-requirements.txt

0 comments on commit e3f2d47

Please sign in to comment.