Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Commit

Permalink
Convert ConstFakeQuantPerAxis to qcast and dcast pair
Browse files Browse the repository at this point in the history
This is also to add the test to the fakeQuantAttrsToType for per-channel fake quant.

PiperOrigin-RevId: 268260032
  • Loading branch information
liufengdb authored and tensorflower-gardener committed Sep 10, 2019
1 parent 083a0ae commit 8066c22
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
92 changes: 68 additions & 24 deletions lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,67 +37,111 @@ class ConvertSimulatedQuantPass

} // end anonymous namespace

/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
class ConstFakeQuantRewrite : public RewritePattern {
/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair.
template <typename ConcretRewriteClass, typename FakeQuantOp>
class FakeQuantRewrite : public OpRewritePattern<FakeQuantOp> {
public:
bool *hadFailure;
using OpRewritePattern<FakeQuantOp>::OpRewritePattern;

ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure)
: RewritePattern(ConstFakeQuant::getOperationName(), 1, context),
hadFailure(hadFailure) {}
FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
: OpRewritePattern<FakeQuantOp>(ctx), hadFailure(hadFailure) {}

PatternMatchResult matchAndRewrite(Operation *op,
PatternMatchResult matchAndRewrite(FakeQuantOp op,
PatternRewriter &rewriter) const override {
// TODO: If this pattern comes up more frequently, consider adding core
// support for failable rewrites.
if (failableRewrite(op, rewriter)) {
*hadFailure = true;
return matchFailure();
return Pattern::matchFailure();
}

return matchSuccess();
return Pattern::matchSuccess();
}

bool failableRewrite(Operation *op, PatternRewriter &rewriter) const {
auto fqOp = cast<ConstFakeQuant>(op);
private:
bool *hadFailure;

auto converter =
ExpressedToQuantizedConverter::forInputType(fqOp.getType());
bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const {
auto converter = ExpressedToQuantizedConverter::forInputType(op.getType());
if (!converter) {
return (op->emitError("unsupported quantized type conversion"), true);
return (op.emitError("unsupported quantized type conversion"), true);
}

UniformQuantizedType uniformElementType = fakeQuantAttrsToType(
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
fqOp.narrow_range(), converter.expressedType, fqOp.is_signed());
QuantizedType elementType =
static_cast<const ConcretRewriteClass *>(this)
->convertFakeQuantAttrsToType(op, converter.expressedType);

if (!uniformElementType) {
if (!elementType) {
// Note that the fakeQuantAttrsToType will have emitted the error.
return true;
}

Type quantizedType = converter.convert(uniformElementType);
Type quantizedType = converter.convert(elementType);
assert(quantizedType &&
"Converter accepted a type that it did not convert");

// TODO: Map to a qbarrier with an attribute like [Forced] to signal that
// this is a forced/hard-coded constraint.
auto qbarrier = rewriter.create<QuantizeCastOp>(op->getLoc(), quantizedType,
fqOp.inputs());
auto qbarrier = rewriter.create<QuantizeCastOp>(op.getLoc(), quantizedType,
op.inputs());
rewriter.replaceOpWithNewOp<DequantizeCastOp>(op, converter.inputType,
qbarrier.getResult());

return false;
}
};

class ConstFakeQuantRewrite
: public FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant> {
public:
using BaseRewrite = FakeQuantRewrite<ConstFakeQuantRewrite, ConstFakeQuant>;

ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure)
: BaseRewrite(ctx, hadFailure) {}

QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp,
Type expressedType) const {
return fakeQuantAttrsToType(
fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
fqOp.min().convertToFloat(), fqOp.max().convertToFloat(),
fqOp.narrow_range(), expressedType, fqOp.is_signed());
}
};

class ConstFakeQuantPerAxisRewrite
: public FakeQuantRewrite<ConstFakeQuantPerAxisRewrite,
ConstFakeQuantPerAxis> {
public:
using BaseRewrite =
FakeQuantRewrite<ConstFakeQuantPerAxisRewrite, ConstFakeQuantPerAxis>;

ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure)
: BaseRewrite(ctx, hadFailure) {}

QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp,
Type expressedType) const {
SmallVector<double, 4> min, max;
min.reserve(fqOp.min().size());
max.reserve(fqOp.max().size());
for (auto m : fqOp.min())
min.push_back(m.cast<FloatAttr>().getValueAsDouble());
for (auto m : fqOp.max())
max.push_back(m.cast<FloatAttr>().getValueAsDouble());

return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(),
fqOp.axis().getSExtValue(), min, max,
fqOp.narrow_range(), expressedType,
fqOp.is_signed());
}
};

void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
OwningRewritePatternList patterns;
auto func = getFunction();
auto *context = &getContext();
patterns.insert<ConstFakeQuantRewrite>(context, &hadFailure);
auto ctx = func.getContext();
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure);
applyPatternsGreedily(func, patterns);
if (hadFailure)
signalPassFailure();
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits,
loc);
}

// TODO(fengliuai): test this method once the quantizeAttr method is fixed.
UniformQuantizedPerAxisType
fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,
ArrayRef<double> rmins, ArrayRef<double> rmaxs,
Expand Down Expand Up @@ -180,8 +179,8 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension,

unsigned flags = isSigned ? QuantizationFlags::Signed : 0;
return UniformQuantizedPerAxisType::getChecked(
flags, storageType, expressedType, scales, zeroPoints, qmin, qmax,
quantizedDimension, loc);
flags, storageType, expressedType, scales, zeroPoints, quantizedDimension,
qmin, qmax, loc);
}

} // namespace quant
Expand Down
19 changes: 19 additions & 0 deletions test/Dialect/QuantOps/convert-fakequant.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ func @fakeQuantArgs_UnrankedTensor(tensor<f32>) -> tensor<f32> {
} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

// -----
// Verifies a qint8 per axis
// CHECK_LABEL: fakeQuantPerAxis
func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
^bb0(%arg0: tensor<8x4x3xf32>):

// CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>)
// CHECK-SAME: -> tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>
// CHECK: %[[d:.*]] = "quant.dcast"(%[[q]])
// CHECK-SAME: (tensor<8x4x3x!quant.uniform<i8:f32:2, {7.812500e-03,1.000000e+00:-128,0.0039215686274509803:-128}>>)

%0 = "quant.const_fake_quant_per_axis"(%arg0) {
min = [-1.0 : f32, 0.0 : f32, 0.0 : f32],
max = [0.9921875 : f32, 0.0: f32, 1.0 : f32],
num_bits = 8, narrow_range = false, is_signed = true, axis = 2
} : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32>
return %0 : tensor<8x4x3xf32>
}

0 comments on commit 8066c22

Please sign in to comment.