diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 62e6680f489b..cf31c8f9735a 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -145,6 +145,8 @@ LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, // control the behavior. Such support would be done in coordination with // the fx_importer and APIs, which could add hints to the IR (based on // Torch flags, user options, etc). +// Note: The special case of int8 intentionally deviates from the reference, and +// uses int32 instead of int64 accumulation. Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); LogicalResult getPermutedType(BaseTensorType inType, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 318c2bec361f..c72db61c42fc 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -149,7 +149,8 @@ class ConvertAtenMmOp : public OpConversionPattern { TensorType resultType = cast(getTypeConverter()->convertType(op.getType())); Type elementType = resultType.getElementType(); - auto accumulatorDType = getDefaultAccType(rewriter, elementType); + auto accumulatorDType = + getDefaultAccType(rewriter, lhsType.getElementType()); if (accumulatorDType != resultType.getElementType()) { elementType = accumulatorDType; } @@ -803,6 +804,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { inputZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(inputZp.getType()), inputZp); + inputZp = + rewriter.create(loc, rewriter.getI32Type(), inputZp); auto torchDtype = cast(make.getType()).getDtype(); inputUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -817,6 +820,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightZp = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(weightZp.getType()), weightZp); + weightZp = rewriter.create(loc, rewriter.getI32Type(), + weightZp); auto torchDtype = cast(make.getType()).getDtype(); weightUnsigned = torch_to_linalg::isUnsignedTorchType(torchDtype); } @@ -1049,7 +1054,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { castIndexToInt(weightDims[i]), strideIntValues[i])); } - Type accumulatorDType = getDefaultAccType(rewriter, resultDTy); + Type accumulatorDType = getDefaultAccType(rewriter, inputDTy); Value initTensor = rewriter.create( loc, getAsOpFoldResult(outDims), accumulatorDType); diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 81a2de87b054..eb8b37502efc 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -625,15 +625,14 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getF32Type(); if (inputType.isFloat8E4M3FNUZ()) return rewriter.getF32Type(); - if (inputType.isSignedInteger(8)) + if (inputType.isInteger(8)) + // this is an intentional deviation from CUDA (which accumulates i8 to i64) + return rewriter.getI32Type(); + if (inputType.isInteger(16)) return rewriter.getI64Type(); - if (inputType.isUnsignedInteger(8)) + if (inputType.isInteger(32)) return rewriter.getI64Type(); - if (inputType.isSignedInteger(16)) - return rewriter.getI64Type(); - if (inputType.isSignedInteger(32)) - return rewriter.getI64Type(); - if (inputType.isSignedInteger(64)) + if (inputType.isInteger(64)) return rewriter.getI64Type(); return inputType; } diff --git a/test/Conversion/TorchToLinalg/convolution.mlir b/test/Conversion/TorchToLinalg/convolution.mlir index 1fead662183e..f99648684a23 100644 --- a/test/Conversion/TorchToLinalg/convolution.mlir +++ b/test/Conversion/TorchToLinalg/convolution.mlir @@ -16,3 +16,41 @@ func.func @torch.aten.convolution$nobias(%arg0: !torch.vtensor<[1,24,16,128,128] %4 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %false, %3, %int1 : !torch.vtensor<[1,24,16,128,128],f16>, !torch.vtensor<[54,24,1,1,1],f16>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,54,16,128,128],f16> return %4 : !torch.vtensor<[1,54,16,128,128],f16> } + +// ----- + +// CHECK-LABEL: func.func @q_conv_test +// CHECK: %[[c3:.*]] = arith.constant 3 : i32 +// CHECK: %[[c7:.*]] = arith.constant 7 : i32 +// CHECK: %[[input:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[?,?,?,?],si8> -> tensor +// CHECK: %[[weight:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[?,?,?,?],si8> -> tensor +// CHECK: %[[TransInput:.*]] = linalg.transpose ins(%[[input]] : tensor) +// CHECK-SAME: permutation = [0, 2, 3, 1] +// CHECK: %[[TransWeight:.*]] = linalg.transpose ins(%[[weight]] : tensor) +// CHECK-SAME: permutation = [2, 3, 1, 0] +// CHECK: %[[conv:.*]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} +// CHECK-SAME: ins(%[[TransInput]], %[[TransWeight]], %[[c7]], %[[c3]] : tensor, tensor, i32, i32) +// CHECK-SAME: outs(%[[convout:.*]] : tensor) -> tensor +func.func @q_conv_test(%arg0: !torch.vtensor<[?,?,?,?],si8>, %arg1: !torch.vtensor<[?,?,?,?],si8>, %arg2: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %float1.000000e-04 = torch.constant.float 1.000000e-04 + %int3 = torch.constant.int 3 + %int7 = torch.constant.int 7 + %float1.000000e-02 = torch.constant.float 1.000000e-02 + %int14 = torch.constant.int 14 + %0 = torch.aten.quantize_per_tensor %arg2, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> + %1 = torch.aten.dequantize.self %0 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],f32> + %2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten._make_per_tensor_quantized_tensor %arg0, %float1.000000e-02, %int7 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8> + %6 = torch.aten._make_per_tensor_quantized_tensor %arg1, %float1.000000e-02, %int3 : !torch.vtensor<[?,?,?,?],si8>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint8> + %7 = torch.aten.quantize_per_tensor %1, %float1.000000e-04, %int0, %int14 : !torch.vtensor<[?],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[?],!torch.qint32> + %8 = torch.aten.int_repr %7 : !torch.vtensor<[?],!torch.qint32> -> !torch.vtensor<[?],si32> + %9 = torch.aten.convolution %5, %6, %8, %2, %3, %2, %false, %4, %int1 : !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?,?,?,?],!torch.qint8>, !torch.vtensor<[?],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,?,?,?],si32> + %10 = torch.aten._make_per_tensor_quantized_tensor %9, %float1.000000e-04, %int0 : !torch.vtensor<[?,?,?,?],si32>, !torch.float, !torch.int -> !torch.vtensor<[?,?,?,?],!torch.qint32> + %11 = torch.aten.dequantize.tensor %10 : !torch.vtensor<[?,?,?,?],!torch.qint32> -> !torch.vtensor<[?,?,?,?],f32> + return %11 : !torch.vtensor<[?,?,?,?],f32> +}