From 0e8a5737dfe49a48a4e9c15ba7a7d24dd2fd7623 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Thu, 3 Oct 2024 11:28:44 -0500 Subject: [PATCH] Adding vm.cast.f32.si64 and vm.cast.f32.ui64 ops (#18642) The corresponding arith dialect ops are legal so we need these implemented. Fixes: https://github.com/iree-org/iree/issues/18501 --- .../VM/Conversion/ArithToVM/Patterns.cpp | 10 +++ .../ArithToVM/test/conversion_ops.mlir | 24 +++++++ .../Conversion/VMToEmitC/ConvertVMToEmitC.cpp | 2 + .../compiler/Dialect/VM/IR/VMOpFolders.cpp | 23 ++++++ .../compiler/Dialect/VM/IR/VMOpcodesF32.td | 4 ++ .../src/iree/compiler/Dialect/VM/IR/VMOps.td | 18 ++++- runtime/src/iree/vm/bytecode/disassembler.c | 24 ++++++- runtime/src/iree/vm/bytecode/dispatch.c | 10 +++ .../vm/bytecode/utils/generated/op_table.h | 8 +-- runtime/src/iree/vm/bytecode/verifier.c | 8 +++ runtime/src/iree/vm/ops.h | 9 +++ .../src/iree/vm/test/conversion_ops_f32.mlir | 70 +++++++++++++++++++ 12 files changed, 202 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp index dc654a0ab987..8e5f96f65f81 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/Patterns.cpp @@ -570,6 +570,11 @@ struct FPToSIOpConversion : public OpConversionPattern { adaptor.getIn()); return success(); } + if (resultType.isSignlessInteger(64) || resultType.isSignedInteger(64)) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + adaptor.getIn()); + return success(); + } } return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } @@ -589,6 +594,11 @@ struct FPToUIOpConversion : public OpConversionPattern { adaptor.getIn()); return success(); } + if (dstType.isSignlessInteger(64) || dstType.isUnsignedInteger(64)) { + rewriter.replaceOpWithNewOp(srcOp, resultType, + adaptor.getIn()); + return success(); + } } return rewriter.notifyMatchFailure(srcOp, "unsupported type"); } diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir index eecc1557c0e3..be4ec1f83b87 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ArithToVM/test/conversion_ops.mlir @@ -324,6 +324,18 @@ module @fptosi_fp32_i32 { // ----- +// CHECK-LABEL: @fptosi_fp32_i64 +module @fptosi_fp32_i64 { + // CHECK: vm.func private @fn(%[[ARG0:.+]]: f32) + func.func @fn(%arg0: f32) -> i64 { + // CHECK: vm.cast.f32.si64 %[[ARG0]] : f32 -> i64 + %0 = arith.fptosi %arg0 : f32 to i64 + return %0 : i64 + } +} + +// ----- + // expected-error@+1 {{conversion to vm.module failed}} module @fptoui_fp32_i8 { func.func @fn(%arg0: f32) -> i8 { @@ -347,6 +359,18 @@ module @fptoui_fp32_i32 { // ----- +// CHECK-LABEL: @fptoui_fp32_i64 +module @fptoui_fp32_i64 { + // CHECK: vm.func private @fn(%[[ARG0:.+]]: f32) + func.func @fn(%arg0: f32) -> i64 { + // CHECK: vm.cast.f32.ui64 %[[ARG0]] : f32 -> i64 + %0 = arith.fptoui %arg0 : f32 to i64 + return %0 : i64 + } +} + +// ----- + // CHECK-LABEL: @bitcast_i32_f32 module @bitcast_i32_f32 { // CHECK: vm.func private @fn(%[[ARG0:.+]]: i32) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp index 1967a4c1894d..0215ad93f4fb 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/VMToEmitC/ConvertVMToEmitC.cpp @@ -4449,7 +4449,9 @@ void populateVMToEmitCPatterns(ConversionTarget &conversionTarget, ADD_GENERIC_PATTERN(IREE::VM::BitcastF32I32Op, "vm_bitcast_f32i32"); ADD_GENERIC_PATTERN(IREE::VM::BitcastI32F32Op, "vm_bitcast_i32f32"); ADD_GENERIC_PATTERN(IREE::VM::CastF32SI32Op, "vm_cast_f32si32"); + ADD_GENERIC_PATTERN(IREE::VM::CastF32SI64Op, "vm_cast_f32si64"); ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32"); + ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64"); ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32"); ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32"); ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32"); diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp index 9beb75fce9c1..bac70cdca819 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp @@ -1724,6 +1724,17 @@ OpFoldResult CastF32SI32Op::fold(FoldAdaptor operands) { }); } +OpFoldResult CastF32SI64Op::fold(FoldAdaptor operands) { + return constFoldCastOp( + IntegerType::get(getContext(), 64), operands.getOperand(), + [&](const APFloat &a) { + bool isExact = false; + llvm::APSInt b(/*BitWidth=*/64, /*isUnsigned=*/false); + a.convertToInteger(b, APFloat::rmNearestTiesToAway, &isExact); + return b; + }); +} + OpFoldResult CastF32UI32Op::fold(FoldAdaptor operands) { return constFoldCastOp( IntegerType::get(getContext(), 32), operands.getOperand(), @@ -1736,6 +1747,18 @@ OpFoldResult CastF32UI32Op::fold(FoldAdaptor operands) { }); } +OpFoldResult CastF32UI64Op::fold(FoldAdaptor operands) { + return constFoldCastOp( + IntegerType::get(getContext(), 64), operands.getOperand(), + [&](const APFloat &a) { + bool isExact = false; + llvm::APSInt b(/*BitWidth=*/64, /*isUnsigned=*/false); + a.convertToInteger(b, APFloat::rmNearestTiesToAway, &isExact); + b.setIsUnsigned(true); + return b; + }); +} + OpFoldResult CastF64SI64Op::fold(FoldAdaptor operands) { return constFoldCastOp( IntegerType::get(getContext(), 64), operands.getOperand(), diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td index c5ef64f64ec0..af9295f165f4 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td @@ -47,7 +47,9 @@ def VM_OPC_MaxF32 : VM_OPC<0x38, "MaxF32">; def VM_OPC_CastSI32F32 : VM_OPC<0x14, "CastSI32F32">; def VM_OPC_CastUI32F32 : VM_OPC<0x15, "CastUI32F32">; def VM_OPC_CastF32SI32 : VM_OPC<0x16, "CastF32SI32">; +def VM_OPC_CastF32SI64 : VM_OPC<0x3A, "CastF32SI64">; def VM_OPC_CastF32UI32 : VM_OPC<0x17, "CastF32UI32">; +def VM_OPC_CastF32UI64 : VM_OPC<0x3B, "CastF32UI64">; def VM_OPC_BitcastI32F32 : VM_OPC<0x18, "BitcastI32F32">; def VM_OPC_BitcastF32I32 : VM_OPC<0x19, "BitcastF32I32">; @@ -120,7 +122,9 @@ def VM_ExtF32OpcodeAttr : VM_OPC_CastSI32F32, VM_OPC_CastUI32F32, VM_OPC_CastF32SI32, + VM_OPC_CastF32SI64, VM_OPC_CastF32UI32, + VM_OPC_CastF32UI64, VM_OPC_BitcastI32F32, VM_OPC_BitcastF32I32, diff --git a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td index 63448d8304f4..6e9899d53a85 100644 --- a/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td +++ b/compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td @@ -3178,14 +3178,28 @@ def VM_CastUI64F64Op : def VM_CastF32SI32Op : VM_ConversionOp { - let summary = [{cast from a float-point value to a signed integer}]; + let summary = [{cast from a float-point value to a signed 32-bit integer}]; + let hasFolder = 1; +} + +def VM_CastF32SI64Op : + VM_ConversionOp { + let summary = [{cast from a float-point value to a signed 64-bit integer}]; let hasFolder = 1; } def VM_CastF32UI32Op : VM_ConversionOp { - let summary = [{cast from an float-point value to an unsigned integer}]; + let summary = [{cast from an float-point value to an unsigned 32-bit integer}]; + let hasFolder = 1; +} + +def VM_CastF32UI64Op : + VM_ConversionOp { + let summary = [{cast from an float-point value to an unsigned 64-bit integer}]; let hasFolder = 1; } diff --git a/runtime/src/iree/vm/bytecode/disassembler.c b/runtime/src/iree/vm/bytecode/disassembler.c index 950bb2c80d38..02d93e9070e6 100644 --- a/runtime/src/iree/vm/bytecode/disassembler.c +++ b/runtime/src/iree/vm/bytecode/disassembler.c @@ -2126,7 +2126,17 @@ iree_status_t iree_vm_bytecode_disassemble_op( uint16_t result_reg = VM_ParseResultRegI32("result"); EMIT_I32_REG_NAME(result_reg); IREE_RETURN_IF_ERROR( - iree_string_builder_append_cstring(b, " = vm.cast.f32.sif32 ")); + iree_string_builder_append_cstring(b, " = vm.cast.f32.si32 ")); + EMIT_F32_REG_NAME(operand_reg); + EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]); + break; + } + DISASM_OP(EXT_F32, CastF32SI64) { + uint16_t operand_reg = VM_ParseOperandRegF32("operand"); + uint16_t result_reg = VM_ParseResultRegI64("result"); + EMIT_I64_REG_NAME(result_reg); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_cstring(b, " = vm.cast.f32.si64 ")); EMIT_F32_REG_NAME(operand_reg); EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]); break; @@ -2136,7 +2146,17 @@ iree_status_t iree_vm_bytecode_disassemble_op( uint16_t result_reg = VM_ParseResultRegI32("result"); EMIT_I32_REG_NAME(result_reg); IREE_RETURN_IF_ERROR( - iree_string_builder_append_cstring(b, " = vm.cast.f32.uif32 ")); + iree_string_builder_append_cstring(b, " = vm.cast.f32.ui32 ")); + EMIT_F32_REG_NAME(operand_reg); + EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]); + break; + } + DISASM_OP(EXT_F32, CastF32UI64) { + uint16_t operand_reg = VM_ParseOperandRegF32("operand"); + uint16_t result_reg = VM_ParseResultRegI64("result"); + EMIT_I64_REG_NAME(result_reg); + IREE_RETURN_IF_ERROR( + iree_string_builder_append_cstring(b, " = vm.cast.f32.ui64 ")); EMIT_F32_REG_NAME(operand_reg); EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]); break; diff --git a/runtime/src/iree/vm/bytecode/dispatch.c b/runtime/src/iree/vm/bytecode/dispatch.c index d3aea379c3bd..40ae195b660d 100644 --- a/runtime/src/iree/vm/bytecode/dispatch.c +++ b/runtime/src/iree/vm/bytecode/dispatch.c @@ -2056,11 +2056,21 @@ static iree_status_t iree_vm_bytecode_dispatch( int32_t* result = VM_DecResultRegI32("result"); *result = vm_cast_f32si32(operand); }); + DISPATCH_OP(EXT_F32, CastF32SI64, { + float operand = VM_DecOperandRegF32("operand"); + int64_t* result = VM_DecResultRegI64("result"); + *result = vm_cast_f32si64(operand); + }); DISPATCH_OP(EXT_F32, CastF32UI32, { float operand = VM_DecOperandRegF32("operand"); int32_t* result = VM_DecResultRegI32("result"); *result = vm_cast_f32ui32(operand); }); + DISPATCH_OP(EXT_F32, CastF32UI64, { + float operand = VM_DecOperandRegF32("operand"); + int64_t* result = VM_DecResultRegI64("result"); + *result = vm_cast_f32ui64(operand); + }); DISPATCH_OP(EXT_F32, BitcastI32F32, { int32_t operand = (int32_t)VM_DecOperandRegI32("operand"); float* result = VM_DecResultRegF32("result"); diff --git a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h index 1e9e7b47a988..2a5a76c0d7ab 100644 --- a/runtime/src/iree/vm/bytecode/utils/generated/op_table.h +++ b/runtime/src/iree/vm/bytecode/utils/generated/op_table.h @@ -582,8 +582,8 @@ typedef enum { IREE_VM_OP_EXT_F32_MinF32 = 0x37, IREE_VM_OP_EXT_F32_MaxF32 = 0x38, IREE_VM_OP_EXT_F32_RoundF32Even = 0x39, - IREE_VM_OP_EXT_F32_RSV_0x3A, - IREE_VM_OP_EXT_F32_RSV_0x3B, + IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A, + IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B, IREE_VM_OP_EXT_F32_RSV_0x3C, IREE_VM_OP_EXT_F32_RSV_0x3D, IREE_VM_OP_EXT_F32_RSV_0x3E, @@ -841,8 +841,8 @@ typedef enum { OPC(0x37, MinF32) \ OPC(0x38, MaxF32) \ OPC(0x39, RoundF32Even) \ - RSV(0x3A) \ - RSV(0x3B) \ + OPC(0x3A, CastF32SI64) \ + OPC(0x3B, CastF32UI64) \ RSV(0x3C) \ RSV(0x3D) \ RSV(0x3E) \ diff --git a/runtime/src/iree/vm/bytecode/verifier.c b/runtime/src/iree/vm/bytecode/verifier.c index 2c1c62679cd6..c5b9d635f220 100644 --- a/runtime/src/iree/vm/bytecode/verifier.c +++ b/runtime/src/iree/vm/bytecode/verifier.c @@ -1831,10 +1831,18 @@ static iree_status_t iree_vm_bytecode_function_verify_bytecode_op( VM_VerifyOperandRegF32(operand); VM_VerifyResultRegI32(result); }); + VERIFY_OP(EXT_F32, CastF32SI64, { + VM_VerifyOperandRegF32(operand); + VM_VerifyResultRegI64(result); + }); VERIFY_OP(EXT_F32, CastF32UI32, { VM_VerifyOperandRegF32(operand); VM_VerifyResultRegI32(result); }); + VERIFY_OP(EXT_F32, CastF32UI64, { + VM_VerifyOperandRegF32(operand); + VM_VerifyResultRegI64(result); + }); VERIFY_OP(EXT_F32, BitcastI32F32, { VM_VerifyOperandRegI32(operand); VM_VerifyResultRegF32(result); diff --git a/runtime/src/iree/vm/ops.h b/runtime/src/iree/vm/ops.h index d9d8bdc8bead..b9ffd70122da 100644 --- a/runtime/src/iree/vm/ops.h +++ b/runtime/src/iree/vm/ops.h @@ -605,9 +605,18 @@ static inline float vm_cast_ui32f32(int32_t operand) { static inline int32_t vm_cast_f32si32(float operand) { return (int32_t)lroundf(operand); } +static inline int64_t vm_cast_f32si64(float operand) { + return (int64_t)llroundf(operand); +} static inline int32_t vm_cast_f32ui32(float operand) { return (uint32_t)llroundf(operand); } +static inline int64_t vm_cast_f32ui64(float operand) { + // `llroundf` used in other casts above only has a range from INT64_MIN + // to INT64_MAX however here we need a range of 0 to UINT64_MAX, hence we do + // rounding in `float` with `roundf` and then cast to `uint64_t`. + return (uint64_t)roundf(operand); +} static inline float vm_bitcast_i32f32(int32_t operand) { float result; memcpy(&result, &operand, sizeof(result)); diff --git a/runtime/src/iree/vm/test/conversion_ops_f32.mlir b/runtime/src/iree/vm/test/conversion_ops_f32.mlir index 834f4b8b1d43..bb893f77ddbf 100644 --- a/runtime/src/iree/vm/test/conversion_ops_f32.mlir +++ b/runtime/src/iree/vm/test/conversion_ops_f32.mlir @@ -96,6 +96,54 @@ vm.module @conversion_ops_f32 { vm.return } + vm.export @test_cast_f32_si64_int_max + vm.func @test_cast_f32_si64_int_max() { + // This is the maximum value that is representable precisely as both i64 + // and f32. An exponent of 62 with all mantissa bits set. + %c1 = vm.const.f32 0x5effffff + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.si64 %c1dno : f32 -> i64 + %c2 = vm.const.i64 0x7FFFFF8000000000 + vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 + vm.return + } + + vm.export @test_cast_f32_si64_int_min + vm.func @test_cast_f32_si64_int_min() { + %c1 = vm.const.f32 -9223372036854775808.0 + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.si64 %c1dno : f32 -> i64 + // Directly providing the true INT64_MIN of -9223372036854775808 + // gives an error so we do -(INT64_MAX) - 1 + // See: https://stackoverflow.com/a/65008288 + %c2 = vm.const.i64 -9223372036854775807 + %c2dno = util.optimization_barrier %c2 : i64 + %c3 = vm.const.i64 1 + %c4 = vm.sub.i64 %c2dno, %c3 : i64 + vm.check.eq %v, %c4, "cast floating-point value to a signed integer" : i64 + vm.return + } + + vm.export @test_cast_f32_si64_away_from_zero_pos + vm.func @test_cast_f32_si64_away_from_zero_pos() { + %c1 = vm.const.f32 2.5 + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.si64 %c1dno : f32 -> i64 + %c2 = vm.const.i64 3 + vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 + vm.return + } + + vm.export @test_cast_f32_si64_away_from_zero_neg + vm.func @test_cast_f32_si64_away_from_zero_neg() { + %c1 = vm.const.f32 -2.5 + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.si64 %c1dno : f32 -> i64 + %c2 = vm.const.i64 -3 + vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 + vm.return + } + vm.export @test_cast_f32_ui32_int_big vm.func @test_cast_f32_ui32_int_big() { // This is the maximum value that is representable precisely as both ui32 @@ -118,4 +166,26 @@ vm.module @conversion_ops_f32 { vm.return } + vm.export @test_cast_f32_ui64_int_big + vm.func @test_cast_f32_ui64_int_big() { + // This is the maximum value that is representable precisely as both ui64 + // and f32. An exponent of 63 with all mantissa bits set. + %c1 = vm.const.f32 0x5F7FFFFF + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 + %c2 = vm.const.i64 0xFFFFFF0000000000 + vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i64 + vm.return + } + + vm.export @test_cast_f32_ui64_away_from_zero + vm.func @test_cast_f32_ui64_away_from_zero() { + %c1 = vm.const.f32 2.5 + %c1dno = util.optimization_barrier %c1 : f32 + %v = vm.cast.f32.ui64 %c1dno : f32 -> i64 + %c2 = vm.const.i64 3 + vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64 + vm.return + } + }