Skip to content

Commit

Permalink
Add attributes support for onnx cumsum op
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchen62 committed Apr 30, 2024
1 parent 122cf22 commit 408ea19
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 36 deletions.
76 changes: 49 additions & 27 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1361,56 +1361,78 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value operand;
Value axisTensor;
Value operand, axisTensor;
int64_t exclusive, reverse;
if (binder.tensorOperands(operand, axisTensor) ||
binder.s64IntegerAttr(exclusive, "exclusive", 0) ||
binder.s64IntegerAttr(reverse, "reverse", 0) ||
binder.tensorResultType(resultType))
return failure();

int64_t exclusive;
int64_t reverse;
// if bind succeeds and either is set, fail because not implemented
if (!binder.s64IntegerAttr(exclusive, "exclusive", 0))
if (exclusive != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: exclusive");
if (!binder.s64IntegerAttr(reverse, "reverse", 0))
if (reverse != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: reverse");
Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}

// deal with neg axis: if (axis < 0) axis += rank
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
binder.getLoc(), axisScalar, cstZero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
Value axis = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
if (!reverse) {
if (!exclusive) {
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(
binder.op, resultType, operand, axis, none);
return success();
} else {
Value cumsum = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, operand, axis, none);
rewriter.replaceOpWithNewOp<Torch::AtenSubTensorOp>(
binder.op, resultType, cumsum, operand, cstOne);
return success();
}
}
// resultTensorType.print(llvm::outs());
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);

Value dims = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{axis});
Value flip = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, operand, dims);
Value cumsum = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, flip, axis, none);
if (!exclusive) {
rewriter.replaceOpWithNewOp<Torch::AtenFlipOp>(binder.op, resultType,
cumsum, dims);
return success();
}

Value flipCumsum = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, cumsum, dims);
rewriter.replaceOpWithNewOp<Torch::AtenSubTensorOp>(
binder.op, resultType, flipCumsum, operand, cstOne);
return success();
});
patterns.onOp(
Expand Down
79 changes: 79 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,85 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a

// -----

// CHECK-LABEL: @test_cumsum_1d_exclusive
func.func @test_cumsum_1d_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[CUMSUM]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse
func.func @test_cumsum_1d_reverse(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse_exclusive
func.func @test_cumsum_1d_reverse_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[FLIP_0]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64, torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_2d
func.func @test_cumsum_2d(%arg0: !torch.vtensor<[2,3],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[2,3],f64>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64>
return %0 : !torch.vtensor<[2,3],f64>
}

// -----

// CHECK-LABEL: func.func @test_exp
func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} {
// CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
Expand Down
9 changes: 0 additions & 9 deletions test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,3 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>)
%211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32>
return %211 : !torch.vtensor<[1,64,1],f32>
}

// -----
// Fixed.
func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>,
%arg1: !torch.vtensor<[],si32>)
-> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64>
return %212 : !torch.vtensor<[2,3],f64>
}

0 comments on commit 408ea19

Please sign in to comment.