From 2c18dd47f25f3396fce58b2ceb1bcf6fb9669541 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Thu, 25 Apr 2024 18:57:41 -0500 Subject: [PATCH 1/6] Add Support for Quantized Conv Transpose --- lib/Conversion/TorchToLinalg/Linear.cpp | 63 ++++++----- .../torch_mlir_e2e_test/test_suite/conv.py | 44 ++++++++ .../test_suite/quantized_models.py | 102 ++++++++++++++++++ 3 files changed, 181 insertions(+), 28 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b7db0496f516..31f65fe01ac4 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -42,8 +42,8 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, Value &zp, bool isUnsignedType, int64_t numBits) { if (!isUnsignedType) return; - int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create(loc, minSI, 32); + int64_t minSI = -(1 << (numBits - 1)); + Value minSIValue = rewriter.create(loc, minSI, zp.getType().cast().getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( @@ -797,6 +797,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto resultTy = op.getType().cast(); Value inputZp, weightZp; + bool inputUnsigned, weightUnsigned = false; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf(); @@ -806,6 +807,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + auto torchDtype = cast(make.getType()).getDtype(); + inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (auto make = op.getWeight() @@ -818,6 +821,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + auto torchDtype = cast(make.getType()).getDtype(); + weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } if (static_cast(inputZp) != static_cast(weightZp)) { @@ -916,15 +921,35 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector strideIntValues = getAsConstantIntValues(rewriter, loc, strideInts); + // convert any uint8 quantization to int8 quantization + if (auto integerType = dyn_cast(inputDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, input, inputZp, inputUnsigned, width); + } + if (auto integerType = dyn_cast(weightDTy)) { + int64_t width = integerType.getWidth(); + signShift(rewriter, loc, weight, weightZp, weightUnsigned, width); + } // Pad the input tensor according to padding. SmallVector outDims{inBatch, weightBatch}; Value paddedInput; - if (transposed) { - if (!isa(inputDTy) || !isa(weightDTy) || - !isa(resultDTy)) - return rewriter.notifyMatchFailure( - op, "transpose does not support non-fp type yet"); + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + if (transposed) { Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -975,13 +1000,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { innerSize = rewriter.create( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); innerSize = rewriter.create(loc, innerSize, c1); - + // offset = (weightDims[i] - 1) * dilation[i] - padding[i] Value offset = rewriter.create(loc, weightDims[i], c1); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); - + // outer size = offset * 2 + inner size + outputPadding[i] Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); outerSize = rewriter.create( @@ -994,7 +1019,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); + createInitTensor(rewriter, loc, outerSizes, inputDTy, pad); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -1017,24 +1042,6 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { strideInts.clear(); strideInts.append(numSpatialDims, 1); } else { - Value pad = inputZp; - if (!pad) { - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); - if (isa(inputDTy)) - pad = rewriter.create( - op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); - } - - if (pad.getType() != inputDTy) { - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - - if (isa(inputDTy)) - pad = rewriter.create(op.getLoc(), inputDTy, pad); - } - // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 5872df170c48..52fe26077e31 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -892,3 +892,47 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) bias = torch.rand(3) module.forward(inputVec, weight, bias) + +N = 10 +Cin = 5 +Cout = 7 +Hin = 10 +Win = 8 +Hker = 3 +Wker = 2 +class ConvTranspose2DQInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ]) + def forward(self, input, weight, bias): + qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) + qinput = torch.dequantize(qinput) + qweight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 50) + qweight = torch.dequantize(qweight) + qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + qbias = torch.dequantize(qbias) + qz = torch.ops.aten.convolution(qinput, + qweight, + bias=qbias, + stride=[2,1], + padding=[1,1], + dilation=[1,1], + transposed=True, + output_padding=[0,0], + groups=1) + return qz + +@register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) +def ConvTranspose2DQInt8_basic(module, tu: TestUtils): + module.forward(tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), + tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), + torch.rand(Cout), + ) \ No newline at end of file diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 47e8adffdfd8..5d82c5e0a438 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -159,3 +159,105 @@ def get_quantized_mlp(): @register_test_case(module_factory=get_quantized_mlp) def QuantizedMLP_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) + +N = 1 +Cin = 2 +Cout = 3 +Hin = 1 +Win = 1 +Hker = 1 +Wker = 1 + +def get_conv_model_input(): + return torch.rand((N, Cin, Hin, Win)) + +class QuantizedConvTranspose2DModule(nn.Module): + def __init__(self): + super().__init__() + torch.random.manual_seed(1) + self.layers = nn.Sequential( + nn.ConvTranspose2d( + in_channels=Cin, + out_channels=Cout, + kernel_size=(Hker, Wker), + stride=(1, 1), + padding=(0, 0), + groups=1, + bias=True, + output_padding=(0,0), + dilation=1, + ), + ) + self.quantize = torch.quantization.QuantStub() + self.dequantize = torch.quantization.DeQuantStub() + + @export + @annotate_args([ + None, + ([N, Cin, Hin, Win], torch.float32, True), + ]) + def forward(self, x): + x = self.quantize(x) + x = self.layers(x) + x = self.dequantize(x) + return x + +def get_quantized_conv_transpose(): + model = QuantizedConvTranspose2DModule() + model.eval() + model.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(model, inplace=True) + torch.manual_seed(1) + for _ in range(32): + model(get_conv_model_input()) + torch.quantization.convert(model, inplace=True) + return model + +@register_test_case(module_factory=get_quantized_conv_transpose) +def QuantizedConvTranspose2DModule_basic(module, tu: TestUtils): + module.forward(0.5*torch.ones(N,Cin,Hin,Win,dtype=torch.float32)) + +class QuantizedConv2DModule(nn.Module): + def __init__(self): + super().__init__() + torch.random.manual_seed(1) + self.layers = nn.Sequential( + nn.Conv2d( + in_channels=Cin, + out_channels=Cout, + kernel_size=(Hker, Wker), + stride=(1, 1), + padding=(0, 0), + groups=1, + bias=True, + dilation=1, + ), + ) + self.quantize = torch.quantization.QuantStub() + self.dequantize = torch.quantization.DeQuantStub() + + @export + @annotate_args([ + None, + ([N, Cin, Hin, Win], torch.float32, True), + ]) + def forward(self, x): + x = self.quantize(x) + x = self.layers(x) + x = self.dequantize(x) + return x + +def get_quantized_conv(): + model = QuantizedConv2DModule() + model.eval() + model.qconfig = torch.quantization.default_qconfig + torch.quantization.prepare(model, inplace=True) + torch.manual_seed(1) + for _ in range(32): + model(get_conv_model_input()) + torch.quantization.convert(model, inplace=True) + return model + +@register_test_case(module_factory=get_quantized_conv) +def QuantizedConv2DModule_basic(module, tu: TestUtils): + module.forward(get_conv_model_input()) \ No newline at end of file From 93121d375348b87cde21f8b96dc56826fb0a34ef Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 26 Apr 2024 11:35:52 -0500 Subject: [PATCH 2/6] edit tests & xfails --- lib/Conversion/TorchToLinalg/Linear.cpp | 4 +- projects/pt1/e2e_testing/xfail_sets.py | 5 + .../test_suite/quantized_models.py | 102 ------------------ 3 files changed, 7 insertions(+), 104 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 31f65fe01ac4..0a91d94c331a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1000,13 +1000,13 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { innerSize = rewriter.create( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); innerSize = rewriter.create(loc, innerSize, c1); - // offset = (weightDims[i] - 1) * dilation[i] - padding[i] + Value offset = rewriter.create(loc, weightDims[i], c1); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, dilationIntValues[i])); offset = rewriter.create( loc, offset, castIntToIndex(rewriter, loc, paddingIntValues[i])); - // outer size = offset * 2 + inner size + outputPadding[i] + Value outerSize = rewriter.create(loc, offset, c2); outerSize = rewriter.create(loc, outerSize, innerSize); outerSize = rewriter.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 55a005e681dd..764e1543ff4f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -335,6 +335,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8Module_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -437,6 +438,7 @@ 'Conv2dQInt8Module_basic', 'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier', 'ConvTbcModule_basic', + "ConvTranspose2DQInt8Module_basic", 'ConvolutionBackwardModule2DPadded_basic', 'ConvolutionBackwardModule2DStrided_basic', 'ConvolutionBackwardModule2D_basic', @@ -609,6 +611,7 @@ "ContainsIntList_True", "Conv2dQInt8Module_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8Module_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -2163,6 +2166,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", + "ConvTranspose2DQInt8Module_basic", } ONNX_XFAIL_SET = { @@ -2322,6 +2326,7 @@ "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", "ConvTbcModule_basic", + "ConvTranspose2DQInt8Module_basic", "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py index 5d82c5e0a438..47e8adffdfd8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/quantized_models.py @@ -159,105 +159,3 @@ def get_quantized_mlp(): @register_test_case(module_factory=get_quantized_mlp) def QuantizedMLP_basic(module, tu: TestUtils): module.forward(get_quant_model_input()) - -N = 1 -Cin = 2 -Cout = 3 -Hin = 1 -Win = 1 -Hker = 1 -Wker = 1 - -def get_conv_model_input(): - return torch.rand((N, Cin, Hin, Win)) - -class QuantizedConvTranspose2DModule(nn.Module): - def __init__(self): - super().__init__() - torch.random.manual_seed(1) - self.layers = nn.Sequential( - nn.ConvTranspose2d( - in_channels=Cin, - out_channels=Cout, - kernel_size=(Hker, Wker), - stride=(1, 1), - padding=(0, 0), - groups=1, - bias=True, - output_padding=(0,0), - dilation=1, - ), - ) - self.quantize = torch.quantization.QuantStub() - self.dequantize = torch.quantization.DeQuantStub() - - @export - @annotate_args([ - None, - ([N, Cin, Hin, Win], torch.float32, True), - ]) - def forward(self, x): - x = self.quantize(x) - x = self.layers(x) - x = self.dequantize(x) - return x - -def get_quantized_conv_transpose(): - model = QuantizedConvTranspose2DModule() - model.eval() - model.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(model, inplace=True) - torch.manual_seed(1) - for _ in range(32): - model(get_conv_model_input()) - torch.quantization.convert(model, inplace=True) - return model - -@register_test_case(module_factory=get_quantized_conv_transpose) -def QuantizedConvTranspose2DModule_basic(module, tu: TestUtils): - module.forward(0.5*torch.ones(N,Cin,Hin,Win,dtype=torch.float32)) - -class QuantizedConv2DModule(nn.Module): - def __init__(self): - super().__init__() - torch.random.manual_seed(1) - self.layers = nn.Sequential( - nn.Conv2d( - in_channels=Cin, - out_channels=Cout, - kernel_size=(Hker, Wker), - stride=(1, 1), - padding=(0, 0), - groups=1, - bias=True, - dilation=1, - ), - ) - self.quantize = torch.quantization.QuantStub() - self.dequantize = torch.quantization.DeQuantStub() - - @export - @annotate_args([ - None, - ([N, Cin, Hin, Win], torch.float32, True), - ]) - def forward(self, x): - x = self.quantize(x) - x = self.layers(x) - x = self.dequantize(x) - return x - -def get_quantized_conv(): - model = QuantizedConv2DModule() - model.eval() - model.qconfig = torch.quantization.default_qconfig - torch.quantization.prepare(model, inplace=True) - torch.manual_seed(1) - for _ in range(32): - model(get_conv_model_input()) - torch.quantization.convert(model, inplace=True) - return model - -@register_test_case(module_factory=get_quantized_conv) -def QuantizedConv2DModule_basic(module, tu: TestUtils): - module.forward(get_conv_model_input()) \ No newline at end of file From e7b3e94ac0b5768bf08c5e97745e77269f47c38d Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 26 Apr 2024 11:44:40 -0500 Subject: [PATCH 3/6] clang-format --- lib/Conversion/TorchToLinalg/Linear.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 0a91d94c331a..d870d552b5b7 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -42,8 +42,9 @@ static void signShift(PatternRewriter &rewriter, Location loc, Value &arg, Value &zp, bool isUnsignedType, int64_t numBits) { if (!isUnsignedType) return; - int64_t minSI = -(1 << (numBits - 1)); - Value minSIValue = rewriter.create(loc, minSI, zp.getType().cast().getWidth()); + int64_t minSI = -(1 << (numBits - 1)); + Value minSIValue = rewriter.create( + loc, minSI, zp.getType().cast().getWidth()); zp = rewriter.create(loc, zp, minSIValue); minSIValue = rewriter.create(loc, minSI, numBits); arg = torch_to_linalg::createElementwiseLinalgGeneric( From e93cbe630cb125cd343f9f83fbce0ce71d85082a Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Fri, 26 Apr 2024 11:54:53 -0500 Subject: [PATCH 4/6] fix name in xfails set --- projects/pt1/e2e_testing/xfail_sets.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 764e1543ff4f..fab9731df9c0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -335,7 +335,7 @@ "QuantizedReluInt8_basic", "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", - "ConvTranspose2DQInt8Module_basic", + "ConvTranspose2DQInt8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -438,7 +438,7 @@ 'Conv2dQInt8Module_basic', 'Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier', 'ConvTbcModule_basic', - "ConvTranspose2DQInt8Module_basic", + "ConvTranspose2DQInt8_basic", 'ConvolutionBackwardModule2DPadded_basic', 'ConvolutionBackwardModule2DStrided_basic', 'ConvolutionBackwardModule2D_basic', @@ -611,7 +611,7 @@ "ContainsIntList_True", "Conv2dQInt8Module_basic", "ConvTbcModule_basic", - "ConvTranspose2DQInt8Module_basic", + "ConvTranspose2DQInt8_basic", "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", @@ -2166,7 +2166,7 @@ "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", - "ConvTranspose2DQInt8Module_basic", + "ConvTranspose2DQInt8_basic", } ONNX_XFAIL_SET = { @@ -2326,7 +2326,7 @@ "Conv2dWithPaddingModule_basic", "Conv3dModule_basic", "ConvTbcModule_basic", - "ConvTranspose2DQInt8Module_basic", + "ConvTranspose2DQInt8_basic", "Conv_Transpose2dModule_basic", "Convolution2DModule_basic", "Convolution2DStridedModule_basic", From eaa9dd703c3b815096ac41c7b8de06146a415f2e Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 29 Apr 2024 10:19:01 -0500 Subject: [PATCH 5/6] black reformatting --- projects/pt1/e2e_testing/xfail_sets.py | 1 - .../torch_mlir_e2e_test/test_suite/conv.py | 47 +++++++++++-------- 2 files changed, 28 insertions(+), 20 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index eb2923eda323..b5d2f4ed580f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -273,7 +273,6 @@ "QuantizedReluUint8_basic", "Conv2dQInt8Module_basic", "ConvTranspose2DQInt8_basic", - # Dynamo not supporting conv_tbc "ConvTbcModule_basic", "FloatImplicitModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 28d17bc7bd74..e99525c32d88 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1047,6 +1047,7 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): bias = torch.rand(3) module.forward(inputVec, weight, bias) + N = 10 Cin = 5 Cout = 7 @@ -1054,18 +1055,22 @@ def Conv2dQInt8Module_basic(module, tu: TestUtils): Win = 8 Hker = 3 Wker = 2 + + class ConvTranspose2DQInt8Module(torch.nn.Module): def __init__(self): super().__init__() @export - @annotate_args([ - None, - ([-1, -1, -1, -1], torch.int8, True), - ([-1, -1, -1, -1], torch.int8, True), - ([-1], torch.float, True), - ]) + @annotate_args( + [ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ] + ) def forward(self, input, weight, bias): qinput = torch._make_per_tensor_quantized_tensor(input, 0.01, -25) qinput = torch.dequantize(qinput) @@ -1073,20 +1078,24 @@ def forward(self, input, weight, bias): qweight = torch.dequantize(qweight) qbias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) qbias = torch.dequantize(qbias) - qz = torch.ops.aten.convolution(qinput, - qweight, - bias=qbias, - stride=[2,1], - padding=[1,1], - dilation=[1,1], - transposed=True, - output_padding=[0,0], - groups=1) + qz = torch.ops.aten.convolution( + qinput, + qweight, + bias=qbias, + stride=[2, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=True, + output_padding=[0, 0], + groups=1, + ) return qz + @register_test_case(module_factory=lambda: ConvTranspose2DQInt8Module()) def ConvTranspose2DQInt8_basic(module, tu: TestUtils): - module.forward(tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), - tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), - torch.rand(Cout), - ) \ No newline at end of file + module.forward( + tu.randint(N, Cin, Hin, Win, low=-128, high=127).to(torch.int8), + tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), + torch.rand(Cout), + ) From 071dd9799efa82f4dfe0b9345589c2a780f2f829 Mon Sep 17 00:00:00 2001 From: zjgarvey Date: Mon, 29 Apr 2024 17:30:18 -0500 Subject: [PATCH 6/6] address comment --- lib/Conversion/TorchToLinalg/Linear.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 88b506badf27..c49646e2f1c0 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -798,7 +798,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { auto resultTy = cast(op.getType()); Value inputZp, weightZp; - bool inputUnsigned, weightUnsigned = false; + bool inputUnsigned = false; + bool weightUnsigned = false; if (auto make = op.getInput() .getDefiningOp()) { input = make.getSelf();