Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support signed int divisor in onnx.Mod op #2471

Merged
58 changes: 54 additions & 4 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
CheckIfCustomScalarOpIsSupported<ONNXModOp>(elementType);
Value dividend = scalarOperands[0];
Value divisor = scalarOperands[1];
MultiDialectBuilder<MathBuilder> create(rewriter, loc);
MultiDialectBuilder<MathBuilder, KrnlBuilder> create(rewriter, loc);

// TODO: here we assume fmod=1, what should if that is not the case?
if (create.math.isFloatWithVector(elementType)) {
Expand All @@ -1136,9 +1136,59 @@ Value emitScalarOpFor<ONNXModOp>(ConversionPatternRewriter &rewriter,
#endif
}
if (create.math.isIntegerWithVector(elementType)) {
// TODO: implement
llvm_unreachable("not support integers at this moment since MLIR integers "
"are signless.");
// "math.rem" returns "minus" for minus dividend and "plus or zero" for plus
// dividend. We call the math.rem's return value "mathRemaider". However
// onnx.ModOp should return "minus" for minus divisor and "plus or zero" for
// plus divisor. we call the value that onnx.Mod op should return "onnxMod".
// The following table shows mathRemainder, onnxMod and their diference
// (=onnxMod-mathRemainder) for some inputs.
//
// dividend | 7 | 7 | -7 | -7 | 6 | 6 | -6 | -6 |
// divisor | 3 | -3 | 3 | -3 | 3 | -3 | 3 | -3 |
// ------------------------+-----+----+----+----+----+----+----+----+
// mathRemainder | 1 | 1 | -1 | -1 | 0 | 0 | 0 | 0 |
// onnxMod | 1 | -2 | 2 | -1 | 0 | 0 | 0 | 0 |
// onnxMod - mathRemainder | 0 | -3 | 3 | 0 | 0 | 0 | 0 | 0 |
//
// The following code shows logic to get onnxMod from mathRemainder
//
// int dividend, divisor;
// int mathRemainder = diviend % divisor;
// int adjustedRemainder = mathRemainder + divisor;
//
// if ((mathRemainder != 0) && ((dividend < 0) ^ (divisor < 0))) # c.f. "^"
// shows "exclusive or".
// return adjustedRemainder;
// return mathRemainder;

Value mathRemainder = create.math.rem(dividend, divisor);
Value adjustedRemainder = create.math.add(mathRemainder, divisor);
Value zero = create.math.constant(elementType, 0);
Value falseVal = create.math.constant(rewriter.getI1Type(), 0);
Value isMathRemainderNonZero =
create.math.eq(create.math.eq(mathRemainder, zero), falseVal);
Value isDividendMinus = create.math.slt(dividend, zero);
Value isDivisorMinus = create.math.slt(divisor, zero);
Value exclusiveOrOfIsDividendMinusAndIsDivisorMinus = create.math.eq(
create.math.eq(isDividendMinus, isDivisorMinus), falseVal);
Value needAdjust = create.math.andi(
isMathRemainderNonZero, exclusiveOrOfIsDividendMinusAndIsDivisorMinus);
Value answer =
create.math.select(needAdjust, adjustedRemainder, mathRemainder);

#ifdef DEBUG_ONNX_MOD
create.krnl.printf("XXXX emitScalarOpFor<ONNXModOp>: diviend=", dividend,
dividend.getType());
create.krnl.printf(", divisor=", divisor, divisor.getType());
create.krnl.printf(
", mathReminder=", mathRemainder, mathRemainder.getType());
create.krnl.printf(
", adjustedReminder=", adjustedRemainder, adjustedRemainder.getType());
create.krnl.printf(", Answer=", answer, answer.getType());
create.krnl.printf("\n");
#endif

return answer;
}
llvm_unreachable("unsupported element type");
}
Expand Down
4 changes: 2 additions & 2 deletions test/backend/inference_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,8 @@ def get_test_models():
# "test_mod_broadcast_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_int64_fmod_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_mixed_sign_int16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_mixed_sign_int32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_mixed_sign_int64_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_mod_mixed_sign_int32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
"test_mod_mixed_sign_int64_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_mixed_sign_int8_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_uint16_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
# "test_mod_uint32_cpu": {STATIC_SHAPE:{}, DYNAMIC_SHAPE:{-1:{-1}}, CONSTANT_INPUT:{-1}},
Expand Down
46 changes: 46 additions & 0 deletions test/mlir/conversion/onnx_to_krnl/Math/Elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,52 @@ func.func private @test_div(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>

// -----

func.func private @test_signed_int_mod(%arg0 : tensor<10x10xi64>, %arg1 : tensor<10x10xi64>) -> tensor<*xi64> {
%0 = "onnx.Mod"(%arg0, %arg1) : (tensor<10x10xi64>, tensor<10x10xi64>) -> tensor<*xi64>
"func.return"(%0) : (tensor<*xi64>) -> ()
// mlir2FileCheck.py
// CHECK-LABEL: func.func private @test_signed_int_mod
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<10x10xi64>, [[PARAM_1_:%.+]]: memref<10x10xi64>) -> memref<10x10xi64> {
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_1_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_2_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_3_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_1_1_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<10x10xi64>
// CHECK-DAG: [[LOOP_0_:%.+]]:2 = krnl.define_loops 2
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_10_4_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_5_:%.+]] = arith.constant 10 : index
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 10, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 10){
// CHECK-DAG: [[VAR_1_:%.+]]:2 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1) : (!krnl.loop, !krnl.loop) -> (index, index)
// CHECK-DAG: [[CST_10_6_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_7_:%.+]] = arith.constant 10 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_PARAM_0_MEM_:%.+]] = krnl.load [[PARAM_0_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64>
// CHECK-DAG: [[CST_10_8_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[CST_10_9_:%.+]] = arith.constant 10 : index
// CHECK-DAG: [[LOAD_PARAM_1_MEM_:%.+]] = krnl.load [[PARAM_1_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64>
// CHECK: [[VAR_4_:%.+]] = arith.remsi [[LOAD_PARAM_0_MEM_]], [[LOAD_PARAM_1_MEM_]] : i64
// CHECK-DAG: [[VAR_5_:%.+]] = arith.addi [[VAR_4_]], [[LOAD_PARAM_1_MEM_]] : i64
// CHECK-DAG: [[CST_0_1_:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[VAR_false_:%.+]] = arith.constant false
// CHECK: [[VAR_6_:%.+]] = arith.cmpi eq, [[VAR_4_]], [[CST_0_1_]] : i64
// CHECK-DAG: [[VAR_7_:%.+]] = arith.cmpi eq, [[VAR_6_]], [[VAR_false_]] : i1
// CHECK-DAG: [[VAR_8_:%.+]] = arith.cmpi slt, [[LOAD_PARAM_0_MEM_]], [[CST_0_1_]] : i64
// CHECK-DAG: [[VAR_9_:%.+]] = arith.cmpi slt, [[LOAD_PARAM_1_MEM_]], [[CST_0_1_]] : i64
// CHECK: [[VAR_10_:%.+]] = arith.cmpi eq, [[VAR_8_]], [[VAR_9_]] : i1
// CHECK: [[VAR_11_:%.+]] = arith.cmpi eq, [[VAR_10_]], [[VAR_false_]] : i1
// CHECK: [[VAR_12_:%.+]] = arith.andi [[VAR_7_]], [[VAR_11_]] : i1
// CHECK: [[VAR_13_:%.+]] = arith.select [[VAR_12_]], [[VAR_5_]], [[VAR_4_]] : i64
// CHECK: krnl.store [[VAR_13_]], [[RES_]]{{.}}[[VAR_1_]]#0, [[VAR_1_]]#1] : memref<10x10xi64>
// CHECK: }
// CHECK: return [[RES_]] : memref<10x10xi64>
// CHECK: }
}

// -----

func.func private @test_sub(%arg0 : tensor<10x10xf32>, %arg1 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Sub"(%arg0, %arg1) : (tensor<10x10xf32>, tensor<10x10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
Expand Down
Loading