Skip to content

Commit

Permalink
Fix computing intermediate zero in DynamicQuantizeLinear (#2535)
Browse files Browse the repository at this point in the history
Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Sep 27, 2023
1 parent a3930ad commit 6ad3bbe
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ struct ONNXDynamicQuantizeLinearOpLowering
create.krnl.store(scale, YScale);

// Compute y_zero_point.
Value interZeroPoint = create.math.div(create.math.sub(qMin, xMin), scale);
Value interZeroPoint = create.math.sub(qMin, create.math.div(xMin, scale));
// Saturate zero point.
Value saturateZeroPoint =
create.onnx.clip(interZeroPoint, qMin, qMax, /*scalarType=*/true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
%y, %y_scale, %y_zero_point = "onnx.DynamicQuantizeLinear"(%arg0) : (tensor<?x2xf32>) -> (tensor<?x2xui8>, tensor<f32>, tensor<ui8>)
return %y, %y_scale, %y_zero_point: tensor<?x2xui8>, tensor<f32>, tensor<ui8>

// mlir2FileCheck.py
// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @test_dynamic_quantize_linear
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<?x2xf32>) -> (memref<?x2xui8>, memref<f32>, memref<ui8>) {
Expand Down Expand Up @@ -68,8 +67,8 @@ func.func @test_dynamic_quantize_linear(%arg0: tensor<?x2xf32>) -> (tensor<?x2xu
// CHECK: [[VAR_8_:%.+]] = arith.subf [[VAR_5_]], [[VAR_7_]] : f32
// CHECK: [[VAR_9_:%.+]] = arith.divf [[VAR_8_]], [[CST_2_dot_550000_]] : f32
// CHECK: krnl.store [[VAR_9_]], [[RES_1_]][] : memref<f32>
// CHECK: [[VAR_10_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_7_]] : f32
// CHECK: [[VAR_11_:%.+]] = arith.divf [[VAR_10_]], [[VAR_9_]] : f32
// CHECK: [[VAR_10_:%.+]] = arith.divf [[VAR_7_]], [[VAR_9_]] : f32
// CHECK: [[VAR_11_:%.+]] = arith.subf [[CST_0_dot_000000_]], [[VAR_10_]] : f32
// CHECK: [[VAR_12_:%.+]] = arith.cmpf olt, [[VAR_11_]], [[CST_0_dot_000000_]] : f32
// CHECK: [[VAR_13_:%.+]] = arith.select [[VAR_12_]], [[CST_0_dot_000000_]], [[VAR_11_]] : f32
// CHECK: [[VAR_14_:%.+]] = arith.cmpf olt, [[VAR_13_]], [[CST_2_dot_550000_]] : f32
Expand Down

0 comments on commit 6ad3bbe

Please sign in to comment.