Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attributes support for onnx cumsum op #3241

Merged
merged 1 commit into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 38 additions & 27 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,56 +1463,67 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
});
patterns.onOp(
"CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Location loc = binder.getLoc();
Torch::ValueTensorType resultType;
Value operand;
Value axisTensor;
Value operand, axisTensor;
int64_t exclusive, reverse;
if (binder.tensorOperands(operand, axisTensor) ||
binder.s64IntegerAttr(exclusive, "exclusive", 0) ||
binder.s64IntegerAttr(reverse, "reverse", 0) ||
binder.tensorResultType(resultType))
return failure();

int64_t exclusive;
int64_t reverse;
// if bind succeeds and either is set, fail because not implemented
if (!binder.s64IntegerAttr(exclusive, "exclusive", 0))
if (exclusive != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: exclusive");
if (!binder.s64IntegerAttr(reverse, "reverse", 0))
if (reverse != 0)
return rewriter.notifyMatchFailure(
binder.op, "unsupported onnx.CumSum conversion: reverse");
Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
}

// deal with neg axis: if (axis < 0) axis += rank
int64_t rank =
cast<Torch::ValueTensorType>(operand.getType()).getSizes().size();
Value rankVal = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank));
Value zero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value cstOne = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));

Value axisScalar = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), axisTensor);
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(
binder.getLoc(), axisScalar, zero);
binder.getLoc(), axisScalar, cstZero);
isNegative =
rewriter.create<Torch::AtenIntBoolOp>(binder.getLoc(), isNegative);
Value finalOffset = rewriter.create<Torch::AtenMulIntOp>(
binder.getLoc(), isNegative, rankVal);
Value dim = rewriter.create<Torch::AtenAddIntOp>(
Value axis = rewriter.create<Torch::AtenAddIntOp>(
binder.getLoc(), axisScalar, finalOffset);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());

Torch::BaseTensorType resultTensorType =
cast<Torch::BaseTensorType>(resultType);
if (!resultTensorType.hasDtype()) {
return rewriter.notifyMatchFailure(
binder.op, "expected result type to have a dtype");
Value res;
if (reverse) {
Value dims = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
rewriter.getType<Torch::ListType>(
rewriter.getType<Torch::IntType>()),
SmallVector<Value>{axis});
Value flip = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, operand, dims);
Value cumsum = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, flip, axis, none);
res = rewriter.create<Torch::AtenFlipOp>(binder.getLoc(), resultType,
cumsum, dims);
} else {
res = rewriter.create<Torch::AtenCumsumOp>(
binder.getLoc(), resultType, operand, axis, none);
}
// resultTensorType.print(llvm::outs());
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);
rewriter.replaceOpWithNewOp<Torch::AtenCumsumOp>(binder.op, resultType,
operand, dim, none);

if (exclusive)
res = rewriter.create<Torch::AtenSubTensorOp>(
binder.getLoc(), resultType, res, operand, cstOne);
rewriter.replaceOp(binder.op, res);
return success();
});
patterns.onOp(
Expand Down
79 changes: 79 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,85 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a

// -----

// CHECK-LABEL: @test_cumsum_1d_exclusive
func.func @test_cumsum_1d_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[CUMSUM]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse
func.func @test_cumsum_1d_reverse(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_1d_reverse_exclusive
func.func @test_cumsum_1d_reverse_exclusive(%arg0: !torch.vtensor<[5],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 1
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[C1:.*]] = torch.constant.int 1
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: %[[CUMSUM:.*]] = torch.aten.cumsum %[[FLIP]], %[[ADD]], %[[NONE]] : !torch.vtensor<[5],f64>, !torch.int, !torch.none -> !torch.vtensor<[5],f64>
// CHECK: %[[FLIP_0:.*]] = torch.aten.flip %[[CUMSUM]], %[[DIMS]] : !torch.vtensor<[5],f64>, !torch.list<int> -> !torch.vtensor<[5],f64>
// CHECK: torch.aten.sub.Tensor %[[FLIP_0]], %arg0, %[[C1]] : !torch.vtensor<[5],f64>, !torch.vtensor<[5],f64>, !torch.int -> !torch.vtensor<[5],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) {torch.onnx.exclusive = 1 : si64, torch.onnx.reverse = 1 : si64} : (!torch.vtensor<[5],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[5],f64>
return %0 : !torch.vtensor<[5],f64>
}

// -----

// CHECK-LABEL: @test_cumsum_2d
func.func @test_cumsum_2d(%arg0: !torch.vtensor<[2,3],f64>, %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[RANK:.*]] = torch.constant.int 2
// CHECK: %[[C0:.*]] = torch.constant.int 0
// CHECK: %[[AXIS:.*]] = torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int
// CHECK: %[[BOOL:.*]] = torch.aten.lt.int %[[AXIS]], %[[C0]] : !torch.int, !torch.int -> !torch.bool
// CHECK: %[[INT:.*]] = torch.aten.Int.bool %[[BOOL]] : !torch.bool -> !torch.int
// CHECK: %[[OTHER:.*]] = torch.aten.mul.int %[[INT]], %[[RANK]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[AXIS]], %[[OTHER]] : !torch.int, !torch.int -> !torch.int
// CHECK: %[[NONE:.*]] = torch.constant.none
// torch.aten.cumsum %arg0, %[[ADD]], %[[NONE]] : !torch.vtensor<[2,3],f64>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f64>
%0 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64>
return %0 : !torch.vtensor<[2,3],f64>
}

// -----

// CHECK-LABEL: func.func @test_exp
func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} {
// CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
Expand Down
9 changes: 0 additions & 9 deletions test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,3 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>)
%211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32>
return %211 : !torch.vtensor<[1,64,1],f32>
}

// -----
// Fixed.
func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>,
%arg1: !torch.vtensor<[],si32>)
-> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64>
return %212 : !torch.vtensor<[2,3],f64>
}
Loading