Skip to content

Commit

Permalink
Rename runtime::BoxInt to runtime::Int, similar for float/bool
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Nov 29, 2023
1 parent 1e5078f commit 0f90e61
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 130 deletions.
4 changes: 2 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -813,12 +813,12 @@ template <>
struct PackedFuncValueConverter<tvm::Bool> {
static Optional<tvm::Bool> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsBool()) {
return Bool(opt.value());
return tvm::Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
int value = opt.value();
ICHECK(value == 0 || value == 1)
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
return Bool(static_cast<bool>(value));
return tvm::Bool(static_cast<bool>(value));
} else {
return NullOpt;
}
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/container/boxed_primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ class Box : public ObjectRef {
* Can be used to store POD integer values as a TVM ObjectRef. Used
* for FFI handling, and for storing POD types inside TVM containers.
*/
using BoxInt = Box<int64_t>;
using Int = Box<int64_t>;

/*! \brief Boxed version of C++ double
*
* Can be used to store POD floating-point values as a TVM ObjectRef.
* Used for FFI handling, and for storing POD types inside TVM
* containers.
*/
using BoxFloat = Box<double>;
using Float = Box<double>;

/*! \brief Boxed version of C++ bool
*
Expand All @@ -118,7 +118,7 @@ using BoxFloat = Box<double>;
* hold the object, a Python to C++ to Python round trip will preserve
* the distinction between bool and int.
*/
using BoxBool = Box<bool>;
using Bool = Box<bool>;

namespace detail {
template <>
Expand Down
43 changes: 20 additions & 23 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -2044,23 +2044,23 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const {
// `TVMRetValue`. Instead, this conversion is checked in the FFI
// return value, to ensure that boxing/unboxing is applied
// consistently.
if constexpr (std::is_base_of_v<BoxInt::ContainerType, ContainerType> ||
std::is_base_of_v<ContainerType, BoxInt::ContainerType>) {
if (std::is_base_of_v<BoxInt::ContainerType, ContainerType> ||
ptr->IsInstance<BoxInt::ContainerType>()) {
values_[i].v_int64 = static_cast<BoxInt::ContainerType*>(ptr)->value;
if constexpr (std::is_base_of_v<Int::ContainerType, ContainerType> ||
std::is_base_of_v<ContainerType, Int::ContainerType>) {
if (std::is_base_of_v<Int::ContainerType, ContainerType> ||
ptr->IsInstance<Int::ContainerType>()) {
values_[i].v_int64 = static_cast<Int::ContainerType*>(ptr)->value;
type_codes_[i] = kTVMArgInt;
return;
}
}

// Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt
// explanation for more detail.
if constexpr (std::is_base_of_v<BoxFloat::ContainerType, ContainerType> ||
std::is_base_of_v<ContainerType, BoxFloat::ContainerType>) {
if (std::is_base_of_v<BoxFloat::ContainerType, ContainerType> ||
ptr->IsInstance<BoxFloat::ContainerType>()) {
values_[i].v_float64 = static_cast<BoxFloat::ContainerType*>(ptr)->value;
if constexpr (std::is_base_of_v<Float::ContainerType, ContainerType> ||
std::is_base_of_v<ContainerType, Float::ContainerType>) {
if (std::is_base_of_v<Float::ContainerType, ContainerType> ||
ptr->IsInstance<Float::ContainerType>()) {
values_[i].v_float64 = static_cast<Float::ContainerType*>(ptr)->value;
type_codes_[i] = kTVMArgFloat;
return;
}
Expand Down Expand Up @@ -2192,15 +2192,15 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const {
}
}

if constexpr (std::is_base_of_v<TObjectRef, BoxInt>) {
if constexpr (std::is_base_of_v<TObjectRef, Int>) {
if (type_code_ == kTVMArgInt) {
return BoxInt(value_.v_int64);
return Int(value_.v_int64);
}
}

if constexpr (std::is_base_of_v<TObjectRef, BoxFloat>) {
if constexpr (std::is_base_of_v<TObjectRef, Float>) {
if (type_code_ == kTVMArgFloat) {
return BoxFloat(value_.v_float64);
return Float(value_.v_float64);
}
}

Expand Down Expand Up @@ -2243,19 +2243,16 @@ inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) {
}
}

if constexpr (std::is_base_of_v<BoxInt, TObjectRef> || std::is_base_of_v<TObjectRef, BoxInt>) {
if (ptr &&
(std::is_base_of_v<BoxInt, TObjectRef> || ptr->IsInstance<BoxInt::ContainerType>())) {
int64_t value = static_cast<const BoxInt::ContainerType*>(ptr)->value;
if constexpr (std::is_base_of_v<Int, TObjectRef> || std::is_base_of_v<TObjectRef, Int>) {
if (ptr && (std::is_base_of_v<Int, TObjectRef> || ptr->IsInstance<Int::ContainerType>())) {
int64_t value = static_cast<const Int::ContainerType*>(ptr)->value;
return operator=(value);
}
}

if constexpr (std::is_base_of_v<BoxFloat, TObjectRef> ||
std::is_base_of_v<TObjectRef, BoxFloat>) {
if (ptr &&
(std::is_base_of_v<BoxFloat, TObjectRef> || ptr->IsInstance<BoxFloat::ContainerType>())) {
double value = static_cast<const BoxFloat::ContainerType*>(ptr)->value;
if constexpr (std::is_base_of_v<Float, TObjectRef> || std::is_base_of_v<TObjectRef, Float>) {
if (ptr && (std::is_base_of_v<Float, TObjectRef> || ptr->IsInstance<Float::ContainerType>())) {
double value = static_cast<const Float::ContainerType*>(ptr)->value;
return operator=(value);
}
}
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR";
.add_attr_option<String>("model") \
.add_attr_option<Array<String>>("libs") \
.add_attr_option<Target>("host") \
.add_attr_option<runtime::BoxInt>("from_device") \
.add_attr_option<runtime::BoxInt>("target_device_type")
.add_attr_option<runtime::Int>("from_device") \
.add_attr_option<runtime::Int>("target_device_type")

} // namespace tvm

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,7 +1191,7 @@ struct PackedFuncValueConverter<PrimExpr> {
if (auto opt = val.TryAsBool()) {
// Check against val.TryAsBool directly, to avoid the
// bounds-checking in PackedFuncValueConverter<Bool>::TryFrom.
return Bool(opt.value());
return tvm::Bool(opt.value());
} else if (auto opt = PackedFuncValueConverter<IntImm>::TryFrom(val)) {
return opt.value();
} else if (auto opt = PackedFuncValueConverter<FloatImm>::TryFrom(val)) {
Expand Down
14 changes: 7 additions & 7 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
n->binding_names.push_back(Downcast<String>(v));
}
if (auto v = config_dict.Get("show_meta")) {
n->show_meta = Downcast<runtime::BoxBool>(v)->value;
n->show_meta = Downcast<runtime::Bool>(v)->value;
}
if (auto v = config_dict.Get("ir_prefix")) {
n->ir_prefix = Downcast<String>(v);
Expand All @@ -60,16 +60,16 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
n->float_dtype = DataType(runtime::String2DLDataType(Downcast<String>(v)));
}
if (auto v = config_dict.Get("verbose_expr")) {
n->verbose_expr = Downcast<runtime::BoxBool>(v)->value;
n->verbose_expr = Downcast<runtime::Bool>(v)->value;
}
if (auto v = config_dict.Get("indent_spaces")) {
n->indent_spaces = Downcast<runtime::BoxInt>(v)->value;
n->indent_spaces = Downcast<runtime::Int>(v)->value;
}
if (auto v = config_dict.Get("print_line_numbers")) {
n->print_line_numbers = Downcast<runtime::BoxBool>(v)->value;
n->print_line_numbers = Downcast<runtime::Bool>(v)->value;
}
if (auto v = config_dict.Get("num_context_lines")) {
n->num_context_lines = Downcast<runtime::BoxInt>(v)->value;
n->num_context_lines = Downcast<runtime::Int>(v)->value;
}
if (auto v = config_dict.Get("path_to_underline")) {
n->path_to_underline = Downcast<Optional<Array<ObjectPath>>>(v).value_or(Array<ObjectPath>());
Expand All @@ -86,10 +86,10 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
Downcast<Optional<Map<ObjectRef, String>>>(v).value_or(Map<ObjectRef, String>());
}
if (auto v = config_dict.Get("syntax_sugar")) {
n->syntax_sugar = Downcast<runtime::BoxBool>(v)->value;
n->syntax_sugar = Downcast<runtime::Bool>(v)->value;
}
if (auto v = config_dict.Get("show_object_address")) {
n->show_object_address = Downcast<runtime::BoxBool>(v)->value;
n->show_object_address = Downcast<runtime::Bool>(v)->value;
}

this->data_ = std::move(n);
Expand Down
24 changes: 12 additions & 12 deletions src/target/tag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,36 +75,36 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64")
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("cortex-a72")},
{"mattr", Array<String>{"+neon"}},
{"num-cores", runtime::BoxInt(4)},
{"num-cores", runtime::Int(4)},
{"host", Map<String, ObjectRef>{{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("cortex-a72")},
{"mattr", Array<String>{"+neon"}},
{"num-cores", runtime::BoxInt(4)}}}});
{"num-cores", runtime::Int(4)}}}});

TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier")
.set_config({{"kind", String("cuda")},
{"arch", String("sm_72")},
{"max_shared_memory_per_block", runtime::BoxInt(49152)},
{"max_threads_per_block", runtime::BoxInt(1024)},
{"thread_warp_size", runtime::BoxInt(32)},
{"registers_per_block", runtime::BoxInt(65536)},
{"max_shared_memory_per_block", runtime::Int(49152)},
{"max_threads_per_block", runtime::Int(1024)},
{"thread_warp_size", runtime::Int(32)},
{"registers_per_block", runtime::Int(65536)},
{"host", Map<String, ObjectRef>{{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("carmel")},
{"num-cores", runtime::BoxInt(8)}}}});
{"num-cores", runtime::Int(8)}}}});

TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano")
.set_config({{"kind", String("cuda")},
{"arch", String("sm_87")},
{"max_shared_memory_per_block", runtime::BoxInt(49152)},
{"max_threads_per_block", runtime::BoxInt(1024)},
{"thread_warp_size", runtime::BoxInt(32)},
{"registers_per_block", runtime::BoxInt(65536)},
{"max_shared_memory_per_block", runtime::Int(49152)},
{"max_threads_per_block", runtime::Int(1024)},
{"thread_warp_size", runtime::Int(32)},
{"registers_per_block", runtime::Int(65536)},
{"host", Map<String, ObjectRef>{{"kind", String("llvm")},
{"mtriple", String("aarch64-linux-gnu")},
{"mcpu", String("carmel")},
{"num-cores", runtime::BoxInt(6)}}}});
{"num-cores", runtime::Int(6)}}}});

#define TVM_REGISTER_CUDA_TAG(Name, Arch, SharedMem, RegPerBlock) \
TVM_REGISTER_TARGET_TAG(Name).set_config({ \
Expand Down
21 changes: 10 additions & 11 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi
ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info) {
std::string interp_str = Interpret(str);
if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex() ||
info.type_index == runtime::BoxBool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex() ||
info.type_index == runtime::Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer or boolean
std::istringstream is(interp_str);
int v;
Expand All @@ -379,10 +379,10 @@ ObjectRef TargetInternal::ParseType(const std::string& str,
}
}

if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
return runtime::BoxInt(v);
if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
return runtime::Int(v);
} else {
return runtime::BoxBool(v);
return runtime::Bool(v);
}
} else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing string, strip leading/trailing spaces, and enclosing quotes if any
Expand Down Expand Up @@ -417,10 +417,9 @@ ObjectRef TargetInternal::ParseType(const std::string& str,

ObjectRef TargetInternal::ParseType(const ObjectRef& obj,
const TargetKindNode::ValueTypeInfo& info) {
if (info.type_index == runtime::BoxInt::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
return GetRef<runtime::BoxInt>(
ObjTypeCheck<runtime::BoxInt::ContainerType>(obj, "runtime.BoxInt"));
return GetRef<runtime::Int>(ObjTypeCheck<runtime::Int::ContainerType>(obj, "runtime.BoxInt"));
} else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) {
// Parsing string
return GetRef<String>(ObjTypeCheck<StringObj>(obj, "String"));
Expand Down Expand Up @@ -491,9 +490,9 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj,
/********** Stringifying **********/

std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) {
if (const auto* p = obj.as<runtime::BoxNode<int64_t>>()) {
if (const auto* p = obj.as<runtime::Int::ContainerType>()) {
return std::to_string(p->value);
} else if (const auto* p = obj.as<runtime::BoxNode<bool>>()) {
} else if (const auto* p = obj.as<runtime::Bool::ContainerType>()) {
return std::to_string(p->value);
}
if (auto tvm_str = obj.as<String>()) {
Expand Down Expand Up @@ -963,7 +962,7 @@ ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
// If requested, query attributes from the device. User-specified
// parameters take precedence over queried parameters.
if (attrs.count("from_device")) {
int device_id = Downcast<runtime::BoxInt>(attrs.at("from_device"))->value;
int device_id = Downcast<runtime::Int>(attrs.at("from_device"))->value;
attrs.erase("from_device");
auto device_params = QueryDevice(device_id, target.get());

Expand Down
Loading

0 comments on commit 0f90e61

Please sign in to comment.