Skip to content

Commit

Permalink
Add frontend support for saturate for CastLike and QuantizeLinear (#2480
Browse files Browse the repository at this point in the history
)

Signed-off-by: philass <[email protected]>
Co-authored-by: Soren Lassen <[email protected]>
  • Loading branch information
philass and sorenlassen authored Sep 6, 2023
1 parent ef535dd commit 49233b0
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/Builder/OpBuildTable.inc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ op_dialect_version_map_["BitwiseOr"] = {18};
op_dialect_version_map_["BitwiseXor"] = {18};
op_dialect_version_map_["BlackmanWindow"] = {17};
op_dialect_version_map_["Cast"] = {19};
op_dialect_version_map_["CastLike"] = {15};
op_dialect_version_map_["CastLike"] = {19};
op_dialect_version_map_["CastMap"] = {1};
op_dialect_version_map_["CategoryMapper"] = {1};
op_dialect_version_map_["Ceil"] = {13};
Expand Down Expand Up @@ -137,7 +137,7 @@ op_dialect_version_map_["Pad"] = {18, 13, 11, 2};
op_dialect_version_map_["Pow"] = {15};
op_dialect_version_map_["QLinearConv"] = {10};
op_dialect_version_map_["QLinearMatMul"] = {10};
op_dialect_version_map_["QuantizeLinear"] = {13};
op_dialect_version_map_["QuantizeLinear"] = {19};
op_dialect_version_map_["RNN"] = {14};
op_dialect_version_map_["RandomNormal"] = {1};
op_dialect_version_map_["RandomNormalLike"] = {1};
Expand Down
28 changes: 17 additions & 11 deletions src/Dialect/ONNX/ONNXOps.td.inc
Original file line number Diff line number Diff line change
Expand Up @@ -898,9 +898,10 @@ def ONNXCastLikeOp:ONNX_Op<"CastLike",
the same data type as the elements of the second input tensor.
See documentation of the Cast operator for further details.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>]>:$input,
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>]>:$target_type);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>]>:$output);
let arguments = (ins AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$input,
AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$target_type,
DefaultValuedAttr<SI64Attr, "1">:$saturate);
let results = (outs AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I1]>, TensorOf<[StringType]>, TensorOf<[BF16]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$output);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 2;
Expand Down Expand Up @@ -5685,15 +5686,20 @@ def ONNXQuantizeLinearOp:ONNX_Op<"QuantizeLinear",
let description = [{
The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor.
The scale factor and zero point must have same shape, and can be either a scalar for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantization.
The quantization formula is y = saturate ((x / y_scale) + y_zero_point).
The quantization formula is `y = saturate ((x / y_scale) + y_zero_point)`.
For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8.
For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[I32]>]>:$x,
TensorOf<[F32]>:$y_scale,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, NoneType]>:$y_zero_point,
DefaultValuedAttr<SI64Attr, "1">:$axis);
let results = (outs AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>]>:$y);
For (x / y_scale), it's rounding to the nearest even. Refer to https://en.wikipedia.org/wiki/Rounding for details.
'y_zero_point' and 'y' must have same type.
'y_zero_point' is usually not used for quantization to float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz,
but the quantization formula remains the same for consistency and
the type of the attribute 'y_zero_point' still determines the quantization type.
}];
let arguments = (ins AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>, TensorOf<[I32]>]>:$x,
AnyTypeOf<[TensorOf<[F32]>, TensorOf<[F16]>, TensorOf<[BF16]>, TensorOf<[I32]>]>:$y_scale,
AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>, NoneType]>:$y_zero_point,
DefaultValuedAttr<SI64Attr, "1">:$axis,
DefaultValuedAttr<SI64Attr, "1">:$saturate);
let results = (outs AnyTypeOf<[TensorOf<[I8]>, TensorOf<[UI8]>, TensorOf<[F8E4M3FN]>, TensorOf<[F8E4M3FNUZ]>, TensorOf<[F8E5M2]>, TensorOf<[F8E5M2FNUZ]>]>:$y);
let extraClassDeclaration = [{
static int getNumberOfOperands() {
return 3;
Expand Down
8 changes: 4 additions & 4 deletions test/mlir/onnx/onnx_shape_inference.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1710,7 +1710,7 @@ func.func @test_castlike_1(%arg0 : tensor<2x3x4xf32>, %arg1 : tensor<2xf16>) ->
"onnx.Return"(%1) : (tensor<*xf16>) -> ()

// CHECK-LABEL: test_castlike_1
// CHECK: [[RES:%.+]] = "onnx.CastLike"(%arg0, %arg1) : (tensor<2x3x4xf32>, tensor<2xf16>) -> tensor<2x3x4xf16>
// CHECK: [[RES:%.+]] = "onnx.CastLike"(%arg0, %arg1) {saturate = 1 : si64} : (tensor<2x3x4xf32>, tensor<2xf16>) -> tensor<2x3x4xf16>
// CHECK: onnx.Return [[RES]] : tensor<2x3x4xf16>
}

Expand Down Expand Up @@ -1739,7 +1739,7 @@ func.func @test_quantize_linear_1(%arg0 : tensor<5x2x3x4xf32>, %arg1 : tensor<f3
"onnx.Return"(%1) {} : (tensor<*xi8>) -> ()

// CHECK-LABEL: test_quantize_linear_1
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xi8>
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64, saturate = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<i8>) -> tensor<5x2x3x4xi8>
// CHECK: onnx.Return [[RES]] : tensor<5x2x3x4xi8>
}

Expand All @@ -1750,7 +1750,7 @@ func.func @test_quantize_linear_2(%arg0 : tensor<5x2x3x4xf32>, %arg1: tensor<f32
"onnx.Return"(%0) {} : (tensor<*xui8>) -> ()

// CHECK-LABEL: test_quantize_linear_2
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<ui8>) -> tensor<5x2x3x4xui8>
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) {axis = 1 : si64, saturate = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, tensor<ui8>) -> tensor<5x2x3x4xui8>
// CHECK: onnx.Return [[RES]] : tensor<5x2x3x4xui8>
}

Expand All @@ -1762,7 +1762,7 @@ func.func @test_quantize_linear_3(%arg0 : tensor<5x2x3x4xf32>, %arg1: tensor<f32
"onnx.Return"(%0) {} : (tensor<*xui8>) -> ()

// CHECK-LABEL: test_quantize_linear_3
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %0) {axis = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, none) -> tensor<5x2x3x4xui8>
// CHECK: [[RES:%.+]] = "onnx.QuantizeLinear"(%arg0, %arg1, %0) {axis = 1 : si64, saturate = 1 : si64} : (tensor<5x2x3x4xf32>, tensor<f32>, none) -> tensor<5x2x3x4xui8>
// CHECK: onnx.Return [[RES]] : tensor<5x2x3x4xui8>
}

Expand Down
4 changes: 2 additions & 2 deletions test/mlir/onnx/parse/functiontest_attrwithdefault.onnxtext
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ myfun <a: float=1.0> (x) => (y) {
// CHECK-LABEL: func.func @main_graph
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?xf32>) -> tensor<?xf32> attributes {input_names = ["x"], output_names = ["y"]} {
// CHECK: [[VAR_0_:%.+]] = onnx.Constant {value_float = 2.000000e+00 : f32} : tensor<f32>
// CHECK: [[VAR_1_:%.+]] = "onnx.CastLike"([[VAR_0_]], [[PARAM_0_]]) : (tensor<f32>, tensor<?xf32>) -> tensor<f32>
// CHECK: [[VAR_1_:%.+]] = "onnx.CastLike"([[VAR_0_]], [[PARAM_0_]]) {saturate = 1 : si64} : (tensor<f32>, tensor<?xf32>) -> tensor<f32>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_1_]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK-DAG: [[VAR_3_:%.+]] = onnx.Constant {value_float = 1.000000e+00 : f32} : tensor<f32>
// CHECK: [[VAR_4_:%.+]] = "onnx.CastLike"([[VAR_3_]], [[PARAM_0_]]) : (tensor<f32>, tensor<?xf32>) -> tensor<f32>
// CHECK: [[VAR_4_:%.+]] = "onnx.CastLike"([[VAR_3_]], [[PARAM_0_]]) {saturate = 1 : si64} : (tensor<f32>, tensor<?xf32>) -> tensor<f32>
// CHECK: [[VAR_5_:%.+]] = "onnx.Add"([[PARAM_0_]], [[VAR_4_]]) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: [[VAR_6_:%.+]] = "onnx.Add"([[VAR_2_]], [[VAR_5_]]) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: onnx.Return [[VAR_6_]] : tensor<?xf32>
Expand Down
4 changes: 2 additions & 2 deletions utils/gen_onnx_mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
'BitwiseXor': [18],
'BlackmanWindow': [17],
'Cast': [19],
'CastLike': [15],
'CastLike': [19],
'CastMap': [1],
'CategoryMapper': [1],
'Ceil': [13],
Expand Down Expand Up @@ -206,7 +206,7 @@
'Pow': [15],
'QLinearConv': [10],
'QLinearMatMul': [10],
'QuantizeLinear': [13],
'QuantizeLinear': [19],
'RNN': [14],
'RandomNormal': [1],
'RandomNormalLike': [1],
Expand Down

0 comments on commit 49233b0

Please sign in to comment.