Skip to content

Commit

Permalink
[FFI][Runtime] Use TVMValue::v_int64 to represent boolean values (#17240
Browse files Browse the repository at this point in the history
)

* [FFI][Runtime] Use TVMValue::v_int64 to represent boolean values

This is a follow-up to #16183, which
added handling of boolean values in the TVM FFI.  The initial
implementation added both a new type code (`kTVMArgBool`) and a new
`TVMValue::v_bool` variant.  This commit removes the
`TVMValue::v_bool` variant, since the `kTVMArgBool` type code is
sufficient to handle boolean arguments.

Removing the `TVMValue::v_bool` variant also makes all `TVMValue`
variants be 64-bit (assuming a 64-bit CPU).  This can simplify
debugging in some cases, since it prevents partial values from
inactive variants from being present in memory.

* Update MakePackedAPI, less special handling required for boolean
  • Loading branch information
Lunderberg authored Aug 22, 2024
1 parent 20289e8 commit 0f037a6
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 41 deletions.
1 change: 0 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ typedef DLTensor* TVMArrayHandle;
*/
typedef union {
int64_t v_int64;
bool v_bool;
double v_float64;
void* v_handle;
const char* v_str;
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ class TVMPODValue_ {
// conversions. This is publicly exposed, as it can be useful in
// specializations of PackedFuncValueConverter.
if (type_code_ == kTVMArgBool) {
return value_.v_bool;
return static_cast<bool>(value_.v_int64);
} else {
return std::nullopt;
}
Expand Down Expand Up @@ -1041,7 +1041,7 @@ class TVMRetValue : public TVMPODValue_CRTP_<TVMRetValue> {
TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kTVMArgBool);
value_.v_bool = value;
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(std::string value) {
Expand Down Expand Up @@ -1831,7 +1831,7 @@ class TVMArgsSetter {
type_codes_[i] = kDLInt;
}
TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const {
values_[i].v_bool = value;
values_[i].v_int64 = value;
type_codes_[i] = kTVMArgBool;
}
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
Expand Down Expand Up @@ -2142,7 +2142,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
std::is_base_of_v<ContainerType, Bool::ContainerType>) {
if (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
ptr->IsInstance<Bool::ContainerType>()) {
values_[i].v_bool = static_cast<Bool::ContainerType*>(ptr)->value;
values_[i].v_int64 = static_cast<Bool::ContainerType*>(ptr)->value;
type_codes_[i] = kTVMArgBool;
return;
}
Expand Down Expand Up @@ -2327,7 +2327,7 @@ inline TObjectRef TVMPODValue_CRTP_<Derived>::AsObjectRef() const {

if constexpr (std::is_base_of_v<TObjectRef, Bool>) {
if (type_code_ == kTVMArgBool) {
return Bool(value_.v_bool);
return Bool(value_.v_int64);
}
}

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/_ffi/_cython/packed_func.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, bool):
# A python `bool` is a subclass of `int`, so this check
# must occur before `Integral`.
value[0].v_bool = arg
value[0].v_int64 = arg
tcode[0] = kTVMArgBool
elif isinstance(arg, Integral):
value[0].v_int64 = arg
Expand Down Expand Up @@ -215,7 +215,7 @@ cdef inline object make_ret(TVMValue value, int tcode):
elif tcode == kTVMNullptr:
return None
elif tcode == kTVMArgBool:
return value.v_bool
return bool(value.v_int64)
elif tcode == kInt:
return value.v_int64
elif tcode == kFloat:
Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-sys/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ macro_rules! TVMPODValue {
DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool),
TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0),
TVMArgTypeCode_kTVMNullptr => Null,
TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
TVMArgTypeCode_kDLDevice => Device($value.v_device),
Expand All @@ -119,7 +119,7 @@ macro_rules! TVMPODValue {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool),
Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool),
Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice),
Expand Down
4 changes: 1 addition & 3 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,8 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
return kTvmErrorFunctionCallWrongArgType;
}

if (type_codes[2] == kDLInt) {
if (type_codes[2] == kDLInt || type_codes[2] == kTVMArgBool) {
query_imports = args[2].v_int64 != 0;
} else if (type_codes[2] == kTVMArgBool) {
query_imports = args[2].v_bool;
} else {
TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer");
return kTvmErrorFunctionCallWrongArgType;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ struct RPCReference {
break;
}
case kTVMArgBool: {
channel->template Write<bool>(value.v_bool);
channel->template Write<int64_t>(value.v_int64);
break;
}
case kTVMDataType: {
Expand Down Expand Up @@ -437,7 +437,7 @@ struct RPCReference {
break;
}
case kTVMArgBool: {
channel->template Read<bool>(&(value.v_bool));
channel->template Read<int64_t>(&(value.v_int64));
break;
}
case kTVMDataType: {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr);

if (op->dtype == DataType::Bool()) {
struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value);
struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value);
}

return struct_value;
Expand Down
3 changes: 1 addition & 2 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ inline DataType APIType(DataType t) {
ICHECK(!t.is_void()) << "Cannot pass void type through packed API.";
if (t.is_handle()) return t;
ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
if (t.is_bool()) return DataType::Bool();
if (t.is_uint() || t.is_int()) return DataType::Int(64);
if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64);
ICHECK(t.is_float());
return DataType::Float(64);
}
Expand Down
20 changes: 7 additions & 13 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ class ReturnRewriter : public StmtMutator {

// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
if (dtype.is_bool()) {
info.tcode = kTVMArgBool;
info.expr = Cast(DataType::Int(64), val);

} else if (dtype.is_int() || dtype.is_uint()) {
info.tcode = kTVMArgInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
Expand Down Expand Up @@ -340,25 +344,15 @@ PrimFunc MakePackedAPI(PrimFunc func) {
seq_init.emplace_back(
AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop));

arg_value = Call(t, builtin::if_then_else(),
{
tcode == kTVMArgBool,
f_arg_value(DataType::Bool(), i),
cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)),
});
arg_value = cast(DataType::Bool(), f_arg_value(DataType::Int(64), i));

} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name_hint << ": Expect arg[" << i << "] to be int";
seq_init.emplace_back(
AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop));

arg_value = Call(t, builtin::if_then_else(),
{
tcode == kTVMArgInt,
f_arg_value(t, i),
cast(t, f_arg_value(DataType::Bool(), i)),
});
arg_value = f_arg_value(t, i);
} else {
ICHECK(t.is_float());
std::ostringstream msg;
Expand Down
16 changes: 16 additions & 0 deletions tests/python/codegen/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,5 +1179,21 @@ def func(arg: T.bool) -> T.int32:
assert output == 20


def test_bool_return_value():
"""Booleans may be returned from a PrimFunc"""

@T.prim_func
def func(value: T.int32) -> T.bool:
T.func_attr({"target": T.target("llvm")})
return value < 10

built = tvm.build(func)
assert isinstance(built(0), bool)
assert built(0)

assert isinstance(built(15), bool)
assert not built(15)


if __name__ == "__main__":
tvm.testing.main()
12 changes: 2 additions & 10 deletions tests/python/tir-transform/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,11 +444,7 @@ def main(
arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids)
arg_code: T.int32 = arg_type_ids_1[0]
assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int"
arg: T.int32 = T.if_then_else(
arg_code == 0,
T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")),
T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")),
)
arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64"))
with T.attr(0, "compute_scope", "main_compute_"):
out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,))
out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,))
Expand Down Expand Up @@ -510,11 +506,7 @@ def main(
arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids)
arg_code: T.int32 = arg_type_ids_1[0]
assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean"
arg: T.bool = T.if_then_else(
arg_code == 15,
T.tvm_struct_get(args, 0, 12, "bool"),
T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")),
)
arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64"))
with T.attr(0, "compute_scope", "main_compute_"):
out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,))
out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,))
Expand Down

0 comments on commit 0f037a6

Please sign in to comment.