Skip to content

Commit

Permalink
[FFI][Runtime] Use TVMValue::v_int64 to represent boolean values
Browse files Browse the repository at this point in the history
This is a follow-up to apache#16183, which
added handling of boolean values in the TVM FFI.  The initial
implementation added both a new type code (`kTVMArgBool`) and a new
`TVMValue::v_bool` variant.  This commit removes the
`TVMValue::v_bool` variant, since the `kTVMArgBool` type code is
sufficient to handle boolean arguments.

Removing the `TVMValue::v_bool` variant also makes all `TVMValue`
variants be 64-bit (assuming a 64-bit CPU).  This can simplify
debugging in some cases, since it prevents partial values from
inactive variants from being present in memory.
  • Loading branch information
Lunderberg committed Aug 5, 2024
1 parent 5f22be4 commit 133290a
Show file tree
Hide file tree
Showing 9 changed files with 33 additions and 17 deletions.
1 change: 0 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ typedef DLTensor* TVMArrayHandle;
*/
typedef union {
int64_t v_int64;
bool v_bool;
double v_float64;
void* v_handle;
const char* v_str;
Expand Down
10 changes: 5 additions & 5 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ class TVMPODValue_ {
// conversions. This is publicly exposed, as it can be useful in
// specializations of PackedFuncValueConverter.
if (type_code_ == kTVMArgBool) {
return value_.v_bool;
return static_cast<bool>(value_.v_int64);
} else {
return std::nullopt;
}
Expand Down Expand Up @@ -1014,7 +1014,7 @@ class TVMRetValue : public TVMPODValue_CRTP_<TVMRetValue> {
TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kTVMArgBool);
value_.v_bool = value;
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(std::string value) {
Expand Down Expand Up @@ -1804,7 +1804,7 @@ class TVMArgsSetter {
type_codes_[i] = kDLInt;
}
TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const {
values_[i].v_bool = value;
values_[i].v_int64 = value;
type_codes_[i] = kTVMArgBool;
}
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
Expand Down Expand Up @@ -2115,7 +2115,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
std::is_base_of_v<ContainerType, Bool::ContainerType>) {
if (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
ptr->IsInstance<Bool::ContainerType>()) {
values_[i].v_bool = static_cast<Bool::ContainerType*>(ptr)->value;
values_[i].v_int64 = static_cast<Bool::ContainerType*>(ptr)->value;
type_codes_[i] = kTVMArgBool;
return;
}
Expand Down Expand Up @@ -2300,7 +2300,7 @@ inline TObjectRef TVMPODValue_CRTP_<Derived>::AsObjectRef() const {

if constexpr (std::is_base_of_v<TObjectRef, Bool>) {
if (type_code_ == kTVMArgBool) {
return Bool(value_.v_bool);
return Bool(value_.v_int64);
}
}

Expand Down
4 changes: 2 additions & 2 deletions rust/tvm-sys/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ macro_rules! TVMPODValue {
DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool),
TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0),
TVMArgTypeCode_kTVMNullptr => Null,
TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
TVMArgTypeCode_kDLDevice => Device($value.v_device),
Expand All @@ -119,7 +119,7 @@ macro_rules! TVMPODValue {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool),
Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool),
Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice),
Expand Down
4 changes: 1 addition & 3 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,8 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
return kTvmErrorFunctionCallWrongArgType;
}

if (type_codes[2] == kDLInt) {
if (type_codes[2] == kDLInt || type_codes[2] == kTVMArgBool) {
query_imports = args[2].v_int64 != 0;
} else if (type_codes[2] == kTVMArgBool) {
query_imports = args[2].v_bool;
} else {
TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer");
return kTvmErrorFunctionCallWrongArgType;
Expand Down
4 changes: 2 additions & 2 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ struct RPCReference {
break;
}
case kTVMArgBool: {
channel->template Write<bool>(value.v_bool);
channel->template Write<int64_t>(value.v_int64);
break;
}
case kTVMDataType: {
Expand Down Expand Up @@ -437,7 +437,7 @@ struct RPCReference {
break;
}
case kTVMArgBool: {
channel->template Read<bool>(&(value.v_bool));
channel->template Read<int64_t>(&(value.v_int64));
break;
}
case kTVMDataType: {
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1379,7 +1379,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) {
llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr);

if (op->dtype == DataType::Bool()) {
struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value);
struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value);
}

return struct_value;
Expand Down
3 changes: 1 addition & 2 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,7 @@ inline DataType APIType(DataType t) {
ICHECK(!t.is_void()) << "Cannot pass void type through packed API.";
if (t.is_handle()) return t;
ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API.";
if (t.is_bool()) return DataType::Bool();
if (t.is_uint() || t.is_int()) return DataType::Int(64);
if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64);
ICHECK(t.is_float());
return DataType::Float(64);
}
Expand Down
6 changes: 5 additions & 1 deletion src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ class ReturnRewriter : public StmtMutator {

// convert val's data type to FFI data type, return type code
DataType dtype = val.dtype();
if (dtype.is_int() || dtype.is_uint()) {
if (dtype.is_bool()) {
info.tcode = kTVMArgBool;
info.expr = Cast(DataType::Int(64), val);

} else if (dtype.is_int() || dtype.is_uint()) {
info.tcode = kTVMArgInt;
info.expr = Cast(DataType::Int(64), val);
} else if (dtype.is_float()) {
Expand Down
16 changes: 16 additions & 0 deletions tests/python/codegen/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,5 +1179,21 @@ def func(arg: T.bool) -> T.int32:
assert output == 20


def test_bool_return_value():
"""Booleans may be returned from a PrimFunc"""

@T.prim_func
def func(value: T.int32) -> T.bool:
T.func_attr({"target": T.target("llvm")})
return value < 10

built = tvm.build(func)
assert isinstance(built(0), bool)
assert built(0)

assert isinstance(built(15), bool)
assert not built(15)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 133290a

Please sign in to comment.