Skip to content

Commit

Permalink
Adding vm.cast.f32.si64 and vm.cast.f32.ui64 ops (iree-org#18642)
Browse files Browse the repository at this point in the history
The corresponding arith dialect ops are legal so we need these
implemented.
Fixes: iree-org#18501
  • Loading branch information
nirvedhmeshram authored Oct 3, 2024
1 parent a6043e2 commit 0e8a573
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,11 @@ struct FPToSIOpConversion : public OpConversionPattern<arith::FPToSIOp> {
adaptor.getIn());
return success();
}
if (resultType.isSignlessInteger(64) || resultType.isSignedInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::CastF32SI64Op>(srcOp, resultType,
adaptor.getIn());
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
Expand All @@ -589,6 +594,11 @@ struct FPToUIOpConversion : public OpConversionPattern<arith::FPToUIOp> {
adaptor.getIn());
return success();
}
if (dstType.isSignlessInteger(64) || dstType.isUnsignedInteger(64)) {
rewriter.replaceOpWithNewOp<IREE::VM::CastF32UI64Op>(srcOp, resultType,
adaptor.getIn());
return success();
}
}
return rewriter.notifyMatchFailure(srcOp, "unsupported type");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,18 @@ module @fptosi_fp32_i32 {

// -----

// CHECK-LABEL: @fptosi_fp32_i64
module @fptosi_fp32_i64 {
// CHECK: vm.func private @fn(%[[ARG0:.+]]: f32)
func.func @fn(%arg0: f32) -> i64 {
// CHECK: vm.cast.f32.si64 %[[ARG0]] : f32 -> i64
%0 = arith.fptosi %arg0 : f32 to i64
return %0 : i64
}
}

// -----

// expected-error@+1 {{conversion to vm.module failed}}
module @fptoui_fp32_i8 {
func.func @fn(%arg0: f32) -> i8 {
Expand All @@ -347,6 +359,18 @@ module @fptoui_fp32_i32 {

// -----

// CHECK-LABEL: @fptoui_fp32_i64
module @fptoui_fp32_i64 {
// CHECK: vm.func private @fn(%[[ARG0:.+]]: f32)
func.func @fn(%arg0: f32) -> i64 {
// CHECK: vm.cast.f32.ui64 %[[ARG0]] : f32 -> i64
%0 = arith.fptoui %arg0 : f32 to i64
return %0 : i64
}
}

// -----

// CHECK-LABEL: @bitcast_i32_f32
module @bitcast_i32_f32 {
// CHECK: vm.func private @fn(%[[ARG0:.+]]: i32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4449,7 +4449,9 @@ void populateVMToEmitCPatterns(ConversionTarget &conversionTarget,
ADD_GENERIC_PATTERN(IREE::VM::BitcastF32I32Op, "vm_bitcast_f32i32");
ADD_GENERIC_PATTERN(IREE::VM::BitcastI32F32Op, "vm_bitcast_i32f32");
ADD_GENERIC_PATTERN(IREE::VM::CastF32SI32Op, "vm_cast_f32si32");
ADD_GENERIC_PATTERN(IREE::VM::CastF32SI64Op, "vm_cast_f32si64");
ADD_GENERIC_PATTERN(IREE::VM::CastF32UI32Op, "vm_cast_f32ui32");
ADD_GENERIC_PATTERN(IREE::VM::CastF32UI64Op, "vm_cast_f32ui64");
ADD_GENERIC_PATTERN(IREE::VM::CastSI32F32Op, "vm_cast_si32f32");
ADD_GENERIC_PATTERN(IREE::VM::CastUI32F32Op, "vm_cast_ui32f32");
ADD_GENERIC_PATTERN(IREE::VM::CeilF32Op, "vm_ceil_f32");
Expand Down
23 changes: 23 additions & 0 deletions compiler/src/iree/compiler/Dialect/VM/IR/VMOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,17 @@ OpFoldResult CastF32SI32Op::fold(FoldAdaptor operands) {
});
}

OpFoldResult CastF32SI64Op::fold(FoldAdaptor operands) {
return constFoldCastOp<FloatAttr, IntegerAttr>(
IntegerType::get(getContext(), 64), operands.getOperand(),
[&](const APFloat &a) {
bool isExact = false;
llvm::APSInt b(/*BitWidth=*/64, /*isUnsigned=*/false);
a.convertToInteger(b, APFloat::rmNearestTiesToAway, &isExact);
return b;
});
}

OpFoldResult CastF32UI32Op::fold(FoldAdaptor operands) {
return constFoldCastOp<FloatAttr, IntegerAttr>(
IntegerType::get(getContext(), 32), operands.getOperand(),
Expand All @@ -1736,6 +1747,18 @@ OpFoldResult CastF32UI32Op::fold(FoldAdaptor operands) {
});
}

OpFoldResult CastF32UI64Op::fold(FoldAdaptor operands) {
return constFoldCastOp<FloatAttr, IntegerAttr>(
IntegerType::get(getContext(), 64), operands.getOperand(),
[&](const APFloat &a) {
bool isExact = false;
llvm::APSInt b(/*BitWidth=*/64, /*isUnsigned=*/false);
a.convertToInteger(b, APFloat::rmNearestTiesToAway, &isExact);
b.setIsUnsigned(true);
return b;
});
}

OpFoldResult CastF64SI64Op::fold(FoldAdaptor operands) {
return constFoldCastOp<FloatAttr, IntegerAttr>(
IntegerType::get(getContext(), 64), operands.getOperand(),
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Dialect/VM/IR/VMOpcodesF32.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def VM_OPC_MaxF32 : VM_OPC<0x38, "MaxF32">;
def VM_OPC_CastSI32F32 : VM_OPC<0x14, "CastSI32F32">;
def VM_OPC_CastUI32F32 : VM_OPC<0x15, "CastUI32F32">;
def VM_OPC_CastF32SI32 : VM_OPC<0x16, "CastF32SI32">;
def VM_OPC_CastF32SI64 : VM_OPC<0x3A, "CastF32SI64">;
def VM_OPC_CastF32UI32 : VM_OPC<0x17, "CastF32UI32">;
def VM_OPC_CastF32UI64 : VM_OPC<0x3B, "CastF32UI64">;
def VM_OPC_BitcastI32F32 : VM_OPC<0x18, "BitcastI32F32">;
def VM_OPC_BitcastF32I32 : VM_OPC<0x19, "BitcastF32I32">;

Expand Down Expand Up @@ -120,7 +122,9 @@ def VM_ExtF32OpcodeAttr :
VM_OPC_CastSI32F32,
VM_OPC_CastUI32F32,
VM_OPC_CastF32SI32,
VM_OPC_CastF32SI64,
VM_OPC_CastF32UI32,
VM_OPC_CastF32UI64,
VM_OPC_BitcastI32F32,
VM_OPC_BitcastF32I32,

Expand Down
18 changes: 16 additions & 2 deletions compiler/src/iree/compiler/Dialect/VM/IR/VMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3178,14 +3178,28 @@ def VM_CastUI64F64Op :
def VM_CastF32SI32Op :
VM_ConversionOp<F32, I32, "cast.f32.si32", VM_OPC_CastF32SI32,
[VM_ExtF32]> {
let summary = [{cast from a float-point value to a signed integer}];
let summary = [{cast from a float-point value to a signed 32-bit integer}];
let hasFolder = 1;
}

def VM_CastF32SI64Op :
VM_ConversionOp<F32, I64, "cast.f32.si64", VM_OPC_CastF32SI64,
[VM_ExtF32]> {
let summary = [{cast from a float-point value to a signed 64-bit integer}];
let hasFolder = 1;
}

def VM_CastF32UI32Op :
VM_ConversionOp<F32, I32, "cast.f32.ui32", VM_OPC_CastF32UI32,
[VM_ExtF32]> {
let summary = [{cast from an float-point value to an unsigned integer}];
let summary = [{cast from an float-point value to an unsigned 32-bit integer}];
let hasFolder = 1;
}

def VM_CastF32UI64Op :
VM_ConversionOp<F32, I64, "cast.f32.ui64", VM_OPC_CastF32UI64,
[VM_ExtF32]> {
let summary = [{cast from an float-point value to an unsigned 64-bit integer}];
let hasFolder = 1;
}

Expand Down
24 changes: 22 additions & 2 deletions runtime/src/iree/vm/bytecode/disassembler.c
Original file line number Diff line number Diff line change
Expand Up @@ -2126,7 +2126,17 @@ iree_status_t iree_vm_bytecode_disassemble_op(
uint16_t result_reg = VM_ParseResultRegI32("result");
EMIT_I32_REG_NAME(result_reg);
IREE_RETURN_IF_ERROR(
iree_string_builder_append_cstring(b, " = vm.cast.f32.sif32 "));
iree_string_builder_append_cstring(b, " = vm.cast.f32.si32 "));
EMIT_F32_REG_NAME(operand_reg);
EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]);
break;
}
DISASM_OP(EXT_F32, CastF32SI64) {
uint16_t operand_reg = VM_ParseOperandRegF32("operand");
uint16_t result_reg = VM_ParseResultRegI64("result");
EMIT_I64_REG_NAME(result_reg);
IREE_RETURN_IF_ERROR(
iree_string_builder_append_cstring(b, " = vm.cast.f32.si64 "));
EMIT_F32_REG_NAME(operand_reg);
EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]);
break;
Expand All @@ -2136,7 +2146,17 @@ iree_status_t iree_vm_bytecode_disassemble_op(
uint16_t result_reg = VM_ParseResultRegI32("result");
EMIT_I32_REG_NAME(result_reg);
IREE_RETURN_IF_ERROR(
iree_string_builder_append_cstring(b, " = vm.cast.f32.uif32 "));
iree_string_builder_append_cstring(b, " = vm.cast.f32.ui32 "));
EMIT_F32_REG_NAME(operand_reg);
EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]);
break;
}
DISASM_OP(EXT_F32, CastF32UI64) {
uint16_t operand_reg = VM_ParseOperandRegF32("operand");
uint16_t result_reg = VM_ParseResultRegI64("result");
EMIT_I64_REG_NAME(result_reg);
IREE_RETURN_IF_ERROR(
iree_string_builder_append_cstring(b, " = vm.cast.f32.ui64 "));
EMIT_F32_REG_NAME(operand_reg);
EMIT_OPTIONAL_VALUE_F32(regs->i32[operand_reg]);
break;
Expand Down
10 changes: 10 additions & 0 deletions runtime/src/iree/vm/bytecode/dispatch.c
Original file line number Diff line number Diff line change
Expand Up @@ -2056,11 +2056,21 @@ static iree_status_t iree_vm_bytecode_dispatch(
int32_t* result = VM_DecResultRegI32("result");
*result = vm_cast_f32si32(operand);
});
DISPATCH_OP(EXT_F32, CastF32SI64, {
float operand = VM_DecOperandRegF32("operand");
int64_t* result = VM_DecResultRegI64("result");
*result = vm_cast_f32si64(operand);
});
DISPATCH_OP(EXT_F32, CastF32UI32, {
float operand = VM_DecOperandRegF32("operand");
int32_t* result = VM_DecResultRegI32("result");
*result = vm_cast_f32ui32(operand);
});
DISPATCH_OP(EXT_F32, CastF32UI64, {
float operand = VM_DecOperandRegF32("operand");
int64_t* result = VM_DecResultRegI64("result");
*result = vm_cast_f32ui64(operand);
});
DISPATCH_OP(EXT_F32, BitcastI32F32, {
int32_t operand = (int32_t)VM_DecOperandRegI32("operand");
float* result = VM_DecResultRegF32("result");
Expand Down
8 changes: 4 additions & 4 deletions runtime/src/iree/vm/bytecode/utils/generated/op_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@ typedef enum {
IREE_VM_OP_EXT_F32_MinF32 = 0x37,
IREE_VM_OP_EXT_F32_MaxF32 = 0x38,
IREE_VM_OP_EXT_F32_RoundF32Even = 0x39,
IREE_VM_OP_EXT_F32_RSV_0x3A,
IREE_VM_OP_EXT_F32_RSV_0x3B,
IREE_VM_OP_EXT_F32_CastF32SI64 = 0x3A,
IREE_VM_OP_EXT_F32_CastF32UI64 = 0x3B,
IREE_VM_OP_EXT_F32_RSV_0x3C,
IREE_VM_OP_EXT_F32_RSV_0x3D,
IREE_VM_OP_EXT_F32_RSV_0x3E,
Expand Down Expand Up @@ -841,8 +841,8 @@ typedef enum {
OPC(0x37, MinF32) \
OPC(0x38, MaxF32) \
OPC(0x39, RoundF32Even) \
RSV(0x3A) \
RSV(0x3B) \
OPC(0x3A, CastF32SI64) \
OPC(0x3B, CastF32UI64) \
RSV(0x3C) \
RSV(0x3D) \
RSV(0x3E) \
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/iree/vm/bytecode/verifier.c
Original file line number Diff line number Diff line change
Expand Up @@ -1831,10 +1831,18 @@ static iree_status_t iree_vm_bytecode_function_verify_bytecode_op(
VM_VerifyOperandRegF32(operand);
VM_VerifyResultRegI32(result);
});
VERIFY_OP(EXT_F32, CastF32SI64, {
VM_VerifyOperandRegF32(operand);
VM_VerifyResultRegI64(result);
});
VERIFY_OP(EXT_F32, CastF32UI32, {
VM_VerifyOperandRegF32(operand);
VM_VerifyResultRegI32(result);
});
VERIFY_OP(EXT_F32, CastF32UI64, {
VM_VerifyOperandRegF32(operand);
VM_VerifyResultRegI64(result);
});
VERIFY_OP(EXT_F32, BitcastI32F32, {
VM_VerifyOperandRegI32(operand);
VM_VerifyResultRegF32(result);
Expand Down
9 changes: 9 additions & 0 deletions runtime/src/iree/vm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,9 +605,18 @@ static inline float vm_cast_ui32f32(int32_t operand) {
static inline int32_t vm_cast_f32si32(float operand) {
return (int32_t)lroundf(operand);
}
static inline int64_t vm_cast_f32si64(float operand) {
return (int64_t)llroundf(operand);
}
static inline int32_t vm_cast_f32ui32(float operand) {
return (uint32_t)llroundf(operand);
}
static inline int64_t vm_cast_f32ui64(float operand) {
// `llroundf` used in other casts above only has a range from INT64_MIN
// to INT64_MAX however here we need a range of 0 to UINT64_MAX, hence we do
// rounding in `float` with `roundf` and then cast to `uint64_t`.
return (uint64_t)roundf(operand);
}
static inline float vm_bitcast_i32f32(int32_t operand) {
float result;
memcpy(&result, &operand, sizeof(result));
Expand Down
70 changes: 70 additions & 0 deletions runtime/src/iree/vm/test/conversion_ops_f32.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,54 @@ vm.module @conversion_ops_f32 {
vm.return
}

vm.export @test_cast_f32_si64_int_max
vm.func @test_cast_f32_si64_int_max() {
// This is the maximum value that is representable precisely as both i64
// and f32. An exponent of 62 with all mantissa bits set.
%c1 = vm.const.f32 0x5effffff
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.si64 %c1dno : f32 -> i64
%c2 = vm.const.i64 0x7FFFFF8000000000
vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64
vm.return
}

vm.export @test_cast_f32_si64_int_min
vm.func @test_cast_f32_si64_int_min() {
%c1 = vm.const.f32 -9223372036854775808.0
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.si64 %c1dno : f32 -> i64
// Directly providing the true INT64_MIN of -9223372036854775808
// gives an error so we do -(INT64_MAX) - 1
// See: https://stackoverflow.com/a/65008288
%c2 = vm.const.i64 -9223372036854775807
%c2dno = util.optimization_barrier %c2 : i64
%c3 = vm.const.i64 1
%c4 = vm.sub.i64 %c2dno, %c3 : i64
vm.check.eq %v, %c4, "cast floating-point value to a signed integer" : i64
vm.return
}

vm.export @test_cast_f32_si64_away_from_zero_pos
vm.func @test_cast_f32_si64_away_from_zero_pos() {
%c1 = vm.const.f32 2.5
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.si64 %c1dno : f32 -> i64
%c2 = vm.const.i64 3
vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64
vm.return
}

vm.export @test_cast_f32_si64_away_from_zero_neg
vm.func @test_cast_f32_si64_away_from_zero_neg() {
%c1 = vm.const.f32 -2.5
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.si64 %c1dno : f32 -> i64
%c2 = vm.const.i64 -3
vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64
vm.return
}

vm.export @test_cast_f32_ui32_int_big
vm.func @test_cast_f32_ui32_int_big() {
// This is the maximum value that is representable precisely as both ui32
Expand All @@ -118,4 +166,26 @@ vm.module @conversion_ops_f32 {
vm.return
}

vm.export @test_cast_f32_ui64_int_big
vm.func @test_cast_f32_ui64_int_big() {
// This is the maximum value that is representable precisely as both ui64
// and f32. An exponent of 63 with all mantissa bits set.
%c1 = vm.const.f32 0x5F7FFFFF
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.ui64 %c1dno : f32 -> i64
%c2 = vm.const.i64 0xFFFFFF0000000000
vm.check.eq %v, %c2, "cast floating-point value to an unsigned integer" : i64
vm.return
}

vm.export @test_cast_f32_ui64_away_from_zero
vm.func @test_cast_f32_ui64_away_from_zero() {
%c1 = vm.const.f32 2.5
%c1dno = util.optimization_barrier %c1 : f32
%v = vm.cast.f32.ui64 %c1dno : f32 -> i64
%c2 = vm.const.i64 3
vm.check.eq %v, %c2, "cast floating-point value to a signed integer" : i64
vm.return
}

}

0 comments on commit 0e8a573

Please sign in to comment.