Skip to content

Commit

Permalink
[FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183)
Browse files Browse the repository at this point in the history
* [Container] Support non-nullable types in Array::Map

Prior to this commit, the `Array::Map` member function could only be
applied to nullable object types.  This was due to the internal use of
`U()` as the default value for initializing the output `ArrayNode`, where
`U` is the return type of the mapping function.  This default
constructor is only available for nullable types, and would result in
a compile-time failure for non-nullable types.

This commit replaces `U()` with `ObjectRef()` in `Array::Map`,
removing this limitation.  Since all items in the output array are
overwritten before returning to the calling scope, initializing the
output array with `ObjectRef()` does not violate type safety.

* [FFI] Separate runtime types from IR types for int/float/bool

Prior to this commit, `int`, `float`, and `bool` arguments from Python
were converted to `IntImm`, `FloatImm`, and `Bool`.  These are
subtypes of `PrimExpr`, and should only be used at compile-time.  By
automatically applying this conversion as part of the FFI, these types
are required to be present whenever a primitive is converted to a
`tvm::ObjectRef`.

This can become especially fragile for an end-user when storing
objects into a TVM container.  Because TVM containers require all
contents to be `ObjectRef` subclasses, an automatic conversion may be
applied on storing into a container, resulting in an unexpected type
being retrieved from the container.  For example, this currently
occurs in Relax when extracting a `R.Prim` from a `R.Tuple`.

This commit introduces a `Box<T>` type for storage of boxed primitives
at runtime, distinct from the IR types.

* Primitive arguments provided to a PackedFunc that requires an
  `ObjectRef` will be converted to the corresponding boxed type.
  (e.g. Passing a Python `int` to a C++ function accepting `ObjectRef`
  produces a `Box<int64_t>`.

* Boxed primitives provided to a PackedFunc that requires an unboxed
  primitive will be converted to the corresponding primitive.

* PackedFunc return values of `ObjectRef` are converted to the
  corresponding primitive, if present.  (e.g. If a `tuple_getitem`
  with static return type `ObjectRef` returns a `Box<int64_t>`, it
  will be unwrapped to a python `int`.)

Together, these three rules provide backwards compatibility for
existing PackedFunc definitions, while avoiding exposing the user to
any container-induced type conversions betweeen primitive types and
`ObjectRef`.

* Fix unit test failure after merge

* Fix breakage in new unit test
  • Loading branch information
Lunderberg authored Aug 5, 2024
1 parent 5a67a00 commit 5f22be4
Show file tree
Hide file tree
Showing 184 changed files with 3,215 additions and 1,221 deletions.
76 changes: 58 additions & 18 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,16 @@ class DictAttrs : public Attrs {

auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
// For backwards compatibility, return through TVMRetValue.
// This triggers any automatic conversions registered with
// PackedFuncValueConverter. Importantly, this allows use of
// `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
// are stored internally as `runtime::Box<int64_t>` and
// `runtime::Box<bool>`.
TVMRetValue ret;
ret = (*it).second;
Optional<TObjectRef> obj = ret;
return obj;
} else {
return default_value;
}
Expand Down Expand Up @@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}

/*!
* \brief Copy the DictAttrs, but overrides attributes with the
* entries from \p attrs.
*
* \param attrs The DictAttrs to update
*
* \param new_attrs Key/values attributes to add to \p attrs.
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttrs(DictAttrs attrs, Map<String, ObjectRef> new_attrs);

/*!
* \brief Copy the DictAttrs, but overrides a single attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The update to insert or update.
*
* \param value The new value of the attribute
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value);

inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
return WithAttr(std::move(attrs), String(key), std::move(value));
}

/*!
* \brief Copy the DictAttrs, but without a specific attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The key to remove
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
Expand Down Expand Up @@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);

return input;
}

Expand All @@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}

node->attrs = WithAttrs(std::move(node->attrs), attrs);

return input;
}

Expand Down Expand Up @@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
}
TNode* node = input.CopyOnWrite();
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);

return input;
}

Expand Down
130 changes: 99 additions & 31 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// common rule for RetValue and ArgValue

// Automatic conversion into IntImm, Integer, and Bool, when called
// through the FFI. Automatic conversions into PrimExpr are
// registered in "tvm/tir/expr.h", as it includes conversions to the
// TIR-only StringImm.
//
// While the FFI only requires the From() method, these
// implementations also define a TryFrom() method to avoid duplicate
// logic in the PrimExpr conversion.

template <>
struct PackedFuncValueConverter<PrimExpr> {
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
int64_t value = val.operator int64_t();
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
return IntImm(runtime::DataType::Int(64), value);
}
return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
return FloatImm(runtime::DataType::Float(32), val.operator double());
struct PackedFuncValueConverter<tvm::IntImm> {
template <typename PODSubclass>
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsInt()) {
int64_t value = opt.value();
auto dtype =
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
? DataType::Int(64)
: DataType::Int(32);
return IntImm(dtype, value);
} else if (auto opt = val.TryAsBool()) {
return IntImm(DataType::Int(32), opt.value());
} else {
return NullOpt;
}
}

return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
template <typename PODSubclass>
static tvm::IntImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::IntImm>();
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
template <typename PODSubclass>
static tvm::Integer From(const PODSubclass& val) {
if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
return Integer(opt.value());
} else {
return val.template AsObjectRef<tvm::Integer>();
}
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
template <typename PODSubclass>
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsBool()) {
return tvm::Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
int value = opt.value();
ICHECK(value == 0 || value == 1)
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
return tvm::Bool(static_cast<bool>(value));
} else {
return NullOpt;
}
}

template <typename PODSubclass>
static tvm::Bool From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::Bool>();
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
};

template <>
struct PackedFuncValueConverter<tvm::FloatImm> {
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsFloat()) {
return FloatImm(runtime::DataType::Float(32), opt.value());
} else {
return NullOpt;
}
}

template <typename PODSubclass>
static tvm::FloatImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::FloatImm>();
}
}
};

/* \brief Backwards compatibility wrapper for IntImm arguments
*
* In previous versions of TVM, IntImm was the default FFI type for
* integer arguments, instead of runtime::Int. For backwards
* compatibility where the callee has been updated to expected a
* runtime::Int, the caller has not been updated to provide a
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
* allow the IntImm to be generated.
*/
template <>
struct PackedFuncValueConverter<runtime::Int> {
template <typename PODSubclass>
static runtime::Int From(const PODSubclass& val) {
if (val.template IsObjectRef<tvm::IntImm>()) {
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
} else {
return val.template AsObjectRef<runtime::Int>();
}
return val.AsObjectRef<tvm::Bool>();
}
};

Expand Down
34 changes: 32 additions & 2 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,36 @@ class PassContext : public ObjectRef {
using ValueNodeType = typename ValueType::ContainerType;
// NOTE: we could further update the function later.
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
RegisterConfigOption(key, tindex);
auto type_key = runtime::Object::TypeIndex2Key(tindex);

auto* reflection = ReflectionVTable::Global();

auto legalization = [=](ObjectRef obj) -> ObjectRef {
if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
} else {
// Backwards compatibility for config options defined prior to
// https://github.com/apache/tvm/pull/16183. This commit
// changed the default FFI conversion of python integers from
// `tvm::IntImm` to `runtime::Int`.
//
// This backwards compatibility fix can be removed when all
// options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
// updated to use `runtime::Int` and `runtime::Bool`.
TVMRetValue ret;
ret = obj;
try {
ValueType legalized = ret;
return legalized;
} catch (Error& err) {
LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
<< ", but received error when converting to this type.\n"
<< err.what();
}
}
};

RegisterConfigOption(key, tindex, legalization);
return tindex;
}

Expand All @@ -285,7 +314,8 @@ class PassContext : public ObjectRef {
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
std::function<ObjectRef(ObjectRef)> legalization);

// Classes to get the Python `with` like syntax.
friend class Internal;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The schedule rule created
Expand All @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef {
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<runtime::Int> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs

struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
ObjectRef indices_or_sections;
Variant<runtime::Int, Array<runtime::Int>> indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

Expand Down Expand Up @@ -186,11 +187,12 @@ typedef enum {
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
kTVMArgBool = 15U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kTVMExtBegin = 15U,
kTVMExtBegin = 16U,
kTVMNNVMFirst = 16U,
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
Expand All @@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle;
*/
typedef union {
int64_t v_int64;
bool v_bool;
double v_float64;
void* v_handle;
const char* v_str;
Expand Down
Loading

0 comments on commit 5f22be4

Please sign in to comment.