diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index b4c653a0a59e..d26c95e4f53c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -209,7 +209,6 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; - bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 91e53055b708..7c1b08e49002 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -669,7 +669,7 @@ class TVMPODValue_ { // conversions. This is publicly exposed, as it can be useful in // specializations of PackedFuncValueConverter. if (type_code_ == kTVMArgBool) { - return value_.v_bool; + return static_cast(value_.v_int64); } else { return std::nullopt; } @@ -1041,7 +1041,7 @@ class TVMRetValue : public TVMPODValue_CRTP_ { TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { this->SwitchToPOD(kTVMArgBool); - value_.v_bool = value; + value_.v_int64 = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -1831,7 +1831,7 @@ class TVMArgsSetter { type_codes_[i] = kDLInt; } TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { - values_[i].v_bool = value; + values_[i].v_int64 = value; type_codes_[i] = kTVMArgBool; } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { @@ -2142,7 +2142,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { std::is_base_of_v) { if (std::is_base_of_v || ptr->IsInstance()) { - values_[i].v_bool = static_cast(ptr)->value; + values_[i].v_int64 = static_cast(ptr)->value; type_codes_[i] = kTVMArgBool; return; } @@ -2327,7 +2327,7 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { if constexpr (std::is_base_of_v) { if (type_code_ == kTVMArgBool) { - return Bool(value_.v_bool); + return Bool(value_.v_int64); } } diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 7977f37d0be5..6e062ab5f199 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -121,7 +121,7 @@ cdef inline int make_arg(object arg, elif isinstance(arg, bool): # A python `bool` is a subclass of `int`, so this check # must occur before `Integral`. - value[0].v_bool = arg + value[0].v_int64 = arg tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg @@ -215,7 +215,7 @@ cdef inline object make_ret(TVMValue value, int tcode): elif tcode == kTVMNullptr: return None elif tcode == kTVMArgBool: - return value.v_bool + return bool(value.v_int64) elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index 2c1f7db6adb0..3d78ce52d621 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -96,7 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), - TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_int64 != 0), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -119,7 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), - Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), + Bool(val) => (TVMValue { v_int64: *val as i64 }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 04d36ad8bcab..2df37205b89c 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -362,10 +362,8 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r return kTvmErrorFunctionCallWrongArgType; } - if (type_codes[2] == kDLInt) { + if (type_codes[2] == kDLInt || type_codes[2] == kTVMArgBool) { query_imports = args[2].v_int64 != 0; - } else if (type_codes[2] == kTVMArgBool) { - query_imports = args[2].v_bool; } else { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 485ebdb449da..13c1fa4b38d3 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -326,7 +326,7 @@ struct RPCReference { break; } case kTVMArgBool: { - channel->template Write(value.v_bool); + channel->template Write(value.v_int64); break; } case kTVMDataType: { @@ -437,7 +437,7 @@ struct RPCReference { break; } case kTVMArgBool: { - channel->template Read(&(value.v_bool)); + channel->template Read(&(value.v_int64)); break; } case kTVMDataType: { diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 21899a12c4b0..b9e18bc4f8d2 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -1379,7 +1379,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); if (op->dtype == DataType::Bool()) { - struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); + struct_value = CreateCast(DataType::Int(64), op->dtype, struct_value); } return struct_value; diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 2948773321dd..05345aab8628 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,8 +155,7 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; - if (t.is_bool()) return DataType::Bool(); - if (t.is_uint() || t.is_int()) return DataType::Int(64); + if (t.is_bool() || t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index 9f2f1295fece..cf388630fcf6 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -81,7 +81,11 @@ class ReturnRewriter : public StmtMutator { // convert val's data type to FFI data type, return type code DataType dtype = val.dtype(); - if (dtype.is_int() || dtype.is_uint()) { + if (dtype.is_bool()) { + info.tcode = kTVMArgBool; + info.expr = Cast(DataType::Int(64), val); + + } else if (dtype.is_int() || dtype.is_uint()) { info.tcode = kTVMArgInt; info.expr = Cast(DataType::Int(64), val); } else if (dtype.is_float()) { @@ -340,12 +344,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back( AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgBool, - f_arg_value(DataType::Bool(), i), - cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), - }); + arg_value = cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)); } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; @@ -353,12 +352,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back( AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); - arg_value = Call(t, builtin::if_then_else(), - { - tcode == kTVMArgInt, - f_arg_value(t, i), - cast(t, f_arg_value(DataType::Bool(), i)), - }); + arg_value = f_arg_value(t, i); } else { ICHECK(t.is_float()); std::ostringstream msg; diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index d9a6fd6e62d1..e8036467ffb6 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1179,5 +1179,21 @@ def func(arg: T.bool) -> T.int32: assert output == 20 +def test_bool_return_value(): + """Booleans may be returned from a PrimFunc""" + + @T.prim_func + def func(value: T.int32) -> T.bool: + T.func_attr({"target": T.target("llvm")}) + return value < 10 + + built = tvm.build(func) + assert isinstance(built(0), bool) + assert built(0) + + assert isinstance(built(15), bool) + assert not built(15) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 0b43db56f300..f783ab2fcef1 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -444,11 +444,7 @@ def main( arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) arg_code: T.int32 = arg_type_ids_1[0] assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" - arg: T.int32 = T.if_then_else( - arg_code == 0, - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), - T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), - ) + arg: T.int32 = T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")) with T.attr(0, "compute_scope", "main_compute_"): out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) @@ -510,11 +506,7 @@ def main( arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) arg_code: T.int32 = arg_type_ids_1[0] assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" - arg: T.bool = T.if_then_else( - arg_code == 15, - T.tvm_struct_get(args, 0, 12, "bool"), - T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), - ) + arg: T.bool = T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")) with T.attr(0, "compute_scope", "main_compute_"): out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,))