From b4c5b9c964abe970eb677c4c62908ba76235a932 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Sun, 21 Apr 2024 21:04:20 -0700 Subject: [PATCH 01/17] Rebase to commit adding reducelogsumexp tests --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 51 ++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 146 +++++++++++++++++- 2 files changed, 194 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d218d23559a3..442bf46e9ff0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1002,6 +1002,57 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*storeValue=*/data, keepDims, noop_with_empty_axes, false); }); + patterns.onOp( + "ReduceLogSumExp", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // out = Log(reduce_sum(exp(data))) + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 11)); + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + auto size = data.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + size, rewriter.getF32Type()); + Value dataCast = rewriter.create( + binder.getLoc(), f32ResultType, data, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + + Value dataExp = rewriter.create(binder.getLoc(), + resultType, dataCast); + auto reducedSum = reducedSumImpl(binder, rewriter, dataExp, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); + + if (failed(reducedSum)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + Value finalResult = rewriter.create( + binder.op, resultType, reducedSum); + + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, finalResult, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index a31350ee17bd..d1dd20d32ef1 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -911,6 +911,140 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // ----- +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example +func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded +func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example +func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example +func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[NONE_0:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE_1:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> + %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> + return %0 : !torch.vtensor<[3,2,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1341,9 +1475,15 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh +// CHECK-LABEL: func.func @test_sinh_example func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[C1:.+]] = torch.constant.int 1 + // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[C2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } @@ -2005,4 +2145,4 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK: return %[[LOSS]], %[[PROB]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> %0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> -} +} \ No newline at end of file From f894e1cb908f5d27c9e6dbfd0dee4a2c6bb2272f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Tue, 23 Apr 2024 03:10:36 +0000 Subject: [PATCH 02/17] Fixing reviewed issues --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 442bf46e9ff0..e7c2490bc5ae 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1017,25 +1017,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reduce_sum(exp(data))) Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 11)); + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); auto size = data.getType() .dyn_cast() .getOptionalSizes(); - auto f32ResultType = rewriter.getType( + auto f64ResultType = rewriter.getType( size, rewriter.getF32Type()); Value dataCast = rewriter.create( - binder.getLoc(), f32ResultType, data, castDType, + binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - Value dataExp = rewriter.create(binder.getLoc(), - resultType, dataCast); - auto reducedSum = reducedSumImpl(binder, rewriter, dataExp, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + Value dataExp = rewriter.create( + binder.getLoc(), f64ResultType, dataCast); + auto reducedSum = reducedSumImpl( + binder, rewriter, dataExp, f64ResultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( @@ -1043,7 +1043,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Failed to perform sum operation on square of operand"); Value finalResult = rewriter.create( - binder.op, resultType, reducedSum); + binder.getLoc(), f64ResultType, reducedSum); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From 904593318ccaf8cefc7cbfe00ac3db8105949e7f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Tue, 23 Apr 2024 19:35:13 +0000 Subject: [PATCH 03/17] Fix log op issue --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index e7c2490bc5ae..85cca7e7b82b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1015,7 +1015,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( 0)) return failure(); - // out = Log(reduce_sum(exp(data))) + // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); @@ -1033,17 +1033,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value dataExp = rewriter.create( binder.getLoc(), f64ResultType, dataCast); - auto reducedSum = reducedSumImpl( + auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ResultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, true); + /*storeValue=*/reducedSumExp, keepDims, noop_with_empty_axes, true); - if (failed(reducedSum)) + if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); - Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, reducedSum); + Value finalResult = rewriter.create( + binder.getLoc(), f64ResultType, reducedSumExp); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From 5e2ebd4b6622b7730e80a8784607827d7a101934 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 02:47:53 +0000 Subject: [PATCH 04/17] Fixing undeclared variable issue --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 85cca7e7b82b..4d866f69e554 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1035,7 +1035,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, dataCast); auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ResultType, - /*storeValue=*/reducedSumExp, keepDims, noop_with_empty_axes, true); + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( @@ -1043,7 +1043,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Failed to perform sum operation on square of operand"); Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, reducedSumExp); + binder.getLoc(), f64ResultType, data); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From ae8171b5af8b2e28b89231c7e1e6c4d35c8e487a Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 13:36:03 -0700 Subject: [PATCH 05/17] Change F32 to F64 --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4d866f69e554..45d9ab6f40d8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1017,7 +1017,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); @@ -1025,7 +1025,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .dyn_cast() .getOptionalSizes(); auto f64ResultType = rewriter.getType( - size, rewriter.getF32Type()); + size, rewriter.getF64Type()); Value dataCast = rewriter.create( binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, From c99196e02c3630fee25a3621ea02c3f115bcfe8b Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Sun, 21 Apr 2024 21:04:20 -0700 Subject: [PATCH 06/17] adding reducelogsumexp tests --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 45d9ab6f40d8..ef1de466f51f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1054,6 +1054,57 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( + "ReduceLogSumExp", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + // out = Log(reduce_sum(exp(data))) + Value castDType = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 11)); + Value noneVal = rewriter.create(binder.getLoc()); + Value constFalse = + rewriter.create(binder.getLoc(), false); + auto size = data.getType() + .dyn_cast() + .getOptionalSizes(); + auto f32ResultType = rewriter.getType( + size, rewriter.getF32Type()); + Value dataCast = rewriter.create( + binder.getLoc(), f32ResultType, data, castDType, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + + Value dataExp = rewriter.create(binder.getLoc(), + resultType, dataCast); + auto reducedSum = reducedSumImpl(binder, rewriter, dataExp, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, true); + + if (failed(reducedSum)) + return rewriter.notifyMatchFailure( + binder.op, + "Failed to perform sum operation on square of operand"); + + Value finalResult = rewriter.create( + binder.op, resultType, reducedSum); + + Value resultDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), resultType.getDtype()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, finalResult, resultDtype, + /*non_blocking=*/constFalse, /*copy=*/constFalse, + /*memory_format=*/noneVal); + return success(); + }); + patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; From 6aa9626ec867339140ee3ef36aa83cc9367ee080 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Tue, 23 Apr 2024 03:10:36 +0000 Subject: [PATCH 07/17] Fixing reviewed issues --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ef1de466f51f..5d28480ea082 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1068,25 +1068,25 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reduce_sum(exp(data))) Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 11)); + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); auto size = data.getType() .dyn_cast() .getOptionalSizes(); - auto f32ResultType = rewriter.getType( + auto f64ResultType = rewriter.getType( size, rewriter.getF32Type()); Value dataCast = rewriter.create( - binder.getLoc(), f32ResultType, data, castDType, + binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - Value dataExp = rewriter.create(binder.getLoc(), - resultType, dataCast); - auto reducedSum = reducedSumImpl(binder, rewriter, dataExp, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, true); + Value dataExp = rewriter.create( + binder.getLoc(), f64ResultType, dataCast); + auto reducedSum = reducedSumImpl( + binder, rewriter, dataExp, f64ResultType, + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSum)) return rewriter.notifyMatchFailure( @@ -1094,7 +1094,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Failed to perform sum operation on square of operand"); Value finalResult = rewriter.create( - binder.op, resultType, reducedSum); + binder.getLoc(), f64ResultType, reducedSum); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From 4ce202b0d12ac0322b3b93e6e35d0e678f052442 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Tue, 23 Apr 2024 19:35:13 +0000 Subject: [PATCH 08/17] Fix log op issue --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 5d28480ea082..1c788c8f10b7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1066,7 +1066,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( 0)) return failure(); - // out = Log(reduce_sum(exp(data))) + // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); @@ -1084,17 +1084,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value dataExp = rewriter.create( binder.getLoc(), f64ResultType, dataCast); - auto reducedSum = reducedSumImpl( + auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ResultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, true); + /*storeValue=*/reducedSumExp, keepDims, noop_with_empty_axes, true); - if (failed(reducedSum)) + if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); - Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, reducedSum); + Value finalResult = rewriter.create( + binder.getLoc(), f64ResultType, reducedSumExp); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From ce6e627811b1a46424266faebb452136ae6224be Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 02:47:53 +0000 Subject: [PATCH 09/17] Fixing undeclared variable issue --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 1c788c8f10b7..125e6307d39a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1086,7 +1086,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, dataCast); auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ResultType, - /*storeValue=*/reducedSumExp, keepDims, noop_with_empty_axes, true); + /*storeValue=*/data, keepDims, noop_with_empty_axes, true); if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( @@ -1094,7 +1094,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "Failed to perform sum operation on square of operand"); Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, reducedSumExp); + binder.getLoc(), f64ResultType, data); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); From a2f78b8889627cb7394267693a05fb6da47465f4 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 24 Apr 2024 13:36:03 -0700 Subject: [PATCH 10/17] Change F32 to F64 --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 125e6307d39a..6b0fea8329cb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1068,7 +1068,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float32Type*/ 7)); + binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); @@ -1076,7 +1076,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .dyn_cast() .getOptionalSizes(); auto f64ResultType = rewriter.getType( - size, rewriter.getF32Type()); + size, rewriter.getF64Type()); Value dataCast = rewriter.create( binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, From d12ea1f439faf8f238d0ecd7be524fcc52b9a944 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Mon, 29 Apr 2024 12:27:49 -0700 Subject: [PATCH 11/17] Remove duplicate code --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 51 ------------------- 1 file changed, 51 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6b0fea8329cb..45d9ab6f40d8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1054,57 +1054,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); patterns.onOp( - "ReduceLogSumExp", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - - // out = Log(reducesum(exp(data))) - Value castDType = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); - Value noneVal = rewriter.create(binder.getLoc()); - Value constFalse = - rewriter.create(binder.getLoc(), false); - auto size = data.getType() - .dyn_cast() - .getOptionalSizes(); - auto f64ResultType = rewriter.getType( - size, rewriter.getF64Type()); - Value dataCast = rewriter.create( - binder.getLoc(), f64ResultType, data, castDType, - /*non_blocking=*/constFalse, /*copy=*/constFalse, - /*memory_format=*/noneVal); - - Value dataExp = rewriter.create( - binder.getLoc(), f64ResultType, dataCast); - auto reducedSumBool = reducedSumImpl( - binder, rewriter, dataExp, f64ResultType, - /*storeValue=*/data, keepDims, noop_with_empty_axes, true); - - if (failed(reducedSumBool)) - return rewriter.notifyMatchFailure( - binder.op, - "Failed to perform sum operation on square of operand"); - - Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, data); - - Value resultDtype = Torch::getDtypeIntValueForType( - rewriter, binder.getLoc(), resultType.getDtype()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, finalResult, resultDtype, - /*non_blocking=*/constFalse, /*copy=*/constFalse, - /*memory_format=*/noneVal); - return success(); - }); - patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; From 3cbf9802042374e6ec7ebebeb69403cdb3c41f9f Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 8 May 2024 11:15:15 -0700 Subject: [PATCH 12/17] Fixing op issues and lit tests --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 117 ++++++++++-------- 1 file changed, 64 insertions(+), 53 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 45d9ab6f40d8..0b830c86b903 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -966,42 +966,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, data); return success(); }); - patterns.onOp("ReduceSum", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - - return reducedSumImpl(binder, rewriter, data, resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); - }); - patterns.onOp("ReduceSumSquare", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - int64_t keepDims, noop_with_empty_axes; - if (binder.tensorOperandAtIndex(data, 0) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - - Value dataSquare = rewriter.create( - binder.getLoc(), data.getType(), data, data); - - return reducedSumImpl(binder, rewriter, dataSquare, - resultType, - /*storeValue=*/data, keepDims, - noop_with_empty_axes, false); - }); patterns.onOp( "ReduceLogSumExp", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1030,21 +994,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - Value dataExp = rewriter.create( binder.getLoc(), f64ResultType, dataCast); + auto f64ReduceType = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF64Type()); auto reducedSumBool = reducedSumImpl( - binder, rewriter, dataExp, f64ResultType, + binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); - if (failed(reducedSumBool)) return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); - Value finalResult = rewriter.create( binder.getLoc(), f64ResultType, data); - Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); rewriter.replaceOpWithNewOp( @@ -1053,6 +1015,42 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*memory_format=*/noneVal); return success(); }); + patterns.onOp("ReduceSum", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + return reducedSumImpl(binder, rewriter, data, resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); + }); + patterns.onOp("ReduceSumSquare", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + int64_t keepDims, noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + + Value dataSquare = rewriter.create( + binder.getLoc(), data.getType(), data, data); + + return reducedSumImpl(binder, rewriter, dataSquare, + resultType, + /*storeValue=*/data, keepDims, + noop_with_empty_axes, false); + }); patterns.onOp( "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1490,18 +1488,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp("Sinh", 9, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand); - return success(); - }); + // 1/2 * (exp(x) – exp(-x)) + Value x = rewriter.create(binder.getLoc(), resultType, + operand); + Value neg = rewriter.create(binder.getLoc(), + resultType, operand); + Value y = + rewriter.create(binder.getLoc(), resultType, neg); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value z = rewriter.create( + binder.getLoc(), resultType, x, y, cstOne); + Value cstTwo = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, z, cstTwo); + return success(); + }); // split with fixed-size parts // Arguments: @@ -2688,4 +2699,4 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, {loss, logProb}); return success(); }); -} +} \ No newline at end of file From a111b00ea6b12b8dccf1b9bc4742ef5b9a2ab4c1 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 8 May 2024 14:22:37 -0700 Subject: [PATCH 13/17] removing unrelated changes --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 35 ++++++------------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 10 ++---- 2 files changed, 13 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 0b830c86b903..7b5565d64798 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1488,31 +1488,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); - patterns.onOp( - "Sinh", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value operand; - if (binder.tensorOperand(operand) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp("Sinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); - // 1/2 * (exp(x) – exp(-x)) - Value x = rewriter.create(binder.getLoc(), resultType, - operand); - Value neg = rewriter.create(binder.getLoc(), - resultType, operand); - Value y = - rewriter.create(binder.getLoc(), resultType, neg); - Value cstOne = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(1)); - Value z = rewriter.create( - binder.getLoc(), resultType, x, y, cstOne); - Value cstTwo = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(2)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, z, cstTwo); - return success(); - }); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); // split with fixed-size parts // Arguments: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index d1dd20d32ef1..ad77d47bc1a9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1475,15 +1475,9 @@ func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, // ----- -// CHECK-LABEL: func.func @test_sinh_example +// CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { - // CHECK: %[[X:.+]] = torch.aten.exp %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[NEG:.+]] = torch.aten.neg %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[Y:.+]] = torch.aten.exp %[[NEG]] : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> - // CHECK: %[[C1:.+]] = torch.constant.int 1 - // CHECK: %[[SUB:.+]] = torch.aten.sub.Tensor %[[X]], %[[Y]], %[[C1]] : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: %[[C2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.div.Scalar %[[SUB]], %[[C2]] : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } From 3b22809dbcd834522fa3923fce46067ee2f041a9 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 8 May 2024 14:28:45 -0700 Subject: [PATCH 14/17] add newline at end of file --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7b5565d64798..02d186712e09 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2686,4 +2686,4 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, {loss, logProb}); return success(); }); -} \ No newline at end of file +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ad77d47bc1a9..501cce655308 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2139,4 +2139,4 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1: // CHECK: return %[[LOSS]], %[[PROB]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> %0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> -} \ No newline at end of file +} From 94850f91559611630b3590f9ea091f83cc2c17d5 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Wed, 8 May 2024 15:45:35 -0700 Subject: [PATCH 15/17] remove redundant lit test checks --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 37 ------------------- 1 file changed, 37 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 501cce655308..cab0f158fff4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -913,13 +913,6 @@ func.func @test_reduce_log_sum_do_not_keepdims_example(%arg0:!torch.vtensor<[3,2 // CHECK-LABEL: func.func @test_reduce_log_sum_exp_default_axes_keepdims_example func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT7:.+]] = torch.constant.int 7 - // CHECK: %[[NONE_0:.+]] = torch.constant.none - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7 // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[FALSE:.+]] = torch.constant.bool false @@ -942,16 +935,6 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK-LABEL: func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT7:.+]] = torch.constant.int 7 - // CHECK: %[[NONE_0:.+]] = torch.constant.none - // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false - // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7 // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[FALSE_0:.+]] = torch.constant.bool false @@ -977,16 +960,6 @@ func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torc // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_example func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT7:.+]] = torch.constant.int 7 - // CHECK: %[[NONE_0:.+]] = torch.constant.none - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7 // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[FALSE:.+]] = torch.constant.bool false @@ -1012,16 +985,6 @@ func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2, // CHECK-LABEL: func.func @test_reduce_log_sum_exp_keep_dims_int_input_example func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vtensor<[3,2,2],si64>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT7:.+]] = torch.constant.int 7 - // CHECK: %[[NONE_0:.+]] = torch.constant.none - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[CAST:.+]] = torch.aten.to.dtype %arg0, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],si64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[EXP:.+]] = torch.aten.exp %[[CAST]] : !torch.vtensor<[3,2,2],f64> -> !torch.vtensor<[3,2,2],f64> - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: %[[SELECT:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[DIMS:.+]] = torch.prim.ListConstruct %[[ITEM]] : (!torch.int) -> !torch.list // CHECK: %[[INT7:.+]] = torch.constant.int 7 // CHECK: %[[NONE_0:.+]] = torch.constant.none // CHECK: %[[FALSE:.+]] = torch.constant.bool false From 35b7306f8284b418118b5998359c38d48a9c5ff7 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Fri, 10 May 2024 11:37:33 -0700 Subject: [PATCH 16/17] Add debug statements --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 02d186712e09..d614c36b314f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -982,6 +982,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); + llvm::errs() << castDType << "castDType\n"; Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); @@ -990,14 +991,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .getOptionalSizes(); auto f64ResultType = rewriter.getType( size, rewriter.getF64Type()); + llvm::errs() << f64ResultType << "f64ResultType\n"; Value dataCast = rewriter.create( binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); + llvm::errs() << dataCast << "AtenToDtypeOp\n"; Value dataExp = rewriter.create( binder.getLoc(), f64ResultType, dataCast); + llvm::errs() << dataExp << "AtenExpOp\n"; auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); + llvm::errs() << f64ReduceType << "f64ReduceType\n"; auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); @@ -1005,14 +1010,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); + llvm::errs() << data << "reducedSumImpl\n"; Value finalResult = rewriter.create( binder.getLoc(), f64ResultType, data); + llvm::errs() << finalResult << "AtenLogOp\n"; Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); rewriter.replaceOpWithNewOp( binder.op, resultType, finalResult, resultDtype, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); + llvm::errs() << finalResult << "finalResult\n"; return success(); }); patterns.onOp("ReduceSum", 1, From 0c8db888e809837b0a22d8647b708a83175e7315 Mon Sep 17 00:00:00 2001 From: archana-ramalingam Date: Mon, 13 May 2024 15:04:43 -0700 Subject: [PATCH 17/17] Change result type for AtenLogOp --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 10 +--------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 ++++++++-------- 2 files changed, 9 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 33c3a8afe3d6..30ab1bfbd8b7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -982,7 +982,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( // out = Log(reducesum(exp(data))) Value castDType = rewriter.create( binder.getLoc(), rewriter.getI64IntegerAttr(/*Float64Type*/ 7)); - llvm::errs() << castDType << "castDType\n"; Value noneVal = rewriter.create(binder.getLoc()); Value constFalse = rewriter.create(binder.getLoc(), false); @@ -991,18 +990,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( .getOptionalSizes(); auto f64ResultType = rewriter.getType( size, rewriter.getF64Type()); - llvm::errs() << f64ResultType << "f64ResultType\n"; Value dataCast = rewriter.create( binder.getLoc(), f64ResultType, data, castDType, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - llvm::errs() << dataCast << "AtenToDtypeOp\n"; Value dataExp = rewriter.create( binder.getLoc(), f64ResultType, dataCast); - llvm::errs() << dataExp << "AtenExpOp\n"; auto f64ReduceType = rewriter.getType( resultType.getOptionalSizes(), rewriter.getF64Type()); - llvm::errs() << f64ReduceType << "f64ReduceType\n"; auto reducedSumBool = reducedSumImpl( binder, rewriter, dataExp, f64ReduceType, /*storeValue=*/data, keepDims, noop_with_empty_axes, true); @@ -1010,17 +1005,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return rewriter.notifyMatchFailure( binder.op, "Failed to perform sum operation on square of operand"); - llvm::errs() << data << "reducedSumImpl\n"; Value finalResult = rewriter.create( - binder.getLoc(), f64ResultType, data); - llvm::errs() << finalResult << "AtenLogOp\n"; + binder.getLoc(), f64ReduceType, data); Value resultDtype = Torch::getDtypeIntValueForType( rewriter, binder.getLoc(), resultType.getDtype()); rewriter.replaceOpWithNewOp( binder.op, resultType, finalResult, resultDtype, /*non_blocking=*/constFalse, /*copy=*/constFalse, /*memory_format=*/noneVal); - llvm::errs() << finalResult << "finalResult\n"; return success(); }); patterns.onOp("ReduceSum", 1, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 1572d1c45ba4..e52ccd6daf44 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -923,9 +923,9 @@ func.func @test_reduce_log_sum_exp_default_axes_keepdims_example(%arg0: !torch.v // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f64> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[1,1,1],f64> -> !torch.vtensor<[1,1,1],f64> // CHECK: %[[INT6:.+]] = torch.constant.int 6 - // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[1,1,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[1,1,1],f32> %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32> @@ -948,9 +948,9 @@ func.func @test_reduce_log_sum_exp_do_not_keepdims_example_expanded(%arg0: !torc // CHECK: %[[FALSE_1:.+]] = torch.constant.bool false // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[FALSE_1]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f64> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2],f64> -> !torch.vtensor<[3,2],f64> // CHECK: %[[INT6:.+]] = torch.constant.int 6 - // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE_0]], %[[FALSE_0]], %[[NONE_0]] : !torch.vtensor<[3,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2],f32> %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> @@ -973,9 +973,9 @@ func.func @test_reduce_log_sum_exp_keep_dims_example(%arg0: !torch.vtensor<[3,2, // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> // CHECK: %[[INT6:.+]] = torch.constant.int 6 - // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> return %0 : !torch.vtensor<[3,2,1],f32> @@ -998,9 +998,9 @@ func.func @test_reduce_log_sum_exp_keep_dims_int_input_example(%arg0: !torch.vte // CHECK: %[[TRUE:.+]] = torch.constant.bool true // CHECK: %[[NONE_1:.+]] = torch.constant.none // CHECK: %[[SUM:.+]] = torch.aten.sum.dim_IntList %[[EXP]], %[[DIMS]], %[[TRUE]], %[[NONE_1]] : !torch.vtensor<[3,2,2],f64>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f64> - // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,2],f64> + // CHECK: %[[LOG:.+]] = torch.aten.log %[[SUM]] : !torch.vtensor<[3,2,1],f64> -> !torch.vtensor<[3,2,1],f64> // CHECK: %[[INT6:.+]] = torch.constant.int 6 - // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,2],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> + // CHECK: %[[CASTLIKE:.+]] = torch.aten.to.dtype %[[LOG]], %[[INT6]], %[[FALSE]], %[[FALSE]], %[[NONE_0]] : !torch.vtensor<[3,2,1],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,2,1],f32> // CHECK: return %[[CASTLIKE]] : !torch.vtensor<[3,2,1],f32> %0 = torch.operator "onnx.ReduceLogSumExp"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2,1],f32> return %0 : !torch.vtensor<[3,2,1],f32>