From 21d79c88c72d21d754e6d76ced6af1d9bb47070f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 21:39:34 -0700 Subject: [PATCH 1/5] Add support for reducelogsum op --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 25 ++++++++++ setup.py | 2 + .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 50 +++++++++++++++++++ 3 files changed, 77 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6c86ecb92789..ba8f71a6d234 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -907,6 +907,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); + patterns.onOp("ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + auto reducedSumBool = reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); + + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data); + return success(); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/setup.py b/setup.py index d5e3d055c68b..5193a9b4fab1 100644 --- a/setup.py +++ b/setup.py @@ -120,6 +120,8 @@ def cmake_build(self, cmake_build_dir): f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if TORCH_MLIR_ENABLE_LTC else 'OFF'}", f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", + f"-DCMAKE_C_COMPILER=clang", + f"-DCMAKE_CXX_COMPILER=clang++", ] if LLVM_INSTALL_DIR: cmake_config_args += [ diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 43849fbbd06e..6a8757134da1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -843,6 +843,56 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example +func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example +func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example +func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + // CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> From 5db13045155cc3d342f82fb942a08a5dc9e0af20 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 21:46:56 -0700 Subject: [PATCH 2/5] Revert setup.py changes Revert "Add support for reducelogsum op" This reverts commit 21d79c88c72d21d754e6d76ced6af1d9bb47070f. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 25 ---------- setup.py | 2 - .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 50 ------------------- 3 files changed, 77 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ba8f71a6d234..6c86ecb92789 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -907,31 +907,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); - patterns.onOp("ReduceLogSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - - auto reducedSumBool = reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); - - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, data); - return success(); - }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/setup.py b/setup.py index 5193a9b4fab1..d5e3d055c68b 100644 --- a/setup.py +++ b/setup.py @@ -120,8 +120,6 @@ def cmake_build(self, cmake_build_dir): f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if TORCH_MLIR_ENABLE_LTC else 'OFF'}", f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", - f"-DCMAKE_C_COMPILER=clang", - f"-DCMAKE_CXX_COMPILER=clang++", ] if LLVM_INSTALL_DIR: cmake_config_args += [ diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 6a8757134da1..43849fbbd06e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -843,56 +843,6 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- -// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example -func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example -func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> - return %0 : !torch.vtensor<[3,2,1],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example -func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> - // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> - return %0 : !torch.vtensor<[3,2],f32> -} - -// ----- - // CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> From fecb34f70ede6045466bebeec8dd6ec2eec8bebf Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 21:51:12 -0700 Subject: [PATCH 3/5] Add support for reducelogsum op --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 25 ++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 50 +++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6c86ecb92789..ba8f71a6d234 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -907,6 +907,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); + patterns.onOp("ReduceLogSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + auto reducedSumBool = reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); + + if (failed(reducedSumBool)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data); + return success(); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 43849fbbd06e..6a8757134da1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -843,6 +843,56 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example +func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_keep_dims_example +func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_do_not_keepdims_example +func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2,2],f32>, %arg1:!torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + // CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> From b3c0859236025d426350a2c289cd104cfefe1367 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 25 Apr 2024 09:56:21 -0700 Subject: [PATCH 4/5] Clang formatting --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f50bae62fb53..7d34f14b4e70 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -974,9 +974,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "noop_with_empty_axes", 0)) return failure(); - auto reducedSumBool = reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + auto reducedSumBool = + reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( From 570e8bf30f5fad14d8fb1c3494d601b469f0d486 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Thu, 25 Apr 2024 11:48:46 -0700 Subject: [PATCH 5/5] Fix lit test issues --- test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 1262fd7fb983..e3519a89a73e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -944,9 +944,10 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // CHECK-LABEL: func.func @test_reduce_log_sum_default_axes_keepdims_example func.func @test_reduce_log_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f32> -> !torch.vtensor<[1,1,1],f32> // CHECK: return %[[LOG]] : !torch.vtensor<[1,1,1],f32> @@ -963,8 +964,8 @@ func.func @test_reduce_log_sum_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f32> -> !torch.vtensor<[3,2,1],f32> // CHECK: return %[[LOG]] : !torch.vtensor<[3,2,1],f32> @@ -981,8 +982,8 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list - // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %arg0, %[[DIMS]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f32> -> !torch.vtensor<[3,2],f32> // CHECK: return %[[LOG]] : !torch.vtensor<[3,2],f32>