From 1fe22aafe54a174dd424c635f238b489f8d82f08 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Mon, 29 Apr 2024 17:22:46 -0700 Subject: [PATCH] Add attributes support for onnx cumsum op --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 65 ++++++++------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 79 +++++++++++++++++++ .../unsupported_fb_opt_ops.mlir | 9 --- 3 files changed, 117 insertions(+), 36 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f7fac538068a..889a5fe88704 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1463,25 +1463,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Location loc = binder.getLoc(); Torch::ValueTensorType resultType; - Value operand; - Value axisTensor; + Value operand, axisTensor; + int64_t exclusive, reverse; if (binder.tensorOperands(operand, axisTensor) || + binder.s64IntegerAttr(exclusive, "exclusive", 0) || + binder.s64IntegerAttr(reverse, "reverse", 0) || binder.tensorResultType(resultType)) return failure(); - int64_t exclusive; - int64_t reverse; - // if bind succeeds and either is set, fail because not implemented - if (!binder.s64IntegerAttr(exclusive, "exclusive", 0)) - if (exclusive != 0) - return rewriter.notifyMatchFailure( - binder.op, "unsupported onnx.CumSum conversion: exclusive"); - if (!binder.s64IntegerAttr(reverse, "reverse", 0)) - if (reverse != 0) - return rewriter.notifyMatchFailure( - binder.op, "unsupported onnx.CumSum conversion: reverse"); + Torch::BaseTensorType resultTensorType = + cast(resultType); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + binder.op, "expected result type to have a dtype"); + } // deal with neg axis: if (axis < 0) axis += rank int64_t rank = @@ -1489,30 +1485,45 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value rankVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); - Value zero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); Value axisScalar = rewriter.create( binder.getLoc(), rewriter.getType(), axisTensor); Value isNegative = rewriter.create( - binder.getLoc(), axisScalar, zero); + binder.getLoc(), axisScalar, cstZero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, rankVal); - Value dim = rewriter.create( + Value axis = rewriter.create( binder.getLoc(), axisScalar, finalOffset); + Value none = rewriter.create(binder.getLoc()); - Torch::BaseTensorType resultTensorType = - cast(resultType); - if (!resultTensorType.hasDtype()) { - return rewriter.notifyMatchFailure( - binder.op, "expected result type to have a dtype"); + Value res; + if (reverse) { + Value dims = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + SmallVector{axis}); + Value flip = rewriter.create( + binder.getLoc(), resultType, operand, dims); + Value cumsum = rewriter.create( + binder.getLoc(), resultType, flip, axis, none); + res = rewriter.create(binder.getLoc(), resultType, + cumsum, dims); + } else { + res = rewriter.create( + binder.getLoc(), resultType, operand, axis, none); } - // resultTensorType.print(llvm::outs()); - Value none = rewriter.create(loc); - rewriter.replaceOpWithNewOp(binder.op, resultType, - operand, dim, none); + + if (exclusive) + res = rewriter.create( + binder.getLoc(), resultType, res, operand, cstOne); + rewriter.replaceOp(binder.op, res); return success(); }); patterns.onOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index e8266c04ffad..a87ec4f8f43f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1324,6 +1324,85 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a // ----- +// CHECK-LABEL: @test_cumsum_1d_exclusive +func.func @test_cumsum_1d_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.sub.Tensor %[[CUMSUM]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_1d_reverse +func.func @test_cumsum_1d_reverse(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_1d_reverse_exclusive +func.func @test_cumsum_1d_reverse_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 1 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64> + // CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list -> !torch.vtensor<[5],f64> + // CHECK: torch.aten.sub.Tensor %[[FLIP_0]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64, torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> + return %0 : !torch.vtensor<[5],f64> +} + +// ----- + +// CHECK-LABEL: @test_cumsum_2d +func.func @test_cumsum_2d(%arg0: !torch.vtensor<[2,3],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RANK:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int + // CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[NONE:.*]] = torch.constant.none + // torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[2,3],f64>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f64> + %0 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> + return %0 : !torch.vtensor<[2,3],f64> +} + +// ----- + // CHECK-LABEL: func.func @test_exp func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { // CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index b4f9dfbb30f2..3fc02201748e 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -37,12 +37,3 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> } - -// ----- -// Fixed. -func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, - %arg1: !torch.vtensor<[],si32>) - -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> - return %212 : !torch.vtensor<[2,3],f64> -}