Skip to content

Commit

Permalink
[TorchToLinalg] Fix Quantized Convolution Accumulator Type (#3459)
Browse files Browse the repository at this point in the history
1. truncates zero-points to i32
2. modifies the default accumulator type for i8 from i64 to i32. 
3. now uses the input dtype to infer accumulator dtype.
  • Loading branch information
zjgarvey authored Jun 20, 2024
1 parent c7d52f6 commit 694210f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/torch-mlir/Dialect/Torch/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA,
// control the behavior. Such support would be done in coordination with
// the fx_importer and APIs, which could add hints to the IR (based on
// Torch flags, user options, etc).
// Note: The special case of int8 intentionally deviates from the reference, and
// uses int32 instead of int64 accumulation.
Type getDefaultAccType(PatternRewriter &rewriter, Type inputType);

LogicalResult getPermutedType(BaseTensorType inType,
Expand Down
9 changes: 7 additions & 2 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ class ConvertAtenMmOp : public OpConversionPattern<AtenMmOp> {
TensorType resultType =
cast<TensorType>(getTypeConverter()->convertType(op.getType()));
Type elementType = resultType.getElementType();
auto accumulatorDType = getDefaultAccType(rewriter, elementType);
auto accumulatorDType =
getDefaultAccType(rewriter, lhsType.getElementType());
if (accumulatorDType != resultType.getElementType()) {
elementType = accumulatorDType;
}
Expand Down Expand Up @@ -803,6 +804,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
inputZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(inputZp.getType()),
inputZp);
inputZp =
rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), inputZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}
Expand All @@ -817,6 +820,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
weightZp = typeConverter->materializeTargetConversion(
rewriter, loc, typeConverter->convertType(weightZp.getType()),
weightZp);
weightZp = rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(),
weightZp);
auto torchDtype = cast<ValueTensorType>(make.getType()).getDtype();
weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype);
}
Expand Down Expand Up @@ -1049,7 +1054,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
castIndexToInt(weightDims[i]), strideIntValues[i]));
}

Type accumulatorDType = getDefaultAccType(rewriter, resultDTy);
Type accumulatorDType = getDefaultAccType(rewriter, inputDTy);
Value initTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(outDims), accumulatorDType);

Expand Down
13 changes: 6 additions & 7 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -625,15 +625,14 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) {
return rewriter.getF32Type();
if (inputType.isFloat8E4M3FNUZ())
return rewriter.getF32Type();
if (inputType.isSignedInteger(8))
if (inputType.isInteger(8))
// this is an intentional deviation from CUDA (which accumulates i8 to i64)
return rewriter.getI32Type();
if (inputType.isInteger(16))
return rewriter.getI64Type();
if (inputType.isUnsignedInteger(8))
if (inputType.isInteger(32))
return rewriter.getI64Type();
if (inputType.isSignedInteger(16))
return rewriter.getI64Type();
if (inputType.isSignedInteger(32))
return rewriter.getI64Type();
if (inputType.isSignedInteger(64))
if (inputType.isInteger(64))
return rewriter.getI64Type();
return inputType;
}
38 changes: 38 additions & 0 deletions test/Conversion/TorchToLinalg/convolution.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,41 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128]
%4 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %false, %3, %int1 : !torch.vtensor<[1,24,16,128,128],f16>, !torch.vtensor<[54,24,1,1,1],f16>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,54,16,128,128],f16>
return %4 : !torch.vtensor<[1,54,16,128,128],f16>
}

// -----

// CHECK-LABEL: func.func @q_conv_test
// CHECK: %[[c3:.*]] = arith.constant 3 : i32
// CHECK: %[[c7:.*]] = arith.constant 7 : i32
// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor<?x?x?x?xi8>
// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor<?x?x?x?xi8>)
// CHECK-SAME: permutation = [0, 2, 3, 1]
// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor<?x?x?x?xi8>)
// CHECK-SAME: permutation = [2, 3, 1, 0]
// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>}
// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor<?x?x?x?xi8>, tensor<?x?x?x?xi8>, i32, i32)
// CHECK-SAME: outs(%[[convout:.*]] : tensor<?x?x?x?xi32>) -> tensor<?x?x?x?xi32>
func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%false = torch.constant.bool false
%int1 = torch.constant.int 1
%int0 = torch.constant.int 0
%float1.000000e-04 = torch.constant.float 1.000000e-04
%int3 = torch.constant.int 3
%int7 = torch.constant.int 7
%float1.000000e-02 = torch.constant.float 1.000000e-02
%int14 = torch.constant.int 14
%0 = torch.aten.quantize_per_tensor %arg2, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32>
%1 = torch.aten.dequantize.self %0 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],f32>
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list<int>
%4 = torch.prim.ListConstruct : () -> !torch.list<int>
%5 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float1.000000e-02, %int7 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8>
%6 = torch.aten._make_per_tensor_quantized_tensor %arg1, %float1.000000e-02, %int3 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8>
%7 = torch.aten.quantize_per_tensor %1, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32>
%8 = torch.aten.int_repr %7 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],si32>
%9 = torch.aten.convolution %5, %6, %8, %2, %3, %2, %false, %4, %int1 : !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?],si32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[?,?,?,?],si32>
%10 = torch.aten._make_per_tensor_quantized_tensor %9, %float1.000000e-04, %int0 : !torch.vtensor<[?,?,?,?],si32>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint32>
%11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32>
return %11 : !torch.vtensor<[?,?,?,?],f32>
}

0 comments on commit 694210f

Please sign in to comment.