Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Commit

Permalink
Remove the constraint that min / max should stride zero
Browse files Browse the repository at this point in the history
Since we apply nudging for the zero point to make sure the nudged zerop points
can be in the range of [qmin, qmax], the constraint that rmin / rmax should
stride zero isn't necessary.

This also matches the documentation of tensorflow's FakeQuantWithMinMaxArgs op,
where min and max don't need to stride zero:
https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args

PiperOrigin-RevId: 268296285
  • Loading branch information
liufengdb authored and tensorflower-gardener committed Sep 10, 2019
1 parent 8066c22 commit 797757b
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 39 deletions.
25 changes: 13 additions & 12 deletions lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned,
return false;
}

void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax,
double &scale, int64_t &nudgedZeroPoint) {
// This is a specific implementation of nudging:
// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted
// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero
// point is derived from the shifted range, and the scale isn't changed. As
// a consequence some values, which are supposeed in the original [rmin, rmax]
// range will be outside the shifted range and be clamped during quantization.
// TODO(fengliuai): we should nudge the scale as well, but that requires the
// fake quant op used in the training to use the nudged scale as well.
void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin,
double rmax, double &scale,
int64_t &nudgedZeroPoint) {
// Determine the scale.
const double qminDouble = qmin;
const double qmaxDouble = qmax;
Expand Down Expand Up @@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
double rmin, double rmax,
bool narrowRange, Type expressedType,
bool isSigned) {
// Range must straddle zero.
// TODO(b/140641593): remove this constraint.
if (rmin > 0.0 || rmax < 0.0) {
return (emitError(loc, "FakeQuant range must straddle zero: [")
<< rmin << "," << rmax << "]",
nullptr);
}

MLIRContext *ctx = expressedType.getContext();
unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
Type storageType;
Expand All @@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,

double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);

return UniformQuantizedType::getChecked(flags, storageType, expressedType,
scale, nudgedZeroPoint, qmin, qmax,
Expand Down Expand Up @@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,

double scale;
int64_t nudgedZeroPoint;
getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint);
scales.push_back(scale);
zeroPoints.push_back(nudgedZeroPoint);
}
Expand Down
22 changes: 0 additions & 22 deletions test/Dialect/QuantOps/convert-fakequant-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,27 +1,5 @@
// RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization

// -----
// Verify that a mismatched range errors.
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}}
%0 = "quant.const_fake_quant"(%arg0) {
min = 1.1 : f32, max = 1.5 : f32, num_bits = 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

// -----
// Verify that a valid range errors.
func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}}
%0 = "quant.const_fake_quant"(%arg0) {
min = 1.1 : f32, max = 1.0 : f32, num_bits = 8
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

// -----
// Unsupported quantizable type (i1 is currently not a supported element type).
func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> {
Expand Down
42 changes: 37 additions & 5 deletions test/Dialect/QuantOps/convert-fakequant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// Verifies a quint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange
// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange
func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand All @@ -62,7 +62,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// Verifies a quint8 symmetric range of -1..127/128.
// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange
// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange
func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand Down Expand Up @@ -122,7 +122,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// Verifies a qint8 asymmetric 0..1 range (with narrow_range = true).
// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange
// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange
func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand All @@ -137,7 +137,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {

// -----
// Verifies a qint8 symmetric range of -1..127/128.
// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange
// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange
func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):
// CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
Expand Down Expand Up @@ -181,9 +181,41 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
return %0 : tensor<f32>
}

// -----
// CHECK-LABEL: fakeQuantArgs_all_positive
func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:-128>>)
// CHECK-SAME: -> tensor<8x4x3xf32>

%0 = "quant.const_fake_quant"(%arg0) {
min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

// -----
// CHECK-LABEL: fakeQuantArgs_all_negative
func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>
// CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform<i8:f32, 0.0039215686274509803:127>>)
// CHECK-SAME: -> tensor<8x4x3xf32>

%0 = "quant.const_fake_quant"(%arg0) {
min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

// -----
// Verifies a qint8 per axis
// CHECK_LABEL: fakeQuantPerAxis
// CHECK-LABEL: fakeQuantPerAxis
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

Expand Down

0 comments on commit 797757b

Please sign in to comment.