Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchToLinalg] Adds Quantization Support for ConvTranspose #3240

Merged
merged 8 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg,
if (!isUnsignedType)
return;
int64_t minSI = -(1 << (numBits - 1));
Value minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, 32);
Value minSIValue = rewriter.create<arith::ConstantIntOp>(
loc, minSI, zp.getType().cast<mlir::IntegerType>().getWidth());
zp = rewriter.create<arith::AddIOp>(loc, zp, minSIValue);
minSIValue = rewriter.create<arith::ConstantIntOp>(loc, minSI, numBits);
arg = torch_to_linalg::createElementwiseLinalgGeneric(
Expand Down Expand Up @@ -797,6 +798,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
auto resultTy = cast<ValueTensorType>(op.getType());

Value inputZp, weightZp;
bool inputUnsigned, weightUnsigned = false;
zjgarvey marked this conversation as resolved.
Show resolved Hide resolved
if (auto make = op.getInput()
.getDefiningOp<Aten_MakePerTensorQuantizedTensorOp>()) {
input = make.getSelf();
Expand All @@ -806,6 +808,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}

if (auto make = op.getWeight()
Expand All @@ -818,6 +822,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}

if (static_cast<bool>(inputZp) != static_cast<bool>(weightZp)) {
Expand Down Expand Up @@ -916,15 +922,35 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
SmallVector<Value> strideIntValues =
getAsConstantIntValues(rewriter, loc, strideInts);

// convert any uint8 quantization to int8 quantization
if (auto integerType = dyn_cast<mlir::IntegerType>(inputDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, input, inputZp, inputUnsigned, width);
}
if (auto integerType = dyn_cast<mlir::IntegerType>(weightDTy)) {
int64_t width = integerType.getWidth();
signShift(rewriter, loc, weight, weightZp, weightUnsigned, width);
}
// Pad the input tensor according to padding.
SmallVector<Value> outDims{inBatch, weightBatch};
Value paddedInput;
if (transposed) {
if (!isa<mlir::FloatType>(inputDTy) || !isa<mlir::FloatType>(weightDTy) ||
!isa<mlir::FloatType>(resultDTy))
return rewriter.notifyMatchFailure(
op, "transpose does not support non-fp type yet");
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}
if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);

if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}
if (transposed) {
Value c0 =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value c1 =
Expand Down Expand Up @@ -994,7 +1020,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {

// Allocate padded input tensor
Value initTensor =
createZeroInitTensor(rewriter, loc, outerSizes, inputDTy);
createInitTensor(rewriter, loc, outerSizes, inputDTy, pad);

// Insert input into allocated tensor
SmallVector<Value> strideIndexValues{c1, c1};
Expand All @@ -1017,24 +1043,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
strideInts.clear();
strideInts.append(numSpatialDims, 1);
} else {
Value pad = inputZp;
if (!pad) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0));
if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0));
}

if (pad.getType() != inputDTy) {
if (isa<mlir::FloatType>(inputDTy))
pad = rewriter.create<arith::TruncFOp>(op.getLoc(), inputDTy, pad);

if (isa<mlir::IntegerType>(inputDTy))
pad = rewriter.create<arith::TruncIOp>(op.getLoc(), inputDTy, pad);
}

// Pad input
paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor(
op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad);
Expand Down
5 changes: 5 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@
"QuantizedReluInt8_basic",
"QuantizedReluUint8_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
# Dynamo not supporting conv_tbc
"ConvTbcModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -372,6 +373,7 @@
"Conv2dQInt8Module_basic",
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
Expand Down Expand Up @@ -544,6 +546,7 @@
"ContainsIntList_True",
"Conv2dQInt8Module_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"ConvolutionBackwardModule2DPadded_basic",
"ConvolutionBackwardModule2DStrided_basic",
"ConvolutionBackwardModule2D_basic",
Expand Down Expand Up @@ -2097,6 +2100,7 @@
"ElementwiseBitwiseAndScalarInt32Module_basic",
"ElementwiseBitwiseAndScalarInt8Module_basic",
"Conv2dQInt8Module_basic",
"ConvTranspose2DQInt8_basic",
}

ONNX_XFAIL_SET = {
Expand Down Expand Up @@ -2251,6 +2255,7 @@
"Conv2dWithPaddingModule_basic",
"Conv3dModule_basic",
"ConvTbcModule_basic",
"ConvTranspose2DQInt8_basic",
"Conv_Transpose2dModule_basic",
"Convolution2DModule_basic",
"Convolution2DStridedModule_basic",
Expand Down
53 changes: 53 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,3 +1046,56 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils):
weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8)
bias = torch.rand(3)
module.forward(inputVec, weight, bias)


N = 10
Cin = 5
Cout = 7
Hin = 10
Win = 8
Hker = 3
Wker = 2


class ConvTranspose2DQInt8Module(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.int8, True),
([-1, -1, -1, -1], torch.int8, True),
([-1], torch.float, True),
]
)
def forward(self, input, weight, bias):
qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25)
qinput = torch.dequantize(qinput)
qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50)
qweight = torch.dequantize(qweight)
qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32)
qbias = torch.dequantize(qbias)
qz = torch.ops.aten.convolution(
qinput,
qweight,
bias=qbias,
stride=[2, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=True,
output_padding=[0, 0],
groups=1,
)
return qz


@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module())
def ConvTranspose2DQInt8_basic(module, tu: TestUtils):
module.forward(
tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8),
tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8),
torch.rand(Cout),
)
Loading