Skip to content

Commit

Permalink
Update rust bindings for bool in TVMValue
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Feb 14, 2024
1 parent 6c9aa9d commit d8b9ab9
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 38 deletions.
2 changes: 1 addition & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ typedef enum {
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
kTVMBool = 15U,
kTVMArgBool = 15U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ class TVMPODValue_ {
// Helper function to reduce duplication in the variable integer
// conversions. This is publicly exposed, as it can be useful in
// specializations of PackedFuncValueConverter.
if (type_code_ == kTVMBool) {
if (type_code_ == kTVMArgBool) {
return value_.v_bool;
} else {
return std::nullopt;
Expand Down Expand Up @@ -1013,7 +1013,7 @@ class TVMRetValue : public TVMPODValue_CRTP_<TVMRetValue> {
}
TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); }
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kTVMBool);
this->SwitchToPOD(kTVMArgBool);
value_.v_bool = value;
return *this;
}
Expand Down Expand Up @@ -1379,7 +1379,7 @@ inline const char* ArgTypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt:
return "int";
case kTVMBool:
case kTVMArgBool:
return "bool";
case kDLUInt:
return "uint";
Expand Down Expand Up @@ -1804,7 +1804,7 @@ class TVMArgsSetter {
}
TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const {
values_[i].v_bool = value;
type_codes_[i] = kTVMBool;
type_codes_[i] = kTVMArgBool;
}
TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
Expand Down Expand Up @@ -2115,7 +2115,7 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
if (std::is_base_of_v<Bool::ContainerType, ContainerType> ||
ptr->IsInstance<Bool::ContainerType>()) {
values_[i].v_bool = static_cast<Bool::ContainerType*>(ptr)->value;
type_codes_[i] = kTVMBool;
type_codes_[i] = kTVMArgBool;
return;
}
}
Expand Down Expand Up @@ -2298,7 +2298,7 @@ inline TObjectRef TVMPODValue_CRTP_<Derived>::AsObjectRef() const {
}

if constexpr (std::is_base_of_v<TObjectRef, Bool>) {
if (type_code_ == kTVMBool) {
if (type_code_ == kTVMArgBool) {
return Bool(value_.v_bool);
}
}
Expand Down
35 changes: 4 additions & 31 deletions rust/tvm-sys/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ macro_rules! TVMPODValue {
Int(i64),
UInt(i64),
Float(f64),
Bool(bool),
Null,
DataType(DLDataType),
String(*mut c_char),
Expand All @@ -95,6 +96,7 @@ macro_rules! TVMPODValue {
DLDataTypeCode_kDLInt => Int($value.v_int64),
DLDataTypeCode_kDLUInt => UInt($value.v_int64),
DLDataTypeCode_kDLFloat => Float($value.v_float64),
TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool),
TVMArgTypeCode_kTVMNullptr => Null,
TVMArgTypeCode_kTVMDataType => DataType($value.v_type),
TVMArgTypeCode_kDLDevice => Device($value.v_device),
Expand All @@ -117,6 +119,7 @@ macro_rules! TVMPODValue {
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool),
Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr),
DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType),
Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice),
Expand Down Expand Up @@ -263,6 +266,7 @@ macro_rules! impl_pod_value {
impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]);
impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]);
impl_pod_value!(Float, f64, [f32, f64]);
impl_pod_value!(Bool, bool, [bool]);
impl_pod_value!(DataType, DLDataType, [DLDataType]);
impl_pod_value!(Device, DLDevice, [DLDevice]);

Expand Down Expand Up @@ -380,37 +384,6 @@ impl TryFrom<RetValue> for std::ffi::CString {
}
}

// Implementations for bool.

impl<'a> From<&bool> for ArgValue<'a> {
fn from(s: &bool) -> Self {
(*s as i64).into()
}
}

impl From<bool> for RetValue {
fn from(s: bool) -> Self {
(s as i64).into()
}
}

impl TryFrom<RetValue> for bool {
type Error = ValueDowncastError;

fn try_from(val: RetValue) -> Result<bool, Self::Error> {
try_downcast!(val -> bool,
|RetValue::Int(val)| { !(val == 0) })
}
}

impl<'a> TryFrom<ArgValue<'a>> for bool {
type Error = ValueDowncastError;

fn try_from(val: ArgValue<'a>) -> Result<bool, Self::Error> {
try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) })
}
}

impl From<()> for RetValue {
fn from(_: ()) -> Self {
RetValue::Null
Expand Down

0 comments on commit d8b9ab9

Please sign in to comment.