diff --git a/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp b/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp index 4f6eb8cb9855..1000b1fabbf7 100644 --- a/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp +++ b/lib/Dialect/QuantOps/Transforms/ConvertSimQuant.cpp @@ -37,54 +37,53 @@ class ConvertSimulatedQuantPass } // end anonymous namespace -/// Rewrites ConstFakeQuant into a qbarrier/dbarrier pair. -class ConstFakeQuantRewrite : public RewritePattern { +/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +template +class FakeQuantRewrite : public OpRewritePattern { public: - bool *hadFailure; + using OpRewritePattern::OpRewritePattern; - ConstFakeQuantRewrite(MLIRContext *context, bool *hadFailure) - : RewritePattern(ConstFakeQuant::getOperationName(), 1, context), - hadFailure(hadFailure) {} + FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : OpRewritePattern(ctx), hadFailure(hadFailure) {} - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(FakeQuantOp op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; - return matchFailure(); + return Pattern::matchFailure(); } - return matchSuccess(); + return Pattern::matchSuccess(); } - bool failableRewrite(Operation *op, PatternRewriter &rewriter) const { - auto fqOp = cast(op); +private: + bool *hadFailure; - auto converter = - ExpressedToQuantizedConverter::forInputType(fqOp.getType()); + bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { + auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); if (!converter) { - return (op->emitError("unsupported quantized type conversion"), true); + return (op.emitError("unsupported quantized type conversion"), true); } - UniformQuantizedType uniformElementType = fakeQuantAttrsToType( - fqOp.getLoc(), fqOp.num_bits().getSExtValue(), - fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), - fqOp.narrow_range(), converter.expressedType, fqOp.is_signed()); + QuantizedType elementType = + static_cast(this) + ->convertFakeQuantAttrsToType(op, converter.expressedType); - if (!uniformElementType) { + if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } - Type quantizedType = converter.convert(uniformElementType); + Type quantizedType = converter.convert(elementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. - auto qbarrier = rewriter.create(op->getLoc(), quantizedType, - fqOp.inputs()); + auto qbarrier = rewriter.create(op.getLoc(), quantizedType, + op.inputs()); rewriter.replaceOpWithNewOp(op, converter.inputType, qbarrier.getResult()); @@ -92,12 +91,57 @@ class ConstFakeQuantRewrite : public RewritePattern { } }; +class ConstFakeQuantRewrite + : public FakeQuantRewrite { +public: + using BaseRewrite = FakeQuantRewrite; + + ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, + Type expressedType) const { + return fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), + fqOp.narrow_range(), expressedType, fqOp.is_signed()); + } +}; + +class ConstFakeQuantPerAxisRewrite + : public FakeQuantRewrite { +public: + using BaseRewrite = + FakeQuantRewrite; + + ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, + Type expressedType) const { + SmallVector min, max; + min.reserve(fqOp.min().size()); + max.reserve(fqOp.max().size()); + for (auto m : fqOp.min()) + min.push_back(m.cast().getValueAsDouble()); + for (auto m : fqOp.max()) + max.push_back(m.cast().getValueAsDouble()); + + return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits().getSExtValue(), + fqOp.axis().getSExtValue(), min, max, + fqOp.narrow_range(), expressedType, + fqOp.is_signed()); + } +}; + void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); - auto *context = &getContext(); - patterns.insert(context, &hadFailure); + auto ctx = func.getContext(); + patterns.insert( + ctx, &hadFailure); applyPatternsGreedily(func, patterns); if (hadFailure) signalPassFailure(); diff --git a/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp b/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp index 02f803ac8396..5d4561be81b2 100644 --- a/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp +++ b/lib/Dialect/QuantOps/Utils/FakeQuantSupport.cpp @@ -136,7 +136,6 @@ UniformQuantizedType fakeQuantAttrsToType(Location loc, unsigned numBits, loc); } -// TODO(fengliuai): test this method once the quantizeAttr method is fixed. UniformQuantizedPerAxisType fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, ArrayRef rmins, ArrayRef rmaxs, @@ -180,8 +179,8 @@ fakeQuantAttrsToType(Location loc, unsigned numBits, int32_t quantizedDimension, unsigned flags = isSigned ? QuantizationFlags::Signed : 0; return UniformQuantizedPerAxisType::getChecked( - flags, storageType, expressedType, scales, zeroPoints, qmin, qmax, - quantizedDimension, loc); + flags, storageType, expressedType, scales, zeroPoints, quantizedDimension, + qmin, qmax, loc); } } // namespace quant diff --git a/test/Dialect/QuantOps/convert-fakequant.mlir b/test/Dialect/QuantOps/convert-fakequant.mlir index 15de088f39ce..316702cc5288 100644 --- a/test/Dialect/QuantOps/convert-fakequant.mlir +++ b/test/Dialect/QuantOps/convert-fakequant.mlir @@ -180,3 +180,22 @@ func @fakeQuantArgs_UnrankedTensor(tensor) -> tensor { } : (tensor) -> tensor return %0 : tensor } + +// ----- +// Verifies a qint8 per axis +// CHECK_LABEL: fakeQuantPerAxis +func @fakeQuantPerAxis(tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { +^bb0(%arg0: tensor<8x4x3xf32>): + + // CHECK: %[[q:.*]] = "quant.qcast"(%arg0) : (tensor<8x4x3xf32>) + // CHECK-SAME: -> tensor<8x4x3x!quant.uniform> + // CHECK: %[[d:.*]] = "quant.dcast"(%[[q]]) + // CHECK-SAME: (tensor<8x4x3x!quant.uniform>) + + %0 = "quant.const_fake_quant_per_axis"(%arg0) { + min = [-1.0 : f32, 0.0 : f32, 0.0 : f32], + max = [0.9921875 : f32, 0.0: f32, 1.0 : f32], + num_bits = 8, narrow_range = false, is_signed = true, axis = 2 + } : (tensor<8x4x3xf32>) -> tensor<8x4x3xf32> + return %0 : tensor<8x4x3xf32> +}