Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][LLVM] Fix conversion of non-standard MLIR float types #122634

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,21 @@ Type LLVMTypeConverter::convertIntegerType(IntegerType type) const {
}

Type LLVMTypeConverter::convertFloatType(FloatType type) const {
// Valid LLVM float types are used directly.
if (LLVM::isCompatibleType(type))
return type;

// F4, F6, F8 types are converted to integer types with the same bit width.
if (type.isFloat8E5M2() || type.isFloat8E4M3() || type.isFloat8E4M3FN() ||
type.isFloat8E5M2FNUZ() || type.isFloat8E4M3FNUZ() ||
type.isFloat8E4M3B11FNUZ() || type.isFloat8E3M4() ||
type.isFloat4E2M1FN() || type.isFloat6E2M3FN() || type.isFloat6E3M2FN() ||
type.isFloat8E8M0FNU())
return IntegerType::get(&getContext(), type.getWidth());
return type;

// Other floating-point types: A custom type conversion rule must be
// specified by the user.
return Type();
}

// Convert a `ComplexType` to an LLVM type. The result is a complex number
Expand Down
45 changes: 45 additions & 0 deletions mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ func.func @vector_ops(%arg0: vector<4xf32>, %arg1: vector<4xi1>, %arg2: vector<4
return %1 : vector<4xf32>
}

// -----

// CHECK-LABEL: @ops
func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
^bb0(%arg0: f32, %arg1: f32, %arg2: i32, %arg3: i32, %arg4: f64):
Expand Down Expand Up @@ -84,9 +86,14 @@ func.func @ops(f32, f32, i32, i32, f64) -> (f32, i32) {
%20 = arith.shrsi %arg2, %arg3 : i32
// CHECK: = llvm.lshr %arg2, %arg3 : i32
%21 = arith.shrui %arg2, %arg3 : i32
// CHECK: arith.constant 2.000000e+00 : tf32
// There is no type conversion rule for tf32.
%22 = arith.constant 2.0 : tf32
return %0, %10 : f32, i32
}

// -----

// Checking conversion of index types to integers using i1, assuming no target
// system would have a 1-bit address space. Otherwise, we would have had to
// make this test dependent on the pointer size on the target system.
Expand All @@ -99,6 +106,8 @@ func.func @index_cast(%arg0: index, %arg1: i1) {
return
}

// -----

// CHECK-LABEL: @vector_index_cast
func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
Expand All @@ -108,6 +117,8 @@ func.func @vector_index_cast(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}

// -----

func.func @index_castui(%arg0: index, %arg1: i1) {
// CHECK: = llvm.trunc %0 : i{{.*}} to i1
%0 = arith.index_castui %arg0: index to i1
Expand All @@ -116,6 +127,8 @@ func.func @index_castui(%arg0: index, %arg1: i1) {
return
}

// -----

// CHECK-LABEL: @vector_index_castui
func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
// CHECK: = llvm.trunc %{{.*}} : vector<2xi{{.*}}> to vector<2xi1>
Expand All @@ -125,6 +138,8 @@ func.func @vector_index_castui(%arg0: vector<2xindex>, %arg1: vector<2xi1>) {
return
}

// -----

// Checking conversion of signed integer types to floating point.
// CHECK-LABEL: @sitofp
func.func @sitofp(%arg0 : i32, %arg1 : i64) {
Expand All @@ -139,6 +154,8 @@ func.func @sitofp(%arg0 : i32, %arg1 : i64) {
return
}

// -----

// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @sitofp_vector
func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
Expand All @@ -157,6 +174,8 @@ func.func @sitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}

// -----

// Checking conversion of unsigned integer types to floating point.
// CHECK-LABEL: @uitofp
func.func @uitofp(%arg0 : i32, %arg1 : i64) {
Expand All @@ -171,6 +190,8 @@ func.func @uitofp(%arg0 : i32, %arg1 : i64) {
return
}

// -----

// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext(%arg0 : f16, %arg1 : f32) {
Expand All @@ -183,6 +204,8 @@ func.func @fpext(%arg0 : f16, %arg1 : f32) {
return
}

// -----

// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fpext
func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
Expand All @@ -195,6 +218,8 @@ func.func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
return
}

// -----

// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptosi
func.func @fptosi(%arg0 : f32, %arg1 : f64) {
Expand All @@ -209,6 +234,8 @@ func.func @fptosi(%arg0 : f32, %arg1 : f64) {
return
}

// -----

// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptosi_vector
func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
Expand All @@ -227,6 +254,8 @@ func.func @fptosi_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}

// -----

// Checking conversion of floating point to integer types.
// CHECK-LABEL: @fptoui
func.func @fptoui(%arg0 : f32, %arg1 : f64) {
Expand All @@ -241,6 +270,8 @@ func.func @fptoui(%arg0 : f32, %arg1 : f64) {
return
}

// -----

// Checking conversion of floating point vectors to integer vector types.
// CHECK-LABEL: @fptoui_vector
func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : vector<2xf64>) {
Expand All @@ -259,6 +290,8 @@ func.func @fptoui_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>, %arg2 : v
return
}

// -----

// Checking conversion of integer vectors to floating point vector types.
// CHECK-LABEL: @uitofp_vector
func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : vector<2xi64>) {
Expand All @@ -277,6 +310,8 @@ func.func @uitofp_vector(%arg0 : vector<2xi16>, %arg1 : vector<2xi32>, %arg2 : v
return
}

// -----

// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
Expand All @@ -289,6 +324,8 @@ func.func @fptrunc(%arg0 : f32, %arg1 : f64) {
return
}

// -----

// Checking conversion of integer types to floating point.
// CHECK-LABEL: @fptrunc
func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
Expand All @@ -301,6 +338,8 @@ func.func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
return
}

// -----

// CHECK-LABEL: experimental_constrained_fptrunc
func.func @experimental_constrained_fptrunc(%arg0 : f64) {
// CHECK-NEXT: = llvm.intr.experimental.constrained.fptrunc {{.*}} tonearest ignore : f64 to f32
Expand All @@ -316,6 +355,8 @@ func.func @experimental_constrained_fptrunc(%arg0 : f64) {
return
}

// -----

// Check sign and zero extension and truncation of integers.
// CHECK-LABEL: @integer_extension_and_truncation
func.func @integer_extension_and_truncation(%arg0 : i3) {
Expand All @@ -328,6 +369,8 @@ func.func @integer_extension_and_truncation(%arg0 : i3) {
return
}

// -----

// CHECK-LABEL: @integer_cast_0d_vector
func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
// CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast
Expand All @@ -340,6 +383,8 @@ func.func @integer_cast_0d_vector(%arg0 : vector<i3>) {
return
}

// -----

// CHECK-LABEL: func @fcmp(%arg0: f32, %arg1: f32) {
func.func @fcmp(f32, f32) -> () {
^bb0(%arg0: f32, %arg1: f32):
Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Conversion/FuncToLLVM/func-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,14 @@ func.func @index_arg(%arg0: index) -> index {
return %arg1 : index
}

// There is no type conversion rule for tf32, so vector<1xtf32> and, therefore,
// the func op cannot be converted.
// CHECK: func.func @non_convertible_arg_type({{.*}}: vector<1xtf32>)
// CHECK: llvm.return
func.func @non_convertible_arg_type(%arg: vector<1xtf32>) {
return
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
Expand Down
Loading