-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][LLVM] Fix conversion of non-standard MLIR float types #122634
[mlir][LLVM] Fix conversion of non-standard MLIR float types #122634
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) ChangesCertain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions:
The type converter should not define a type conversion rule for such types. Conversion patterns will no apply to ops with such operand types. Full diff: https://github.com/llvm/llvm-project/pull/122634.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 72799e42cf3fd1..64bdb248dff430 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
}
Type LLVMTypeConverter::convertFloatType(FloatType type) const {
+ // Valid LLVM float types are used directly.
+ if (LLVM::isCompatibleType(type))
+ return type;
+
+ // F4, F6, F8 types are converted to integer types with the same bit width.
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
type.isFloat8E8M0FNU())
return IntegerType::get(&getContext(), type.getWidth());
- return type;
+
+ // Other floating-point types: A custom type conversion rule must be
+ // specified by the user.
+ return Type();
}
// Convert a `ComplexType` to an LLVM type. The result is a complex number
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a9dcc0a16b3dbd..1dabacfd8a47cc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
return %1 : vector<4xf32>
}
+// -----
+
// CHECK-LABEL: @ops
func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
%20 = arith.shrsi %arg2, %arg3 : i32
// CHECK: = llvm.lshr %arg2, %arg3 : i32
%21 = arith.shrui %arg2, %arg3 : i32
+// CHECK: arith.constant 2.000000e+00 : tf32
+ // There is no type conversion rule for tf32.
+ %22 = arith.constant 2.0 : tf32
return %0, %10 : f32, i32
}
+// -----
+
// Checking conversion of index types to integers using i1, assuming no target
// system would have a 1-bit address space. Otherwise, we would have had to
// make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_cast
func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
func.func @index_castui(%arg0: index, %arg1: i1) {
// CHECK: = llvm.trunc %0 : i{{.*}} to i1
%0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_castui
func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @sitofp_vector
func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of unsigned integer types to floating point.
// CHECK-LABEL: @uitofp
func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptosi
func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptosi_vector
func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptoui
func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptoui_vector
func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @uitofp_vector
func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}
+// -----
+
// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
return
}
+// -----
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
return
}
+// -----
+
// CHECK-LABEL: @integer_cast_0d_vector
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
return
}
+// -----
+
// CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
func.func @fcmp(f32, f32) -> () {
^bb0(%arg0: f32, %arg1: f32):
diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 8396e5ad8ade15..22ac6eae73f534 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
return %arg1 : index
}
+// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
+// the func op cannot be converted.
+// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
+// CHECK: llvm.return
+func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
+ return
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesCertain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions:
The type converter should not define a type conversion rule for such types. Conversion patterns will no apply to ops with such operand types. Full diff: https://github.com/llvm/llvm-project/pull/122634.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 72799e42cf3fd1..64bdb248dff430 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
}
Type LLVMTypeConverter::convertFloatType(FloatType type) const {
+ // Valid LLVM float types are used directly.
+ if (LLVM::isCompatibleType(type))
+ return type;
+
+ // F4, F6, F8 types are converted to integer types with the same bit width.
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
type.isFloat8E8M0FNU())
return IntegerType::get(&getContext(), type.getWidth());
- return type;
+
+ // Other floating-point types: A custom type conversion rule must be
+ // specified by the user.
+ return Type();
}
// Convert a `ComplexType` to an LLVM type. The result is a complex number
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index a9dcc0a16b3dbd..1dabacfd8a47cc 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
return %1 : vector<4xf32>
}
+// -----
+
// CHECK-LABEL: @ops
func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
@@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
%20 = arith.shrsi %arg2, %arg3 : i32
// CHECK: = llvm.lshr %arg2, %arg3 : i32
%21 = arith.shrui %arg2, %arg3 : i32
+// CHECK: arith.constant 2.000000e+00 : tf32
+ // There is no type conversion rule for tf32.
+ %22 = arith.constant 2.0 : tf32
return %0, %10 : f32, i32
}
+// -----
+
// Checking conversion of index types to integers using i1, assuming no target
// system would have a 1-bit address space. Otherwise, we would have had to
// make this test dependent on the pointer size on the target system.
@@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_cast
func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
func.func @index_castui(%arg0: index, %arg1: i1) {
// CHECK: = llvm.trunc %0 : i{{.*}} to i1
%0 = arith.index_castui %arg0: index to i1
@@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
return
}
+// -----
+
// CHECK-LABEL: @vector_index_castui
func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
@@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}
+// -----
+
// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
@@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @sitofp_vector
func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of unsigned integer types to floating point.
// CHECK-LABEL: @uitofp
func.func @uitofp(%arg0 : i32, %arg1 : i64) {
@@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
@@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
@@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptosi
func.func @fptosi(%arg0 : f32, %arg1 : f64) {
@@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptosi_vector
func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptoui
func.func @fptoui(%arg0 : f32, %arg1 : f64) {
@@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptoui_vector
func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
@@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @uitofp_vector
func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
@@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
return
}
+// -----
+
// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
@@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}
+// -----
+
// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
@@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
return
}
+// -----
+
// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
@@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
return
}
+// -----
+
// CHECK-LABEL: @integer_cast_0d_vector
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
@@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
return
}
+// -----
+
// CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
func.func @fcmp(f32, f32) -> () {
^bb0(%arg0: f32, %arg1: f32):
diff --git a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
index 8396e5ad8ade15..22ac6eae73f534 100644
--- a/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
+++ b/mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
@@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
return %arg1 : index
}
+// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
+// the func op cannot be converted.
+// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
+// CHECK: llvm.return
+func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
+ return
+}
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Certain non-standard float types were directly passed through in the LLVM type converter, resulting in invalid IR or failed assertions:
The LLVM type converter should not define invalid type conversion rules for such types. If there is no type conversion rule, conversion patterns will not apply to ops with such operand types.