diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 48b46343c694d..2494809119668 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -781,7 +781,8 @@ namespace runtime { template <> struct PackedFuncValueConverter { - static Optional TryFrom(const TVMPODValue_& val) { + template + static Optional TryFrom(const PODSubclass& val) { if (auto opt = val.TryAsInt()) { int64_t value = opt.value(); auto dtype = @@ -796,31 +797,34 @@ struct PackedFuncValueConverter { } } - static tvm::IntImm From(const TVMPODValue_& val) { + template + static tvm::IntImm From(const PODSubclass& val) { if (auto opt = TryFrom(val)) { return opt.value(); } else { - return val.AsObjectRef(); + return val.template AsObjectRef(); } } }; template <> struct PackedFuncValueConverter { - static tvm::Integer From(const TVMPODValue_& val) { + template + static tvm::Integer From(const PODSubclass& val) { if (auto opt = val.TryAsInt()) { return Integer(opt.value()); } else if (auto opt = val.TryAsBool()) { return Integer(opt.value()); } else { - return val.AsObjectRef(); + return val.template AsObjectRef(); } } }; template <> struct PackedFuncValueConverter { - static Optional TryFrom(const TVMPODValue_& val) { + template + static Optional TryFrom(const PODSubclass& val) { if (auto opt = val.TryAsBool()) { return tvm::Bool(opt.value()); } else if (auto opt = val.TryAsInt()) { @@ -833,11 +837,12 @@ struct PackedFuncValueConverter { } } - static tvm::Bool From(const TVMPODValue_& val) { + template + static tvm::Bool From(const PODSubclass& val) { if (auto opt = TryFrom(val)) { return opt.value(); } else { - return val.AsObjectRef(); + return val.template AsObjectRef(); } } }; @@ -852,11 +857,12 @@ struct PackedFuncValueConverter { } } - static tvm::FloatImm From(const TVMPODValue_& val) { + template + static tvm::FloatImm From(const PODSubclass& val) { if (auto opt = TryFrom(val)) { return opt.value(); } else { - return val.AsObjectRef(); + return val.template AsObjectRef(); } } }; @@ -873,8 +879,8 @@ struct PackedFuncValueConverter { */ template <> struct PackedFuncValueConverter { - template - static runtime::Int From(const PODType& val) { + template + static runtime::Int From(const PODSubclass& val) { if (val.template IsObjectRef()) { return runtime::Int(val.template AsObjectRef()->value); } else { diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 8400344bf5597..64a5a90fa7441 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -209,6 +209,8 @@ class NDArray : public ObjectRef { protected: friend class TVMPODValue_; + template + friend class TVMPODValue_CRTP_; friend class TVMRetValue; friend class TVMArgsSetter; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index e624a51c29541..a8747c59d6659 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -587,48 +587,6 @@ struct ObjectTypeChecker> { */ class TVMPODValue_ { public: - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (auto opt = TryAsFloat()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); - } - } - operator int64_t() const { - if (auto opt = TryAsInt()) { - return opt.value(); - } else if (auto opt = TryAsBool()) { - return opt.value(); - } else if (IsObjectRef()) { - auto obj = AsObjectRef(); - LOG(FATAL) << "Expected integer, but found object with type key " << obj->GetTypeKey(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } - operator uint64_t() const { return operator int64_t(); } - operator int() const { - int64_t value = operator int64_t(); - ICHECK_LE(value, std::numeric_limits::max()); - ICHECK_GE(value, std::numeric_limits::min()); - return value; - } - operator bool() const { - if (auto opt = TryAsBool()) { - return opt.value(); - } else if (auto opt = TryAsInt()) { - return opt.value(); - } else { - LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); - } - } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; @@ -678,12 +636,6 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; std::optional TryAsInt() const { // Helper function to reduce duplication in the variable integer @@ -707,16 +659,6 @@ class TVMPODValue_ { } } - std::optional TryAsBool() const { - // Booleans may be kept distinct from Int by using Box and - // Box. - if (IsObjectRef()) { - return AsObjectRef()->value; - } else { - return std::nullopt; - } - } - protected: friend class TVMArgsSetter; friend class TVMRetValue; @@ -730,13 +672,100 @@ class TVMPODValue_ { int type_code_; }; +/*! \brief A utility class that adds methods useful for each POD type + * + * These cannot be provided in the base PODValue_ class, because + * TVMArgValue and TVMRetValue have different semantics for kTVMStr + * and kTVMBytes. + * + * kTVMStr: + * + * For `TVMArgValue`, the active variant is `v_str`, a `const + * char*`. For `TVMRetValue`, the active variant is `v_handle`, + * and should be cast from `void*` to `std::string*`. + * + * kTVMBytes: + * + * The active variant is `v_handle`, a `void*`. For + * `TVMArgValue`, should be cast to `TVMByteArray*`. For + * `TVMRetValue`, should be cast to `std::string*`. + * + * When converting into an `ObjectRef`, a string may be used to build + * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use + * different representations for strings, any utility funciton which + * might attempt a conversion to an `ObjectRef` must be performed + * within a context that is aware of the derived class. + */ +template +class TVMPODValue_CRTP_ : public TVMPODValue_ { + public: + using TVMPODValue_::TVMPODValue_; + + // ObjectRef handling + template ::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + + std::optional TryAsBool() const { + // Booleans may be kept distinct from Int by using Box and + // Box. + if (IsObjectRef()) { + return AsObjectRef()->value; + } else { + return std::nullopt; + } + } + + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (auto opt = TryAsFloat()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); + } + } + operator int64_t() const { + if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } + operator uint64_t() const { return operator int64_t(); } + operator int() const { + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; + } + operator bool() const { + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } +}; + /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ -class TVMArgValue : public TVMPODValue_ { +class TVMArgValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMArgValue() {} @@ -745,21 +774,21 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; // conversion operator. operator std::string() const { @@ -796,15 +825,15 @@ class TVMArgValue : public TVMPODValue_ { * * \note For internal development purpose only. */ -class TVMMovableArgValue_ : public TVMPODValue_ { +class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; @@ -886,7 +915,7 @@ class TVMMovableArgValueWithContext_ { * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ -class TVMRetValue : public TVMPODValue_ { +class TVMRetValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMRetValue() {} @@ -894,28 +923,28 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -2131,8 +2160,9 @@ inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { } } +template template -inline bool TVMPODValue_::IsObjectRef() const { +inline bool TVMPODValue_CRTP_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { @@ -2162,8 +2192,9 @@ inline bool TVMPODValue_::IsObjectRef() const { ObjectTypeChecker::Check(static_cast(value_.v_handle))); } +template template -inline TObjectRef TVMPODValue_::AsObjectRef() const { +inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { static_assert(std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; @@ -2255,8 +2286,17 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { } if constexpr (std::is_base_of_v) { - if (type_code_ == kTVMStr) { - return String(value_.v_str); + if (type_code_ == kTVMStr || type_code_ == kTVMBytes) { + // This step is the reason why `AsObjectRef` cannot be provided + // in the base `TVMPODValue_` class. Because `TVMArgValue` and + // `TVMRetValue` have different implementations of `operator + // std::string`, with different interpretations of `kTVMStr` and + // `kTVMBytes`, we must delegate to those implementations. + // + // This could be done with a pure virtual method in + // `TVMPODValue_`, but that would require a vtable lookup during + // FFI conversions, imposing a runtime overhead. + return String(static_cast(this)->operator std::string()); } } @@ -2373,17 +2413,10 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); - } - } - - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); + template + static String From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return val.template AsObjectRef(); } else { return tvm::runtime::String(val.operator std::string()); } @@ -2525,7 +2558,7 @@ struct PackedFuncValueConverter> { return opt.value(); } - if (auto opt = TryValueConverter(val)) { + if (auto opt = TryValueConverter(val)) { return opt.value(); } @@ -2536,10 +2569,10 @@ struct PackedFuncValueConverter> { << " but got " << ArgTypeCode2Str(val.type_code()); } - template - static Optional TryAsObjectRef(const TVMPODValue_& val) { - if (val.IsObjectRef()) { - return VType(val.AsObjectRef()); + template + static Optional TryAsObjectRef(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return VType(val.template AsObjectRef()); } else if constexpr (sizeof...(VarRest)) { return TryAsObjectRef(val); } else { @@ -2547,7 +2580,7 @@ struct PackedFuncValueConverter> { } } - template + template static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); @@ -2555,7 +2588,7 @@ struct PackedFuncValueConverter> { } if constexpr (sizeof...(VarRest)) { - return TryValueConverter(val); + return TryValueConverter(val); } else { return NullOpt; }