From 797757bcbd6dfe35398866e5ba77c6ba407a213f Mon Sep 17 00:00:00 2001 From: Feng Liu Date: Tue, 10 Sep 2019 13:26:14 -0700 Subject: [PATCH] Remove the constraint that min / max should stride zero Since we apply nudging for the zero point to make sure the nudged zerop points can be in the range of [qmin, qmax], the constraint that rmin / rmax should stride zero isn't necessary. This also matches the documentation of tensorflow's FakeQuantWithMinMaxArgs op, where min and max don't need to stride zero: https://www.tensorflow.org/api_docs/python/tf/quantization/fake_quant_with_min_max_args PiperOrigin-RevId: 268296285 --- .../QuantOps/Utils/FakeQuantSupport.cpp | 25 +++++------ .../QuantOps/convert-fakequant-invalid.mlir | 22 ---------- test/Dialect/QuantOps/convert-fakequant.mlir | 42 ++++++++++++++++--- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 5d4561be81b2..2e1bd958b795 100644 --- a/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -54,8 +54,17 @@ bool getDefaultStorageParams(unsigned numBits, bool narrowRange, bool isSigned, return false; } -void getScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, double rmax, - double &scale, int64_t &nudgedZeroPoint) { +// This is a specific implementation of nudging: +// If 0.0 < rmin < rmax or rmin < rmax < 0.0, the range will be shifted +// to include 0.0, but the range width size (rmax-rmin) isn't changed. The zero +// point is derived from the shifted range, and the scale isn't changed. As +// a consequence some values, which are supposeed in the original [rmin, rmax] +// range will be outside the shifted range and be clamped during quantization. +// TODO(fengliuai): we should nudge the scale as well, but that requires the +// fake quant op used in the training to use the nudged scale as well. +void getNudgedScaleAndZeroPoint(int64_t qmin, int64_t qmax, double rmin, + double rmax, double &scale, + int64_t &nudgedZeroPoint) { // Determine the scale. const double qminDouble = qmin; const double qmaxDouble = qmax; @@ -100,14 +109,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double rmin, double rmax, bool narrowRange, Type expressedType, bool isSigned) { - // Range must straddle zero. - // TODO(b/140641593): remove this constraint. - if (rmin > 0.0 || rmax < 0.0) { - return (emitError(loc, "FakeQuant range must straddle zero: [") - << rmin << "," << rmax << "]", - nullptr); - } - MLIRContext *ctx = expressedType.getContext(); unsigned flags = isSigned ? QuantizationFlags::Signed : 0; Type storageType; @@ -129,7 +130,7 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, double scale; int64_t nudgedZeroPoint; - getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); + getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); return UniformQuantizedType::getChecked(flags, storageType, expressedType, scale, nudgedZeroPoint, qmin, qmax, @@ -172,7 +173,7 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, double scale; int64_t nudgedZeroPoint; - getScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); + getNudgedScaleAndZeroPoint(qmin, qmax, rmin, rmax, scale, nudgedZeroPoint); scales.push_back(scale); zeroPoints.push_back(nudgedZeroPoint); } diff --git a/test/Dialect/QuantOps/convert-fakequant-invalid.mlir b/test/Dialect/QuantOps/convert-fakequant-invalid.mlir index b55538050ea8..d6b6a524e593 100644 --- a/test/Dialect/QuantOps/convert-fakequant-invalid.mlir +++ b/test/Dialect/QuantOps/convert-fakequant-invalid.mlir @@ -1,27 +1,5 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -quant-convert-simulated-quantization -// ----- -// Verify that a mismatched range errors. -func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.500000e+00]}} - %0 = "quant.const_fake_quant"(%arg0) { - min = 1.1 : f32, max = 1.5 : f32, num_bits = 8 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - -// ----- -// Verify that a valid range errors. -func @fakeQuantArgs(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { -^bb0(%arg0: tensor<8x4x3xf32>): - // expected-error@+1 {{FakeQuant range must straddle zero: [1.100000e+00,1.000000e+00}} - %0 = "quant.const_fake_quant"(%arg0) { - min = 1.1 : f32, max = 1.0 : f32, num_bits = 8 - } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> - return %0 : tensor<8x4x3xf32> -} - // ----- // Unsupported quantizable type (i1 is currently not a supported element type). func @fakeQuantArgs(tensor<8x4x3xi1>) -> tensor<8x4x3xi1> { diff --git a/test/Dialect/QuantOps/convert-fakequant.mlir b/test/Dialect/QuantOps/convert-fakequant.mlir index 316702cc5288..f5709e6a8e10 100644 --- a/test/Dialect/QuantOps/convert-fakequant.mlir +++ b/test/Dialect/QuantOps/convert-fakequant.mlir @@ -47,7 +47,7 @@ func @fakeQuantArgs_Quint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // ----- // Verifies a quint8 asymmetric 0..1 range (with narrow_range = true). -// CHECK_LABEL: fakeQuantArgs_Quint8_NarrowRange +// CHECK-LABEL: fakeQuantArgs_Quint8_NarrowRange func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) @@ -62,7 +62,7 @@ func @fakeQuantArgs_Quint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // ----- // Verifies a quint8 symmetric range of -1..127/128. -// CHECK_LABEL: fakeQuantArgs_Quint8_SymmetricRange +// CHECK-LABEL: fakeQuantArgs_Quint8_SymmetricRange func @fakeQuantArgs_Quint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) @@ -122,7 +122,7 @@ func @fakeQuantArgs_Qint8_0_1(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // ----- // Verifies a qint8 asymmetric 0..1 range (with narrow_range = true). -// CHECK_LABEL: fakeQuantArgs_Qint8_NarrowRange +// CHECK-LABEL: fakeQuantArgs_Qint8_NarrowRange func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) @@ -137,7 +137,7 @@ func @fakeQuantArgs_Qint8_NarrowRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // ----- // Verifies a qint8 symmetric range of -1..127/128. -// CHECK_LABEL: fakeQuantArgs_Qint8_SymmetricRange +// CHECK-LABEL: fakeQuantArgs_Qint8_SymmetricRange func @fakeQuantArgs_Qint8_SymmetricRange(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>): // CHECK: %0 = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) @@ -181,9 +181,41 @@ func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { return %0 : tensor } +// ----- +// CHECK-LABEL: fakeQuantArgs_all_positive +func @fakeQuantArgs_all_positive(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + + // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) + // CHECK-SAME: -> tensor<8x4x3xf32> + + %0 = "quant.const_fake_quant"(%arg0) { + min = 0.5 : f32, max = 1.5 : f32, num_bits = 8, narrow_range = false, is_signed = true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + +// ----- +// CHECK-LABEL: fakeQuantArgs_all_negative +func @fakeQuantArgs_all_negative(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + + // CHECK: %[[qc:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK-NEXT: "quant.dcast"(%[[qc]]) : (tensor<8x4x3x!quant.uniform>) + // CHECK-SAME: -> tensor<8x4x3xf32> + + %0 = "quant.const_fake_quant"(%arg0) { + min = -1.5 : f32, max = -0.5 : f32, num_bits = 8, narrow_range = false, is_signed = true + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +} + // ----- // Verifies a qint8 per axis -// CHECK_LABEL: fakeQuantPerAxis +// CHECK-LABEL: fakeQuantPerAxis func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { ^bb0(%arg0: tensor<8x4x3xf32>):