From d8b9ab93dfeba71ba7518c63081e6e670e904da6 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 14 Feb 2024 08:24:14 -0600 Subject: [PATCH] Update rust bindings for bool in TVMValue --- include/tvm/runtime/c_runtime_api.h | 2 +- include/tvm/runtime/packed_func.h | 12 +++++----- rust/tvm-sys/src/packed_func.rs | 35 ++++------------------------- 3 files changed, 11 insertions(+), 38 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 7390324bafa3c..b2ed13638bbac 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -187,7 +187,7 @@ typedef enum { kTVMBytes = 12U, kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, - kTVMBool = 15U, + kTVMArgBool = 15U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 2c3c85168dde8..ec82e36bc35c7 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -641,7 +641,7 @@ class TVMPODValue_ { // Helper function to reduce duplication in the variable integer // conversions. This is publicly exposed, as it can be useful in // specializations of PackedFuncValueConverter. - if (type_code_ == kTVMBool) { + if (type_code_ == kTVMArgBool) { return value_.v_bool; } else { return std::nullopt; @@ -1013,7 +1013,7 @@ class TVMRetValue : public TVMPODValue_CRTP_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kTVMBool); + this->SwitchToPOD(kTVMArgBool); value_.v_bool = value; return *this; } @@ -1379,7 +1379,7 @@ inline const char* ArgTypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; - case kTVMBool: + case kTVMArgBool: return "bool"; case kDLUInt: return "uint"; @@ -1804,7 +1804,7 @@ class TVMArgsSetter { } TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { values_[i].v_bool = value; - type_codes_[i] = kTVMBool; + type_codes_[i] = kTVMArgBool; } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); @@ -2115,7 +2115,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { if (std::is_base_of_v || ptr->IsInstance()) { values_[i].v_bool = static_cast(ptr)->value; - type_codes_[i] = kTVMBool; + type_codes_[i] = kTVMArgBool; return; } } @@ -2298,7 +2298,7 @@ inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { } if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMBool) { + if (type_code_ == kTVMArgBool) { return Bool(value_.v_bool); } } diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d8..2c1f7db6adb0c 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -73,6 +73,7 @@ macro_rules! TVMPODValue { Int(i64), UInt(i64), Float(f64), + Bool(bool), Null, DataType(DLDataType), String(*mut c_char), @@ -95,6 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -117,6 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), @@ -263,6 +266,7 @@ macro_rules! impl_pod_value { impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(Bool, bool, [bool]); impl_pod_value!(DataType, DLDataType, [DLDataType]); impl_pod_value!(Device, DLDevice, [DLDevice]); @@ -380,37 +384,6 @@ impl TryFrom for std::ffi::CString { } } -// Implementations for bool. - -impl<'a> From<&bool> for ArgValue<'a> { - fn from(s: &bool) -> Self { - (*s as i64).into() - } -} - -impl From for RetValue { - fn from(s: bool) -> Self { - (s as i64).into() - } -} - -impl TryFrom for bool { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> bool, - |RetValue::Int(val)| { !(val == 0) }) - } -} - -impl<'a> TryFrom> for bool { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) - } -} - impl From<()> for RetValue { fn from(_: ()) -> Self { RetValue::Null