Skip to content

Commit

Permalink
Handle conversion to runtime::String for both Arg/Ret value
Browse files Browse the repository at this point in the history
The `TVMArgValue` and `TVMRetValue` types have different
interpretations for the `kTVMStr` and `kTVMBytes`.  Therefore,
since `pod_value.AsObjectRef<ObjectRef>()` may require converting a
`kTVMStr` or `kTVMBytes` into a `tvm::runtime::String`, the
`AsObjectRef` method must be implemented in a context that knows which
derived POD class is being used.

This commit moves the `AsObjectRef` method from `TVMPODValue_` to
`TVMPodValue_CRTP_<Derived>`.  The POD subclasses now inherit from
`TVMPodValue_CRTP_<Derived>`, which itself inherits from
`TVMPODValue_`.
  • Loading branch information
Lunderberg committed Feb 8, 2024
1 parent a7eaecf commit e0f7408
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 118 deletions.
30 changes: 18 additions & 12 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,8 @@ namespace runtime {

template <>
struct PackedFuncValueConverter<tvm::IntImm> {
static Optional<tvm::IntImm> TryFrom(const TVMPODValue_& val) {
template <typename PODSubclass>
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsInt()) {
int64_t value = opt.value();
auto dtype =
Expand All @@ -796,31 +797,34 @@ struct PackedFuncValueConverter<tvm::IntImm> {
}
}

static tvm::IntImm From(const TVMPODValue_& val) {
template <typename PODSubclass>
static tvm::IntImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::IntImm>();
return val.template AsObjectRef<tvm::IntImm>();
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
template <typename PODSubclass>
static tvm::Integer From(const PODSubclass& val) {
if (auto opt = val.TryAsInt()) {
return Integer(opt.value());
} else if (auto opt = val.TryAsBool()) {
return Integer(opt.value());
} else {
return val.AsObjectRef<tvm::Integer>();
return val.template AsObjectRef<tvm::Integer>();
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static Optional<tvm::Bool> TryFrom(const TVMPODValue_& val) {
template <typename PODSubclass>
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsBool()) {
return tvm::Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
Expand All @@ -833,11 +837,12 @@ struct PackedFuncValueConverter<tvm::Bool> {
}
}

static tvm::Bool From(const TVMPODValue_& val) {
template <typename PODSubclass>
static tvm::Bool From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::Bool>();
return val.template AsObjectRef<tvm::Bool>();
}
}
};
Expand All @@ -852,11 +857,12 @@ struct PackedFuncValueConverter<tvm::FloatImm> {
}
}

static tvm::FloatImm From(const TVMPODValue_& val) {
template <typename PODSubclass>
static tvm::FloatImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.AsObjectRef<tvm::FloatImm>();
return val.template AsObjectRef<tvm::FloatImm>();
}
}
};
Expand All @@ -873,8 +879,8 @@ struct PackedFuncValueConverter<tvm::FloatImm> {
*/
template <>
struct PackedFuncValueConverter<runtime::Int> {
template <typename PODType>
static runtime::Int From(const PODType& val) {
template <typename PODSubclass>
static runtime::Int From(const PODSubclass& val) {
if (val.template IsObjectRef<tvm::IntImm>()) {
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
} else {
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class NDArray : public ObjectRef {

protected:
friend class TVMPODValue_;
template <typename Derived>
friend class TVMPODValue_CRTP_;
friend class TVMRetValue;
friend class TVMArgsSetter;
/*!
Expand Down
Loading

0 comments on commit e0f7408

Please sign in to comment.