diff --git a/compiler/plugins/input/Torch/PluginRegistration.cpp b/compiler/plugins/input/Torch/PluginRegistration.cpp index b9a497f12550..6f79686ad267 100644 --- a/compiler/plugins/input/Torch/PluginRegistration.cpp +++ b/compiler/plugins/input/Torch/PluginRegistration.cpp @@ -57,8 +57,19 @@ struct TorchSession OpPassManager &passManager, std::string_view typeMnemonic) override { if (typeMnemonic == "onnx") { // ONNX input is a pre-processing step to torch. - passManager.addNestedPass( - mlir::torch::onnx_c::createTorchOnnxToTorchPass()); + mlir::torch::Torch::TorchLoweringPipelineOptions torchOnnxPipelineOptions; + // The `aten.flatten.using_ints` and `aten.unflatten.int` are added to the + // list of backend legal ops so that they are not decomposed into the + // `aten.view` op during the run of `DecomposeComplexOps` pass. The issue + // with this is that the `aten.view` op eventually lowers to + // `tensor.reshape` op while there exists a direct torch->linalg lowering + // for both the flatten/unflatten ops which lowers to + // `tensor.collapse_shape/expand_shape` op, and this is a more preferred + // path for the downstream pipeline. + torchOnnxPipelineOptions.backendLegalOps = {"aten.flatten.using_ints", + "aten.unflatten.int"}; + mlir::torch::Torch::createTorchOnnxToTorchBackendPipeline( + passManager, torchOnnxPipelineOptions); } if (typeMnemonic == "torch" || typeMnemonic == "onnx") { diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json index 0bbd604384f8..a025431d7af4 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_cpu_llvm_sync.json @@ -101,8 +101,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -127,7 +125,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample_bicubic", "onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1", "onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1", @@ -427,6 +424,7 @@ "onnx/node/generated/test_constantofshape_float_ones", "onnx/node/generated/test_constantofshape_int_shape_zero", "onnx/node/generated/test_constantofshape_int_zeros", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dropout_default_mask_ratio", "onnx/node/generated/test_gridsample_nearest", "onnx/node/generated/test_gridsample_nearest_align_corners_0_additional_1", @@ -447,6 +445,8 @@ "onnx/node/generated/test_reduce_min_empty_set", "onnx/node/generated/test_reduce_sum_empty_set_non_reduced_axis_zero", "onnx/node/generated/test_resize_downsample_scales_linear_align_corners", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape_clip_start", "onnx/node/generated/test_shape_end_1", "onnx/node/generated/test_shape_start_1", diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json index 9901e017948c..79cc5a9a4add 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_rocm_rdna3.json @@ -105,8 +105,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -131,7 +129,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample_bicubic", "onnx/node/generated/test_gridsample_bicubic_align_corners_0_additional_1", "onnx/node/generated/test_gridsample_bicubic_align_corners_1_additional_1", @@ -442,6 +439,7 @@ "onnx/node/generated/test_constantofshape_float_ones", "onnx/node/generated/test_constantofshape_int_shape_zero", "onnx/node/generated/test_constantofshape_int_zeros", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dropout_default_mask_ratio", "onnx/node/generated/test_eyelike_populate_off_main_diagonal", "onnx/node/generated/test_eyelike_with_dtype", @@ -490,6 +488,8 @@ "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded", "onnx/node/generated/test_resize_downsample_scales_linear_align_corners", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape", "onnx/node/generated/test_shape_clip_end", "onnx/node/generated/test_shape_clip_start", @@ -501,6 +501,7 @@ "onnx/node/generated/test_shape_start_negative_1", "onnx/node/generated/test_size", "onnx/node/generated/test_size_example", + "onnx/node/generated/test_slice_default_axes", "onnx/node/generated/test_split_zero_size_splits_opset13", "onnx/node/generated/test_split_zero_size_splits_opset18", "onnx/node/generated/test_top_k", diff --git a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json index e6d9a4a4e201..8c31c26a421a 100644 --- a/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json +++ b/tests/external/iree-test-suites/onnx_ops/onnx_ops_gpu_vulkan.json @@ -136,8 +136,6 @@ "onnx/node/generated/test_compress_default_axis", "onnx/node/generated/test_compress_negative_axis", "onnx/node/generated/test_convtranspose_autopad_same", - "onnx/node/generated/test_convtranspose_kernel_shape", - "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_cumsum_1d", "onnx/node/generated/test_cumsum_1d_exclusive", "onnx/node/generated/test_cumsum_1d_reverse", @@ -162,7 +160,6 @@ "onnx/node/generated/test_dft_inverse_opset19", "onnx/node/generated/test_dft_opset19", "onnx/node/generated/test_edge_pad", - "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_gridsample", "onnx/node/generated/test_gridsample_aligncorners_true", "onnx/node/generated/test_gridsample_bicubic", @@ -528,6 +525,7 @@ "onnx/node/generated/test_constantofshape_int_zeros", "onnx/node/generated/test_convinteger_with_padding", "onnx/node/generated/test_convinteger_without_padding", + "onnx/node/generated/test_convtranspose_output_shape", "onnx/node/generated/test_dequantizelinear_int16", "onnx/node/generated/test_dequantizelinear_uint16", "onnx/node/generated/test_dropout_default_mask_ratio", @@ -536,6 +534,7 @@ "onnx/node/generated/test_einsum_batch_diagonal", "onnx/node/generated/test_einsum_batch_matmul", "onnx/node/generated/test_einsum_transpose", + "onnx/node/generated/test_einsum_sum", "onnx/node/generated/test_eyelike_with_dtype", "onnx/node/generated/test_isinf_float16", "onnx/node/generated/test_isnan_float16", @@ -596,6 +595,8 @@ "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_example_expanded", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random", "onnx/node/generated/test_reduce_sum_square_default_axes_keepdims_random_expanded", + "onnx/node/generated/test_scan_sum", + "onnx/node/generated/test_scan9_sum", "onnx/node/generated/test_shape_clip_start", "onnx/node/generated/test_shape_end_1", "onnx/node/generated/test_shape_start_1", diff --git a/third_party/torch-mlir b/third_party/torch-mlir index 45bb17ebfe5e..140cad5659bb 160000 --- a/third_party/torch-mlir +++ b/third_party/torch-mlir @@ -1 +1 @@ -Subproject commit 45bb17ebfe5e9cdcfd2cfabf850d9dec7127c5ab +Subproject commit 140cad5659bb779bb1f5de1888566db5b5d21236