diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 69e1355f7d130..0f7e5fcae6bdc 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -16,6 +16,7 @@ # under the License. from ..base import raise_last_ffi_error +from libcpp cimport bool as bool_t from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -38,7 +39,8 @@ cdef enum TVMArgTypeCode: kTVMBytes = 12 kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 - kTVMExtBegin = 15 + kTVMArgBool = 15 + kTVMExtBegin = 16 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: @@ -66,6 +68,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct TVMValue: int64_t v_int64 + bool_t v_bool double v_float64 void* v_handle const char* v_str diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index a4baed09bb95d..7977f37d0be57 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -121,10 +121,8 @@ cdef inline int make_arg(object arg, elif isinstance(arg, bool): # A python `bool` is a subclass of `int`, so this check # must occur before `Integral`. - arg = _FUNC_CONVERT_TO_OBJECT(arg) - value[0].v_handle = (arg).chandle - tcode[0] = kTVMObjectHandle - temp_args.append(arg) + value[0].v_bool = arg + tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg tcode[0] = kInt @@ -216,6 +214,8 @@ cdef inline object make_ret(TVMValue value, int tcode): return make_ret_object(value.v_handle) elif tcode == kTVMNullptr: return None + elif tcode == kTVMArgBool: + return value.v_bool elif tcode == kInt: return value.v_int64 elif tcode == kFloat: