From 5f22be4d83ca698e316ac342f32f5b4d38155ca8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 5 Aug 2024 08:19:20 -0500 Subject: [PATCH] [FFI][RUNTIME] Introduce runtime boxed types for int/float/bool (#16183) * [Container] Support non-nullable types in Array::Map Prior to this commit, the `Array::Map` member function could only be applied to nullable object types. This was due to the internal use of `U()` as the default value for initializing the output `ArrayNode`, where `U` is the return type of the mapping function. This default constructor is only available for nullable types, and would result in a compile-time failure for non-nullable types. This commit replaces `U()` with `ObjectRef()` in `Array::Map`, removing this limitation. Since all items in the output array are overwritten before returning to the calling scope, initializing the output array with `ObjectRef()` does not violate type safety. * [FFI] Separate runtime types from IR types for int/float/bool Prior to this commit, `int`, `float`, and `bool` arguments from Python were converted to `IntImm`, `FloatImm`, and `Bool`. These are subtypes of `PrimExpr`, and should only be used at compile-time. By automatically applying this conversion as part of the FFI, these types are required to be present whenever a primitive is converted to a `tvm::ObjectRef`. This can become especially fragile for an end-user when storing objects into a TVM container. Because TVM containers require all contents to be `ObjectRef` subclasses, an automatic conversion may be applied on storing into a container, resulting in an unexpected type being retrieved from the container. For example, this currently occurs in Relax when extracting a `R.Prim` from a `R.Tuple`. This commit introduces a `Box` type for storage of boxed primitives at runtime, distinct from the IR types. * Primitive arguments provided to a PackedFunc that requires an `ObjectRef` will be converted to the corresponding boxed type. (e.g. Passing a Python `int` to a C++ function accepting `ObjectRef` produces a `Box`. * Boxed primitives provided to a PackedFunc that requires an unboxed primitive will be converted to the corresponding primitive. * PackedFunc return values of `ObjectRef` are converted to the corresponding primitive, if present. (e.g. If a `tuple_getitem` with static return type `ObjectRef` returns a `Box`, it will be unwrapped to a python `int`.) Together, these three rules provide backwards compatibility for existing PackedFunc definitions, while avoiding exposing the user to any container-induced type conversions betweeen primitive types and `ObjectRef`. * Fix unit test failure after merge * Fix breakage in new unit test --- include/tvm/ir/attrs.h | 76 +- include/tvm/ir/expr.h | 130 +++- include/tvm/ir/transform.h | 34 +- include/tvm/meta_schedule/schedule_rule.h | 8 +- include/tvm/relay/attrs/transform.h | 2 +- include/tvm/runtime/c_runtime_api.h | 5 +- .../tvm/runtime/container/boxed_primitive.h | 143 ++++ include/tvm/runtime/container/variant.h | 2 +- include/tvm/runtime/ndarray.h | 2 + include/tvm/runtime/packed_func.h | 689 ++++++++++++++---- include/tvm/target/target.h | 10 +- include/tvm/target/target_kind.h | 4 +- include/tvm/tir/expr.h | 57 ++ include/tvm/tir/function.h | 2 +- include/tvm/tir/schedule/schedule.h | 5 +- python/tvm/_ffi/_ctypes/object.py | 22 + python/tvm/_ffi/_ctypes/packed_func.py | 7 +- python/tvm/_ffi/_ctypes/types.py | 3 + python/tvm/_ffi/_cython/base.pxi | 5 +- python/tvm/_ffi/_cython/object.pxi | 10 + python/tvm/_ffi/_cython/packed_func.pxi | 9 +- python/tvm/_ffi/runtime_ctypes.py | 3 +- python/tvm/driver/tvmc/registry.py | 22 +- python/tvm/ir/attrs.py | 2 +- python/tvm/ir/expr.py | 5 +- python/tvm/meta_schedule/tune_context.py | 3 +- python/tvm/relax/op/statistical.py | 22 +- python/tvm/relax/testing/ast_printer.py | 18 +- python/tvm/relax/training/setup_trainer.py | 4 +- python/tvm/relax/utils.py | 3 + .../relay/backend/contrib/ethosu/legalize.py | 2 +- python/tvm/relay/op/_tensor_grad.py | 3 + python/tvm/relay/op/_transform.py | 8 +- python/tvm/relay/op/contrib/ethosu.py | 4 +- python/tvm/relay/op/transform.py | 25 +- .../transform/fake_quantization_to_integer.py | 5 +- python/tvm/runtime/__init__.py | 4 +- python/tvm/runtime/container.py | 38 + python/tvm/runtime/object_generic.py | 75 +- python/tvm/script/parser/tir/parser.py | 2 + python/tvm/te/hybrid/calls.py | 12 +- python/tvm/te/hybrid/parser.py | 4 +- python/tvm/te/hybrid/utils.py | 28 +- python/tvm/te/operation.py | 1 - python/tvm/te/tensor.py | 11 +- python/tvm/tir/__init__.py | 1 + python/tvm/tir/expr.py | 4 + python/tvm/tir/ir_builder.py | 6 +- python/tvm/tir/op.py | 151 ++-- python/tvm/tir/schedule/trace.py | 15 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 2 +- python/tvm/topi/cuda/batch_matmul.py | 8 +- rust/tvm-rt/src/module.rs | 5 +- rust/tvm-sys/src/packed_func.rs | 35 +- src/auto_scheduler/compute_dag.cc | 16 +- .../search_policy/sketch_policy_rules.cc | 3 +- src/auto_scheduler/search_policy/utils.h | 12 +- .../msc/core/printer/msc_base_printer.cc | 9 + .../msc/core/printer/prototxt_printer.cc | 4 + src/contrib/msc/core/utils.cc | 4 + src/driver/driver_api.cc | 5 +- src/ir/attrs.cc | 89 +++ src/ir/expr.cc | 17 +- src/ir/transform.cc | 41 +- src/meta_schedule/database/database_utils.cc | 10 +- src/meta_schedule/database/json_database.cc | 4 +- .../mutator/mutate_thread_binding.cc | 2 +- src/meta_schedule/mutator/mutate_tile_size.cc | 6 +- src/meta_schedule/mutator/mutate_unroll.cc | 6 +- .../schedule/cuda/thread_bind.cc | 6 +- .../schedule_rule/cross_thread_reduction.cc | 8 +- .../schedule_rule/multi_level_tiling.cc | 5 +- .../parallel_vectorize_unroll.cc | 6 +- .../schedule_rule/schedule_rule.cc | 12 +- src/meta_schedule/utils.h | 38 +- src/node/boxed_primitive.cc | 134 ++++ src/node/script_printer.cc | 16 +- src/node/structural_equal.cc | 37 +- src/relax/backend/vm/codegen_vm.cc | 2 + src/relax/backend/vm/codegen_vm_tir.cc | 30 +- src/relax/op/tensor/create.cc | 2 +- src/relax/op/tensor/create.h | 2 +- src/relax/op/tensor/manipulate.cc | 6 +- src/relax/op/tensor/manipulate.h | 4 +- .../backend/contrib/cmsisnn/compiler_attrs.cc | 2 +- src/relay/backend/contrib/cmsisnn/target.cc | 2 +- src/relay/backend/contrib/cutlass/target.cc | 18 +- .../backend/contrib/ethosn/ethosn_api.cc | 6 +- src/relay/backend/contrib/ethosu/codegen.cc | 3 +- .../backend/contrib/ethosu/preprocess.cc | 4 +- .../contrib/example_target_hooks/target.cc | 2 +- src/relay/backend/contrib/tensorrt/codegen.cc | 4 +- src/relay/backend/contrib/tensorrt/target.cc | 14 +- src/relay/backend/contrib/uma/targets.cc | 7 +- src/relay/backend/executor.cc | 10 +- src/relay/backend/runtime.cc | 4 +- src/relay/ir/dataflow_matcher.cc | 36 + src/relay/op/make_op.h | 2 +- src/relay/op/tensor/transform.cc | 48 +- .../transforms/combine_parallel_op_batch.cc | 2 +- src/relay/transforms/fold_constant.cc | 2 +- src/relay/transforms/higher_order_gradient.cc | 2 - src/relay/transforms/to_mixed_precision.cc | 4 +- src/runtime/boxed_primitive.cc | 65 ++ src/runtime/crt/common/crt_runtime_api.c | 8 +- src/runtime/disco/bcast_session.cc | 8 +- src/runtime/minrpc/rpc_reference.h | 8 + src/runtime/relax_vm/builtin.cc | 10 +- .../printer/doc_printer/python_doc_printer.cc | 23 +- src/script/printer/ir/misc.cc | 15 + src/script/printer/relax/tir.cc | 6 +- src/support/array.h | 52 +- src/support/ffi_testing.cc | 52 ++ src/target/llvm/codegen_cpu.cc | 29 +- src/target/llvm/llvm_instance.cc | 14 +- src/target/tag.cc | 66 +- src/target/target.cc | 66 +- src/target/target_kind.cc | 137 ++-- src/te/operation/compute_op.cc | 26 +- src/te/operation/create_primfunc.cc | 15 +- src/te/operation/placeholder_op.cc | 12 +- src/te/schedule/schedule_dataflow_rewrite.cc | 7 +- .../analysis/calculate_allocated_memory.cc | 2 +- src/tir/ir/expr.cc | 20 +- src/tir/ir/function.cc | 7 + src/tir/ir/specialize.cc | 2 +- src/tir/ir/stmt.cc | 32 +- src/tir/ir/utils.cc | 68 ++ src/tir/ir/utils.h | 51 ++ src/tir/op/op.cc | 16 +- src/tir/schedule/concrete_schedule.cc | 14 +- src/tir/schedule/concrete_schedule.h | 5 +- src/tir/schedule/instruction_traits.h | 5 + src/tir/schedule/primitive.h | 5 +- src/tir/schedule/primitive/annotate.cc | 3 + src/tir/schedule/primitive/sampling.cc | 36 +- src/tir/schedule/trace.cc | 12 +- src/tir/schedule/traced_schedule.cc | 6 +- src/tir/schedule/traced_schedule.h | 5 +- .../transforms/inline_private_functions.cc | 2 +- src/tir/transforms/ir_utils.h | 1 + src/tir/transforms/lower_tvm_builtin.cc | 2 + src/tir/transforms/make_packed_api.cc | 45 +- tests/cpp/relay/backend/runtime_test.cc | 10 +- tests/cpp/target_test.cc | 56 +- .../test_runtime_packed_func.py | 18 +- .../arith/test_arith_canonical_simplify.py | 23 +- .../arith/test_arith_iter_affine_map.py | 35 +- .../test_arith_narrow_predicate_expression.py | 21 +- .../arith/test_arith_rewrite_simplify.py | 63 +- .../test_arith_solve_linear_equations.py | 15 +- .../test_arith_solve_linear_inequality.py | 11 +- .../codegen/test_target_codegen_cuda.py | 2 +- .../codegen/test_target_codegen_llvm.py | 41 ++ .../ir/test_container_structural_equal.py | 30 +- tests/python/ir/test_ir_container.py | 15 +- tests/python/ir/test_ir_type.py | 9 +- .../test_distributed_tvmscript_printer.py | 4 +- tests/python/relax/test_ast_printer.py | 2 +- .../relax/test_backend_dispatch_sort_scan.py | 10 +- .../relax/test_tvmscript_printer_relax.py | 6 +- tests/python/relax/test_vm_build.py | 2 +- tests/python/relax/test_vm_codegen_tir.py | 5 +- tests/python/relay/test_dataflow_pattern.py | 3 +- tests/python/relay/test_executor.py | 2 +- tests/python/relay/test_runtime.py | 4 +- tests/python/relay/test_type_infer.py | 65 +- .../python/runtime/test_runtime_container.py | 130 +++- tests/python/te/test_te_schedule_tensorize.py | 20 +- tests/python/te/test_te_tag.py | 10 +- tests/python/tir-base/test_lower_build.py | 2 +- tests/python/tir-base/test_tir_buffer.py | 17 +- tests/python/tir-base/test_tir_index_map.py | 55 +- tests/python/tir-base/test_tir_nodes.py | 27 +- .../test_tir_schedule_sampling.py | 2 +- .../tir-schedule/test_tir_schedule_state.py | 4 +- ...est_tir_transform_compact_buffer_region.py | 71 +- ...tir_transform_instrument_bound_checkers.py | 8 +- .../test_tir_transform_make_packed_api.py | 139 ++++ .../test_tir_transform_storage_rewrite.py | 4 +- .../tvmscript/test_tvmscript_error_report.py | 17 +- .../tvmscript/test_tvmscript_printer_tir.py | 12 +- .../tvmscript/test_tvmscript_roundtrip.py | 31 +- vta/python/vta/transform.py | 13 +- 184 files changed, 3215 insertions(+), 1221 deletions(-) create mode 100644 include/tvm/runtime/container/boxed_primitive.h create mode 100644 src/node/boxed_primitive.cc create mode 100644 src/runtime/boxed_primitive.cc create mode 100644 src/tir/ir/utils.cc create mode 100644 src/tir/ir/utils.h diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 81611b1a535a..d038d5f59a5f 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -265,7 +265,16 @@ class DictAttrs : public Attrs { auto it = node->dict.find(attr_key); if (it != node->dict.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + Optional obj = ret; + return obj; } else { return default_value; } @@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() { return TAttrs(n); } +/*! + * \brief Copy the DictAttrs, but overrides attributes with the + * entries from \p attrs. + * + * \param attrs The DictAttrs to update + * + * \param new_attrs Key/values attributes to add to \p attrs. + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs); + +/*! + * \brief Copy the DictAttrs, but overrides a single attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The update to insert or update. + * + * \param value The new value of the attribute + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value); + +inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) { + return WithAttr(std::move(attrs), String(key), std::move(value)); +} + +/*! + * \brief Copy the DictAttrs, but without a specific attribute. + * + * \param attrs The DictAttrs to update + * + * \param key The key to remove + * + * \returns The new DictAttrs with updated attributes. + */ +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key); + /*! * \brief Copy the function or module, but overrides * the attribute value key with the value. @@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value); - } else { - Map dict = {{attr_key, attr_value}}; - node->attrs = DictAttrs(dict); - } + node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value); + return input; } @@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map attrs) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); TNode* node = input.CopyOnWrite(); - if (node->attrs.defined()) { - for (const auto& pair : attrs) { - node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second); - } - } else { - node->attrs = DictAttrs(std::move(attrs)); - } + + node->attrs = WithAttrs(std::move(node->attrs), attrs); + return input; } @@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) { using TNode = typename TFunc::ContainerType; static_assert(TNode::_type_final, "Can only operate on the leaf nodes"); - if (input->attrs.defined()) { - TNode* node = input.CopyOnWrite(); - node->attrs.CopyOnWrite()->dict.erase(attr_key); - } + TNode* node = input.CopyOnWrite(); + node->attrs = WithoutAttr(std::move(node->attrs), attr_key); + return input; } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 9b522389227a..efde52385177 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const { namespace tvm { namespace runtime { -// common rule for RetValue and ArgValue + +// Automatic conversion into IntImm, Integer, and Bool, when called +// through the FFI. Automatic conversions into PrimExpr are +// registered in "tvm/tir/expr.h", as it includes conversions to the +// TIR-only StringImm. +// +// While the FFI only requires the From() method, these +// implementations also define a TryFrom() method to avoid duplicate +// logic in the PrimExpr conversion. + template <> -struct PackedFuncValueConverter { - static PrimExpr From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return PrimExpr(ObjectPtr(nullptr)); - } - if (val.type_code() == kDLInt) { - int64_t value = val.operator int64_t(); - if (value > std::numeric_limits::max() || value < std::numeric_limits::min()) { - return IntImm(runtime::DataType::Int(64), value); - } - return IntImm(runtime::DataType::Int(32), val.operator int()); - } - if (val.type_code() == kDLFloat) { - return FloatImm(runtime::DataType::Float(32), val.operator double()); +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsInt()) { + int64_t value = opt.value(); + auto dtype = + (value > std::numeric_limits::max() || value < std::numeric_limits::min()) + ? DataType::Int(64) + : DataType::Int(32); + return IntImm(dtype, value); + } else if (auto opt = val.TryAsBool()) { + return IntImm(DataType::Int(32), opt.value()); + } else { + return NullOpt; } + } - return PrimExpr::FromObject_(val.AsObjectRef()); + template + static tvm::IntImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } } }; template <> struct PackedFuncValueConverter { - static tvm::Integer From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Integer(ObjectPtr(nullptr)); + template + static tvm::Integer From(const PODSubclass& val) { + if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return Integer(opt.value()); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - return Integer(val.operator int()); - } - return val.AsObjectRef(); } }; template <> struct PackedFuncValueConverter { - static tvm::Bool From(const TVMPODValue_& val) { - if (val.type_code() == kTVMNullptr) { - return Bool(ObjectPtr(nullptr)); + template + static Optional TryFrom(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + return tvm::Bool(opt.value()); + } else if (auto opt = val.TryAsInt()) { + int value = opt.value(); + ICHECK(value == 0 || value == 1) + << "ValueError: boolean value can only be 0 or 1, but get " << value; + return tvm::Bool(static_cast(value)); + } else { + return NullOpt; + } + } + + template + static tvm::Bool From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); } - if (val.type_code() == kTVMArgInt) { - int v = val.operator int(); - ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v; - return Bool(static_cast(v)); + } +}; + +template <> +struct PackedFuncValueConverter { + static Optional TryFrom(const TVMPODValue_& val) { + if (auto opt = val.TryAsFloat()) { + return FloatImm(runtime::DataType::Float(32), opt.value()); + } else { + return NullOpt; + } + } + + template + static tvm::FloatImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +/* \brief Backwards compatibility wrapper for IntImm arguments + * + * In previous versions of TVM, IntImm was the default FFI type for + * integer arguments, instead of runtime::Int. For backwards + * compatibility where the callee has been updated to expected a + * runtime::Int, the caller has not been updated to provide a + * runtime::Int (e.g. relay script parsing), and the auto-unboxing of + * runtime::Int does not apply (e.g. making an `Array`), + * allow the IntImm to be generated. + */ +template <> +struct PackedFuncValueConverter { + template + static runtime::Int From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return runtime::Int(val.template AsObjectRef()->value); + } else { + return val.template AsObjectRef(); } - return val.AsObjectRef(); } }; diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index adf332525020..5828d98206ad 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -271,7 +271,36 @@ class PassContext : public ObjectRef { using ValueNodeType = typename ValueType::ContainerType; // NOTE: we could further update the function later. uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex(); - RegisterConfigOption(key, tindex); + auto type_key = runtime::Object::TypeIndex2Key(tindex); + + auto* reflection = ReflectionVTable::Global(); + + auto legalization = [=](ObjectRef obj) -> ObjectRef { + if (obj->IsInstance::ContainerType>()) { + return reflection->CreateObject(type_key, Downcast>(obj)); + } else { + // Backwards compatibility for config options defined prior to + // https://github.com/apache/tvm/pull/16183. This commit + // changed the default FFI conversion of python integers from + // `tvm::IntImm` to `runtime::Int`. + // + // This backwards compatibility fix can be removed when all + // options registered with TVM_REGISTER_PASS_CONFIG_OPTION are + // updated to use `runtime::Int` and `runtime::Bool`. + TVMRetValue ret; + ret = obj; + try { + ValueType legalized = ret; + return legalized; + } catch (Error& err) { + LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key + << ", but received error when converting to this type.\n" + << err.what(); + } + } + }; + + RegisterConfigOption(key, tindex, legalization); return tindex; } @@ -285,7 +314,8 @@ class PassContext : public ObjectRef { // The exit of a pass context scope. TVM_DLL void ExitWithScope(); // Register configuration key value type. - TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index); + TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization); // Classes to get the Python `with` like syntax. friend class Internal; diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index d91812fb55cb..90aec05187eb 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef { * \param thread_extents Candidates of thread axis extent (values are required to be positive). * \return The schedule rule created */ - TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); + TVM_DLL static ScheduleRule CrossThreadReduction(Array thread_extents); /*! * \brief A rule that randomly select a compute-at location for a free block * \return The schedule rule created @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef { * \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma. * \return The schedule rule created */ - TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // - int max_vectorize_extent, // - Array unroll_max_steps, // + TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, // + int max_vectorize_extent, // + Array unroll_max_steps, // bool unroll_explicit); /*! * \brief Auto bind loops around the block to BlockIdx and ThreadIdx diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 249b9cd0e50d..91020fc7443b 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode { }; // struct SqueezeAttrs struct SplitAttrs : public tvm::AttrsNode { - ObjectRef indices_or_sections; + Variant> indices_or_sections; int axis; TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") { diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f1046ef24266..b4c653a0a59e 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -81,6 +81,7 @@ #ifdef __cplusplus extern "C" { #endif +#include #include #include @@ -186,11 +187,12 @@ typedef enum { kTVMBytes = 12U, kTVMNDArrayHandle = 13U, kTVMObjectRValueRefArg = 14U, + kTVMArgBool = 15U, // Extension codes for other frameworks to integrate TVM PackedFunc. // To make sure each framework's id do not conflict, use first and // last sections to mark ranges. // Open an issue at the repo if you need a section of code. - kTVMExtBegin = 15U, + kTVMExtBegin = 16U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, // The following section of code is used for non-reserved types. @@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle; */ typedef union { int64_t v_int64; + bool v_bool; double v_float64; void* v_handle; const char* v_str; diff --git a/include/tvm/runtime/container/boxed_primitive.h b/include/tvm/runtime/container/boxed_primitive.h new file mode 100644 index 000000000000..8d01b5dc17b5 --- /dev/null +++ b/include/tvm/runtime/container/boxed_primitive.h @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/container/boxed_primitive.h + * \brief Runtime container types for primitives stored as ObjectRef. + */ +#ifndef TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ +#define TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +namespace detail { +/* \brief Provide the BoxNode type traits in templated contexts + * + * The Box class is used in many templated contexts, and is easier + * to have templated over the primitive type. + * + * However, much of the TVM type system depends on classes having a + * unique name. For example, the use of `Object::IsInstance` depends + * on `Object::GetOrAllocRuntimeTypeIndex`. Any duplicate names will + * result in duplicate indices, and invalid downcasting. Furthermore, + * the name must be specified in the Python FFI using + * `tvm._ffi.register_object`. This prevents use of + * `typeid(T)::name()` to build a unique name, as the name is not + * required to be human-readable or consistent across compilers. + * + * This utility struct should be specialized over the primitive type + * held by the box, to allow explicit listing of the `_type_key` and + * other similar tratis. + * + * Note: This should only contain traits that are required at runtime, + * and should *not* contain extensions for features that are only + * available at compile-time. For integration with compile-time-only + * functionality (e.g. StructuralHash, StructuralEqual), see + * `BoxNodeCompileTimeTraits` in `src/node/boxed_primitive.cc`. + */ +template +struct BoxNodeRuntimeTraits; + +} // namespace detail + +template +class BoxNode : public Object { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + explicit BoxNode(Prim value) : value(value) {} + + /*! \brief The boxed value */ + Prim value; + + static constexpr const char* _type_key = detail::BoxNodeRuntimeTraits::_type_key; + static constexpr bool _type_has_method_visit_attrs = false; + TVM_DECLARE_FINAL_OBJECT_INFO(BoxNode, Object); +}; + +template +class Box : public ObjectRef { + public: + /*! \brief Constructor + * + * \param value The value to be boxed + */ + Box(Prim value) : ObjectRef(make_object>(value)) {} // NOLINT(*) + + operator Prim() const { return (*this)->value; } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Box, ObjectRef, BoxNode); +}; + +/*! \brief Boxed version of C++ int64_t + * + * Can be used to store POD integer values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + */ +using Int = Box; + +/*! \brief Boxed version of C++ double + * + * Can be used to store POD floating-point values as a TVM ObjectRef. + * Used for FFI handling, and for storing POD types inside TVM + * containers. + */ +using Float = Box; + +/*! \brief Boxed version of C++ bool + * + * Can be used to store POD boolean values as a TVM ObjectRef. Used + * for FFI handling, and for storing POD types inside TVM containers. + * + * When passing from Python to C++, TVM PackedFunc conversion follow + * C++ conversion rules, and allow bool->int and int->bool + * conversions. When passing from C++ to Python, the types are + * returned as bool or int. If the C++ function uses ObjectRef to + * hold the object, a Python to C++ to Python round trip will preserve + * the distinction between bool and int. + */ +using Bool = Box; + +namespace detail { +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxInt"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxFloat"; +}; + +template <> +struct BoxNodeRuntimeTraits { + static constexpr const char* _type_key = "runtime.BoxBool"; +}; +} // namespace detail + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CONTAINER_BOXED_PRIMITIVE_H_ diff --git a/include/tvm/runtime/container/variant.h b/include/tvm/runtime/container/variant.h index 7953ac47c1cf..e8defa4e6fee 100644 --- a/include/tvm/runtime/container/variant.h +++ b/include/tvm/runtime/container/variant.h @@ -82,7 +82,7 @@ class Variant : public ObjectRef { public: /* \brief Helper utility to check if the type is part of the variant */ template - static constexpr bool is_variant = (std::is_same_v || ...); + static constexpr bool is_variant = (std::is_base_of_v || ...); /* \brief Helper utility for SFINAE if the type is part of the variant */ template diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 3eb225fccffe..fef61a753103 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -226,6 +226,8 @@ class NDArray : public ObjectRef { protected: friend class TVMPODValue_; + template + friend class TVMPODValue_CRTP_; friend class TVMRetValue; friend class TVMArgsSetter; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 7266f8c4a50a..98196c13af7f 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -37,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -429,9 +431,11 @@ inline const char* ArgTypeCode2Str(int type_code); inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) +#define TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) \ + "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) + // macro to check type code. -#define TVM_CHECK_TYPE_CODE(CODE, T) \ - ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) +#define TVM_CHECK_TYPE_CODE(CODE, T) ICHECK_EQ(CODE, T) << TVM_LOG_INCORRECT_TYPE_CODE(CODE, T) /*! * \brief Type traits for runtime type check during FFI conversion. @@ -510,6 +514,7 @@ struct ObjectTypeChecker> { } static std::string TypeName() { return "Array[" + ObjectTypeChecker::TypeName() + "]"; } }; + template struct ObjectTypeChecker> { static Optional CheckAndGetMismatch(const Object* ptr) { @@ -545,40 +550,43 @@ struct ObjectTypeChecker> { } }; +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + return ObjectTypeChecker::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { return ObjectTypeChecker::Check(ptr); } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { return ObjectTypeChecker::TypeName(); } +}; + +template +struct ObjectTypeChecker> { + static Optional CheckAndGetMismatch(const Object* ptr) { + auto try_first = ObjectTypeChecker::CheckAndGetMismatch(ptr); + if (!try_first.defined()) { + return try_first; + } + + return ObjectTypeChecker>::CheckAndGetMismatch(ptr); + } + static bool Check(const Object* ptr) { + return ObjectTypeChecker::Check(ptr) || + ObjectTypeChecker>::Check(ptr); + } + static std::string TypeName() { return "Variant[" + VariantNames() + "]"; } + static std::string VariantNames() { + return ObjectTypeChecker::TypeName() + ", " + + ObjectTypeChecker>::VariantNames(); + } +}; + /*! * \brief Internal base class to * handle conversion to POD values. */ class TVMPODValue_ { public: - operator double() const { - // Allow automatic conversion from int to float - // This avoids errors when user pass in int from - // the frontend while the API expects a float. - if (type_code_ == kDLInt) { - return static_cast(value_.v_int64); - } - TVM_CHECK_TYPE_CODE(type_code_, kDLFloat); - return value_.v_float64; - } - operator int64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator uint64_t() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64; - } - operator int() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - ICHECK_LE(value_.v_int64, std::numeric_limits::max()); - ICHECK_GE(value_.v_int64, std::numeric_limits::min()); - return static_cast(value_.v_int64); - } - operator bool() const { - TVM_CHECK_TYPE_CODE(type_code_, kDLInt); - return value_.v_int64 != 0; - } operator void*() const { if (type_code_ == kTVMNullptr) return nullptr; if (type_code_ == kTVMDLTensorHandle) return value_.v_handle; @@ -628,12 +636,39 @@ class TVMPODValue_ { T* ptr() const { return static_cast(value_.v_handle); } - // ObjectRef handling - template ::value>::type> - inline bool IsObjectRef() const; - template - inline TObjectRef AsObjectRef() const; + + std::optional TryAsBool() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kTVMArgBool) { + return value_.v_bool; + } else { + return std::nullopt; + } + } + + std::optional TryAsInt() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLInt) { + return value_.v_int64; + } else { + return std::nullopt; + } + } + + std::optional TryAsFloat() const { + // Helper function to reduce duplication in the variable integer + // conversions. This is publicly exposed, as it can be useful in + // specializations of PackedFuncValueConverter. + if (type_code_ == kDLFloat) { + return value_.v_float64; + } else { + return std::nullopt; + } + } protected: friend class TVMArgsSetter; @@ -648,13 +683,90 @@ class TVMPODValue_ { int type_code_; }; +/*! \brief A utility class that adds methods useful for each POD type + * + * These cannot be provided in the base PODValue_ class, because + * TVMArgValue and TVMRetValue have different semantics for kTVMStr + * and kTVMBytes. + * + * kTVMStr: + * + * For `TVMArgValue`, the active variant is `v_str`, a `const + * char*`. For `TVMRetValue`, the active variant is `v_handle`, + * and should be cast from `void*` to `std::string*`. + * + * kTVMBytes: + * + * The active variant is `v_handle`, a `void*`. For + * `TVMArgValue`, should be cast to `TVMByteArray*`. For + * `TVMRetValue`, should be cast to `std::string*`. + * + * When converting into an `ObjectRef`, a string may be used to build + * a `tvm::runtime::String`. Because TVMArgValue and TVMRetValue use + * different representations for strings, any utility funciton which + * might attempt a conversion to an `ObjectRef` must be performed + * within a context that is aware of the derived class. + */ +template +class TVMPODValue_CRTP_ : public TVMPODValue_ { + public: + using TVMPODValue_::TVMPODValue_; + + // ObjectRef handling + template ::value>::type> + inline bool IsObjectRef() const; + template + inline TObjectRef AsObjectRef() const; + + operator double() const { + // Allow automatic conversion from int to float + // This avoids errors when user pass in int from + // the frontend while the API expects a float. + if (auto opt = TryAsFloat()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLFloat); + } + } + operator int64_t() const { + if (auto opt = TryAsInt()) { + return opt.value(); + } else if (auto opt = TryAsBool()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } + operator uint64_t() const { return operator int64_t(); } + operator int() const { + int64_t value = operator int64_t(); + ICHECK_LE(value, std::numeric_limits::max()); + ICHECK_GE(value, std::numeric_limits::min()); + return value; + } + operator bool() const { + if (auto opt = TryAsBool()) { + return opt.value(); + } else if (auto opt = TryAsInt()) { + return opt.value(); + } else { + LOG(FATAL) << TVM_LOG_INCORRECT_TYPE_CODE(type_code_, kDLInt); + } + } +}; + /*! * \brief A single argument value to PackedFunc. * Containing both type_code and TVMValue * * Provides utilities to do type cast into other types. */ -class TVMArgValue : public TVMPODValue_ { +class TVMArgValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMArgValue() {} @@ -663,21 +775,21 @@ class TVMArgValue : public TVMPODValue_ { * \param value of the function * \param type_code The type code. */ - TVMArgValue(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMArgValue(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Device; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; // conversion operator. operator std::string() const { @@ -714,15 +826,15 @@ class TVMArgValue : public TVMPODValue_ { * * \note For internal development purpose only. */ -class TVMMovableArgValue_ : public TVMPODValue_ { +class TVMMovableArgValue_ : public TVMPODValue_CRTP_ { public: - TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_(value, type_code) {} + TVMMovableArgValue_(TVMValue value, int type_code) : TVMPODValue_CRTP_(value, type_code) {} // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator NDArray; @@ -804,7 +916,7 @@ class TVMMovableArgValueWithContext_ { * TVMRetValue holds value and will manage the underlying containers * when it stores a complicated data type. */ -class TVMRetValue : public TVMPODValue_ { +class TVMRetValue : public TVMPODValue_CRTP_ { public: /*! \brief default constructor */ TVMRetValue() {} @@ -812,28 +924,28 @@ class TVMRetValue : public TVMPODValue_ { * \brief move constructor from another return value. * \param other The other return value. */ - TVMRetValue(TVMRetValue&& other) : TVMPODValue_(other.value_, other.type_code_) { + TVMRetValue(TVMRetValue&& other) : TVMPODValue_CRTP_(other.value_, other.type_code_) { other.value_.v_handle = nullptr; other.type_code_ = kTVMNullptr; } /*! \brief destructor */ ~TVMRetValue() { this->Clear(); } // reuse converter from parent - using TVMPODValue_::operator double; - using TVMPODValue_::operator int64_t; - using TVMPODValue_::operator uint64_t; - using TVMPODValue_::operator int; - using TVMPODValue_::operator bool; + using TVMPODValue_CRTP_::operator double; + using TVMPODValue_CRTP_::operator int64_t; + using TVMPODValue_CRTP_::operator uint64_t; + using TVMPODValue_CRTP_::operator int; + using TVMPODValue_CRTP_::operator bool; using TVMPODValue_::operator void*; using TVMPODValue_::operator DLTensor*; using TVMPODValue_::operator Device; using TVMPODValue_::operator NDArray; using TVMPODValue_::operator Module; using TVMPODValue_::operator PackedFunc; - using TVMPODValue_::AsObjectRef; - using TVMPODValue_::IsObjectRef; + using TVMPODValue_CRTP_::AsObjectRef; + using TVMPODValue_CRTP_::IsObjectRef; - TVMRetValue(const TVMRetValue& other) : TVMPODValue_() { this->Assign(other); } + TVMRetValue(const TVMRetValue& other) : TVMPODValue_CRTP_() { this->Assign(other); } // conversion operators operator std::string() const { if (type_code_ == kTVMDataType) { @@ -901,8 +1013,8 @@ class TVMRetValue : public TVMPODValue_ { } TVMRetValue& operator=(const DataType& other) { return operator=(other.operator DLDataType()); } TVMRetValue& operator=(bool value) { - this->SwitchToPOD(kDLInt); - value_.v_int64 = value; + this->SwitchToPOD(kTVMArgBool); + value_.v_bool = value; return *this; } TVMRetValue& operator=(std::string value) { @@ -974,7 +1086,8 @@ class TVMRetValue : public TVMPODValue_ { */ static TVMRetValue MoveFromCHost(TVMValue value, int type_code) { // Can move POD and everything under the object system. - ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle); + ICHECK(type_code <= kTVMPackedFuncHandle || type_code == kTVMNDArrayHandle || + type_code == kTVMArgBool); TVMRetValue ret; ret.value_ = value; ret.type_code_ = type_code; @@ -989,9 +1102,9 @@ class TVMRetValue : public TVMPODValue_ { } // ObjectRef handling template ::value>::type> + typename = typename std::enable_if_t>> inline TVMRetValue& operator=(TObjectRef other); - template ::value>::type> + template >> inline operator T() const; private: @@ -1019,9 +1132,11 @@ class TVMRetValue : public TVMPODValue_ { break; } case kTVMObjectHandle: { - // Avoid operator ObjectRef as we already know it is not NDArray/Module - SwitchToObject(kTVMObjectHandle, - GetObjectPtr(static_cast(other.value_.v_handle))); + // We already known it is not NDArray/Module, but + // operator=(ObjectRef) also handles conversions from wrappers + // around primitive types. For NDArray/Module, the duplicate + // checks are removed with if constexpr. + operator=(other.operator ObjectRef()); break; } case kTVMObjectRValueRefArg: { @@ -1265,6 +1380,8 @@ inline const char* ArgTypeCode2Str(int type_code) { switch (type_code) { case kDLInt: return "int"; + case kTVMArgBool: + return "bool"; case kDLUInt: return "uint"; case kDLFloat: @@ -1686,6 +1803,10 @@ class TVMArgsSetter { values_[i].v_int64 = static_cast(value); type_codes_[i] = kDLInt; } + TVM_ALWAYS_INLINE void operator()(size_t i, bool value) const { + values_[i].v_bool = value; + type_codes_[i] = kTVMArgBool; + } TVM_ALWAYS_INLINE void operator()(size_t i, uint64_t value) const { values_[i].v_int64 = static_cast(value); ICHECK_LE(value, static_cast(std::numeric_limits::max())); @@ -1951,38 +2072,110 @@ inline T TVMArgs::At(int i) const { template inline void TVMArgsSetter::SetObject(size_t i, T&& value) const { using ContainerType = typename std::remove_reference::type::ContainerType; - if (value.defined()) { - Object* ptr = value.data_.data_; - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + if (!value.defined()) { + type_codes_[i] = kTVMNullptr; + values_[i].v_handle = nullptr; + return; + } + + Object* ptr = value.data_.data_; + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = NDArray::FFIGetHandle(value); type_codes_[i] = kTVMNDArrayHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMModuleHandle; - } else if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { + return; + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { values_[i].v_handle = ptr; type_codes_[i] = kTVMPackedFuncHandle; - } else if (std::is_rvalue_reference::value) { - values_[i].v_handle = const_cast(&(value.data_.data_)); - type_codes_[i] = kTVMObjectRValueRefArg; - } else { - values_[i].v_handle = value.data_.data_; - type_codes_[i] = kTVMObjectHandle; + return; + } + } + + // Like with BoxInt, unwrap any BoxBool instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_bool = static_cast(ptr)->value; + type_codes_[i] = kTVMArgBool; + return; + } + } + + // If a boxed integer is being returned, always unbox it to the + // primitive type. This must be checked at the PackedFunc level to + // ensure that a boxed primitive argument is round-tripped correctly + // when the boxing is no longer required. + // + // For example, consider a PackedFunc with signature `ObjectRef + // func(Array)`, and returns the first element of that + // array. When passing a Python array `[5, 17.5, "hello"]`, the + // items are converted to `[Box(5), Box(17.5), + // String("hello")]` in order to provide an `Array`. + // + // If we had no additional conversions, the caller would receive the + // return value as a `Box(5)`, which would be unexpected and + // require additional unwrapping. We could perform this check + // inside the PackedFunc, but that would require a large amount of + // duplicated checked, and would require explicit handling of + // `TVMRetValue`. Instead, this conversion is checked in the FFI + // return value, to ensure that boxing/unboxing is applied + // consistently. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_int64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgInt; + return; + } + } + + // Like with BoxInt, unwrap any BoxFloat instances. See the BoxInt + // explanation for more detail. + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + values_[i].v_float64 = static_cast(ptr)->value; + type_codes_[i] = kTVMArgFloat; + return; } + } + + // Final fallback, if the ObjectRef has no special cases that must + // be expressed within the TVMRetValue. + if constexpr (std::is_rvalue_reference_v) { + values_[i].v_handle = const_cast(&(value.data_.data_)); + type_codes_[i] = kTVMObjectRValueRefArg; } else { - type_codes_[i] = kTVMNullptr; - values_[i].v_handle = nullptr; + values_[i].v_handle = value.data_.data_; + type_codes_[i] = kTVMObjectHandle; } } +template template -inline bool TVMPODValue_::IsObjectRef() const { +inline bool TVMPODValue_CRTP_::IsObjectRef() const { using ContainerType = typename TObjectRef::ContainerType; // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { @@ -2012,8 +2205,9 @@ inline bool TVMPODValue_::IsObjectRef() const { ObjectTypeChecker::Check(static_cast(value_.v_handle))); } +template template -inline TObjectRef TVMPODValue_::AsObjectRef() const { +inline TObjectRef TVMPODValue_CRTP_::AsObjectRef() const { static_assert(std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; @@ -2023,8 +2217,10 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expect a not null value of " << ContainerType::_type_key; return TObjectRef(ObjectPtr(nullptr)); } - // NOTE: the following code can be optimized by constant folding. - if (std::is_base_of::value) { + + // NOTE: The following code uses "if constexpr" wherever possible to + // minimize the number of runtime checks. + if constexpr (std::is_base_of_v) { // Casting to a sub-class of NDArray TVM_CHECK_TYPE_CODE(type_code_, kTVMNDArrayHandle); ObjectPtr data = @@ -2033,7 +2229,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of Module TVM_CHECK_TYPE_CODE(type_code_, kTVMModuleHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2041,7 +2238,8 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } - if (std::is_base_of::value) { + + if constexpr (std::is_base_of_v) { // Casting to a sub-class of PackedFunc TVM_CHECK_TYPE_CODE(type_code_, kTVMPackedFuncHandle); ObjectPtr data = GetObjectPtr(static_cast(value_.v_handle)); @@ -2049,6 +2247,7 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { << "Expected " << ContainerType::_type_key << " but got " << data->GetTypeKey(); return TObjectRef(data); } + if (type_code_ == kTVMObjectHandle) { // normal object type check. Object* ptr = static_cast(value_.v_handle); @@ -2062,51 +2261,152 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker::TypeName() << ", but got " << checked_type.value(); return TObjectRef(GetObjectPtr(ptr)); - } else if (std::is_base_of::value && - type_code_ == kTVMNDArrayHandle) { - // Casting to a base class that NDArray can sub-class - ObjectPtr data = - NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); - return TObjectRef(data); - } else if (std::is_base_of::value && - type_code_ == kTVMModuleHandle) { - // Casting to a base class that Module can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else if (std::is_base_of::value && - type_code_ == kTVMPackedFuncHandle) { - // Casting to a base class that PackedFunc can sub-class - return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); - } else { - TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); - return TObjectRef(ObjectPtr(nullptr)); } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMNDArrayHandle) { + // Casting to a base class that NDArray can sub-class + ObjectPtr data = + NDArray::FFIDataFromHandle(static_cast(value_.v_handle)); + return TObjectRef(data); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMModuleHandle) { + // Casting to a base class that Module can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMPackedFuncHandle) { + // Casting to a base class that PackedFunc can sub-class + return TObjectRef(GetObjectPtr(static_cast(value_.v_handle))); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgInt) { + return Int(value_.v_int64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgFloat) { + return Float(value_.v_float64); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMArgBool) { + return Bool(value_.v_bool); + } + } + + if constexpr (std::is_base_of_v) { + if (type_code_ == kTVMStr || type_code_ == kTVMBytes) { + // This step is the reason why `AsObjectRef` cannot be provided + // in the base `TVMPODValue_` class. Because `TVMArgValue` and + // `TVMRetValue` have different implementations of `operator + // std::string`, with different interpretations of `kTVMStr` and + // `kTVMBytes`, we must delegate to those implementations. + // + // This could be done with a pure virtual method in + // `TVMPODValue_`, but that would require a vtable lookup during + // FFI conversions, imposing a runtime overhead. + return String(static_cast(this)->operator std::string()); + } + } + + TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle); + return TObjectRef(ObjectPtr(nullptr)); } template inline TVMRetValue& TVMRetValue::operator=(TObjectRef other) { using ContainerType = typename TObjectRef::ContainerType; const Object* ptr = other.get(); - if (ptr != nullptr) { - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(NDArray(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(Module(std::move(other.data_))); - } - if (std::is_base_of::value || - (std::is_base_of::value && - ptr->IsInstance())) { - return operator=(PackedFunc(std::move(other.data_))); + + if (ptr) { + // Check for special cases of ObjectRef that have explicit + // representation within the TVMRetValue structure. + // (e.g. Unboxing of `runtime::Int` into a primitive integer + // with type code kTVMArgInt.) The checks below are written to + // handle three distinct cases. + // + // 1. If TObjectRef is a subclass of TSpecialCase, the special + // case applies, and can be handled without a runtime check. + // No runtime checks should be performed. + // + // 2. If TSpecialCase is a subclass of TObjectRef, the special + // case might apply, and requires a runtime check. + // + // 3. If neither TObjectRef nor TSpecialCase is a subclass of + // the other, then the special case does not apply. No + // runtime checks should be performed. + // + // Use of `if constexpr` ensures that the C++ subclass checks + // are applied when compiling TVM, and runtime overhead are only + // present when they may be applicable. + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(NDArray(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(Module(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || + std::is_base_of_v) { + if (std::is_base_of_v || + ptr->IsInstance()) { + return operator=(PackedFunc(std::move(other.data_))); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + bool value = static_cast(ptr)->value; + return operator=(value); + } } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + int64_t value = static_cast(ptr)->value; + return operator=(value); + } + } + + if constexpr (std::is_base_of_v || std::is_base_of_v) { + if (std::is_base_of_v || ptr->IsInstance()) { + double value = static_cast(ptr)->value; + return operator=(value); + } + } + + // If the object being stored is not one of the special cases, + // it is stored as an ObjectRef. SwitchToObject(kTVMObjectHandle, std::move(other.data_)); + } else { + // No object is present, set to an explicitly null handle. When + // returning to a Python callee, this will be converted to + // `None`. SwitchToPOD(kTVMNullptr); value_.v_handle = nullptr; } + return *this; } @@ -2139,20 +2439,123 @@ inline PackedFunc Module::GetFunction(const String& name, bool query_imports) { // specializations of PackedFuncValueConverter template <> struct PackedFuncValueConverter<::tvm::runtime::String> { - static String From(const TVMArgValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); + template + static String From(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return val.template AsObjectRef(); } else { return tvm::runtime::String(val.operator std::string()); } } +}; - static String From(const TVMRetValue& val) { - if (val.IsObjectRef()) { - return val.AsObjectRef(); - } else { - return tvm::runtime::String(val.operator std::string()); +template +struct PackedFuncValueConverter> { + static Array From(const TVMArgValue& val) { + auto untyped_array = val.AsObjectRef>(); + + // Attempt to convert each item of the array into the desired + // type. If the items do not require a conversion, no copies are + // made. + return untyped_array.Map([](ObjectRef item) { + // Recursively apply any conversions that have been registered + // with TVM's FFI. + // + // For example, a function that accepts `Array` may + // be called from python with argument `[1,2]`. By the time + // `PackedFuncValueConverter::From` is called, the python list + // has been converted to `Array`, with contents + // converted into `runtime::Int`. Converting the `ObjectRef` + // to `TVMArgValue` unboxes the `runtime::Int` back into a + // primitive with type code `kTVMArgInt`. This primitive can + // then be converted to a PrimExpr using + // `PackedFuncValueConverter::From`. + // + // The use of two conversions, first from python `int` to + // `runtime::Int` and then from `runtime::Int` to `PrimExpr`, + // is a result of the split between `libtvm_runtime.so` and + // `libtvm.so`. The FFI must function correctly in both + // cases, and so conversions applied by default in the Python + // FFI implementation may only produce types that are + // available in both libraries. In the C++ FFI implementation + // (i.e. this file), libtvm.so may apply additional + // conversions that are not present in libtvm_runtime.so. + TVMValue value; + int type_code; + TVMArgsSetter setter(&value, &type_code); + setter(0, item); + TVMArgValue arg(value, type_code); + return PackedFuncValueConverter::From(arg); + }); + } + static Array From(const TVMRetValue& val) { + auto untyped_array = val.AsObjectRef>(); + + return untyped_array.Map([](ObjectRef item) { + TVMRetValue item_val; + item_val = std::move(item); + return PackedFuncValueConverter::From(item_val); + }); + } +}; + +template +struct PackedFuncValueConverter> { + static Map From(const TVMArgValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.first); + TVMArgValue pod_arg(pod_value, type_code); + return PackedFuncValueConverter::From(pod_arg); + }(); + U new_value = [&]() { + TVMValue pod_value; + int type_code; + TVMArgsSetter setter(&pod_value, &type_code); + setter(0, kv.second); + TVMArgValue key_arg(pod_value, type_code); + return PackedFuncValueConverter::From(key_arg); + }(); + output.Set(new_key, new_value); + } + return output; + } + static Map From(const TVMRetValue& val) { + auto untyped_map = val.AsObjectRef>(); + + if (ObjectTypeChecker>::Check(untyped_map.get())) { + // Early bail-out for common case where no type conversions are + // required. + return Downcast>(untyped_map); + } + + Map output; + for (const auto& kv : untyped_map) { + T new_key = [&]() { + TVMRetValue pod; + pod = kv.first; + return PackedFuncValueConverter::From(pod); + }(); + U new_value = [&]() { + TVMRetValue pod; + pod = kv.second; + return PackedFuncValueConverter::From(pod); + }(); + output.Set(new_key, new_value); } + return output; } }; @@ -2181,7 +2584,7 @@ struct PackedFuncValueConverter> { return opt.value(); } - if (auto opt = TryValueConverter(val)) { + if (auto opt = TryValueConverter(val)) { return opt.value(); } @@ -2192,10 +2595,10 @@ struct PackedFuncValueConverter> { << " but got " << ArgTypeCode2Str(val.type_code()); } - template - static Optional TryAsObjectRef(const TVMPODValue_& val) { - if (val.IsObjectRef()) { - return VType(val.AsObjectRef()); + template + static Optional TryAsObjectRef(const PODSubclass& val) { + if (val.template IsObjectRef()) { + return VType(val.template AsObjectRef()); } else if constexpr (sizeof...(VarRest)) { return TryAsObjectRef(val); } else { @@ -2203,15 +2606,15 @@ struct PackedFuncValueConverter> { } } - template + template static Optional TryValueConverter(const PODSubclass& val) { try { return VType(PackedFuncValueConverter::From(val)); - } catch (const InternalError&) { + } catch (const Error&) { } if constexpr (sizeof...(VarRest)) { - return TryValueConverter(val); + return TryValueConverter(val); } else { return NullOpt; } diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index d47ac94e067e..4c1d1fc1f3d2 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -113,7 +113,15 @@ class TargetNode : public Object { "Can only call GetAttr with ObjectRef types."); auto it = attrs.find(attr_key); if (it != attrs.end()) { - return Downcast>((*it).second); + // For backwards compatibility, return through TVMRetValue. + // This triggers any automatic conversions registered with + // PackedFuncValueConverter. Importantly, this allows use of + // `GetAttr` and `GetAttr` for properties that + // are stored internally as `runtime::Box` and + // `runtime::Box`. + TVMRetValue ret; + ret = (*it).second; + return ret; } else { return default_value; } diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 130aea32f844..6b3b9c31a645 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -445,8 +445,8 @@ constexpr const char* kRelayToTIR = "RelayToTIR"; .add_attr_option("model") \ .add_attr_option>("libs") \ .add_attr_option("host") \ - .add_attr_option("from_device") \ - .add_attr_option("target_device_type") + .add_attr_option("from_device") \ + .add_attr_option("target_device_type") } // namespace tvm diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index d9b65dc8745c..28cb022151d2 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -1155,6 +1155,63 @@ inline std::unordered_map as_unordered_map(const Map& dmap) { } // namespace tir } // namespace tvm +namespace tvm { +namespace runtime { + +// Automatic conversion into PrimExpr, when called through the FFI. +// Automatic conversions into IntImm, Integer, and Bool are registered +// in "tvm/ir/expr.h", as they are currently in use outside of TIR. + +template <> +struct PackedFuncValueConverter { + template + static Optional TryFrom(const PODSubclass& val) { + auto type_code = val.type_code(); + bool can_convert = type_code == kTVMDataType || type_code == kTVMBytes || + type_code == kTVMStr || val.template IsObjectRef(); + if (can_convert) { + return tvm::tir::StringImm(PackedFuncValueConverter::From(val)); + } else { + return NullOpt; + } + } + + template + static tvm::tir::StringImm From(const PODSubclass& val) { + if (auto opt = TryFrom(val)) { + return opt.value(); + } else { + return val.template AsObjectRef(); + } + } +}; + +template <> +struct PackedFuncValueConverter { + // Common rule for RetValue and ArgValue. Templated to ensure + // correct delegation to `operator std::string()` for either + // TVMArgValue or TVMRetValue. + template + static PrimExpr From(const PODSubclass& val) { + if (auto opt = val.TryAsBool()) { + // Check against val.TryAsBool directly, to avoid the + // bounds-checking in PackedFuncValueConverter::TryFrom. + return tvm::Bool(opt.value()); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else if (auto opt = PackedFuncValueConverter::TryFrom(val)) { + return opt.value(); + } else { + return PrimExpr::FromObject_(val.template AsObjectRef()); + } + } +}; + +} // namespace runtime +} // namespace tvm + namespace std { template <> struct hash<::tvm::tir::IterVar> : public ::tvm::ObjectPtrHash {}; diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 274ebd0a6558..1d218c6a7c61 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef { * B[vi, vj] = A[vi, vj] * \endcode */ -PrimFunc Specialize(PrimFunc func, const Map& param_map); +PrimFunc Specialize(PrimFunc func, const Map>& param_map); /*! * \brief PrimFunc specific attribute names. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9b23973b6f8f..092bd52d5634 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -224,8 +224,9 @@ class ScheduleNode : public runtime::Object { * \param decision The sampling decision * \return The random variable sampled from candidates */ - virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) = 0; + virtual ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) = 0; /*! * \brief Sample the factors to perfect tile a specific loop * \param loop_rv The loop to be tiled diff --git a/python/tvm/_ffi/_ctypes/object.py b/python/tvm/_ffi/_ctypes/object.py index 520e0e42ebbe..8f674eea2ec6 100644 --- a/python/tvm/_ffi/_ctypes/object.py +++ b/python/tvm/_ffi/_ctypes/object.py @@ -60,14 +60,36 @@ def _return_object(x): tindex = ctypes.c_uint() check_call(_LIB.TVMObjectGetTypeIndex(handle, ctypes.byref(tindex))) cls = OBJECT_TYPE.get(tindex.value, _CLASS_OBJECT) + + # Handle return values that subclass from both TVM objects and + # python native objects (e.g. runtime.String, a subclass of str). if issubclass(cls, PyNativeObject): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) obj.handle = handle return cls.__from_tvm_object__(cls, obj) + # Avoid calling __init__ of cls, instead directly call __new__ # This allows child class to implement their own __init__ obj = cls.__new__(cls) obj.handle = handle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # + # The `hasattr` check is done on the object's class, not the + # object itself, to avoid edge cases that can result in invalid + # error messages. If a C++ `LOG(FATAL) << nested_obj;` statement + # requires C++ to Python conversions in order to print + # `nested_obj`, then the `AttributeError` used internally by + # `hasattr` may overwrite the text being collected by + # `LOG(FATAL)`. By checking for the method on the class instead + # of the instance, we avoid throwing the `AttributeError`. + # if hasattr(type(obj), "__into_pynative_object__"): + # return obj.__into_pynative_object__() + return obj diff --git a/python/tvm/_ffi/_ctypes/packed_func.py b/python/tvm/_ffi/_ctypes/packed_func.py index 5f3aa04914be..6dab1a5db1f4 100644 --- a/python/tvm/_ffi/_ctypes/packed_func.py +++ b/python/tvm/_ffi/_ctypes/packed_func.py @@ -134,6 +134,11 @@ def _make_tvm_args(args, temp_args): elif isinstance(arg, _nd._TVM_COMPATS): values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) type_codes[i] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + values[i].v_bool = arg + type_codes[i] = ArgTypeCode.BOOL elif isinstance(arg, Integral): values[i].v_int64 = arg type_codes[i] = ArgTypeCode.INT @@ -147,7 +152,7 @@ def _make_tvm_args(args, temp_args): values[i].v_int64 = _device_to_int64(arg) type_codes[i] = ArgTypeCode.DLDEVICE elif isinstance(arg, (bytearray, bytes)): - # from_buffer only taeks in bytearray. + # from_buffer only takes in bytearray. if isinstance(arg, bytes): byte_arr = bytearray(arg) temp_args.append(byte_arr) diff --git a/python/tvm/_ffi/_ctypes/types.py b/python/tvm/_ffi/_ctypes/types.py index 38d3cd72b55d..45f36eafd78a 100644 --- a/python/tvm/_ffi/_ctypes/types.py +++ b/python/tvm/_ffi/_ctypes/types.py @@ -27,6 +27,7 @@ class TVMValue(ctypes.Union): _fields_ = [ ("v_int64", ctypes.c_int64), + ("v_bool", ctypes.c_bool), ("v_float64", ctypes.c_double), ("v_handle", ctypes.c_void_p), ("v_str", ctypes.c_char_p), @@ -94,6 +95,7 @@ def _device_to_int64(dev): RETURN_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, @@ -104,6 +106,7 @@ def _device_to_int64(dev): C_TO_PY_ARG_SWITCH = { ArgTypeCode.INT: lambda x: x.v_int64, + ArgTypeCode.BOOL: lambda x: x.v_bool, ArgTypeCode.FLOAT: lambda x: x.v_float64, ArgTypeCode.HANDLE: _return_handle, ArgTypeCode.NULL: lambda x: None, diff --git a/python/tvm/_ffi/_cython/base.pxi b/python/tvm/_ffi/_cython/base.pxi index 69e1355f7d13..0f7e5fcae6bd 100644 --- a/python/tvm/_ffi/_cython/base.pxi +++ b/python/tvm/_ffi/_cython/base.pxi @@ -16,6 +16,7 @@ # under the License. from ..base import raise_last_ffi_error +from libcpp cimport bool as bool_t from libcpp.vector cimport vector from cpython.version cimport PY_MAJOR_VERSION from cpython cimport pycapsule @@ -38,7 +39,8 @@ cdef enum TVMArgTypeCode: kTVMBytes = 12 kTVMNDArrayHandle = 13 kTVMObjectRefArg = 14 - kTVMExtBegin = 15 + kTVMArgBool = 15 + kTVMExtBegin = 16 cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct DLDataType: @@ -66,6 +68,7 @@ cdef extern from "tvm/runtime/c_runtime_api.h": ctypedef struct TVMValue: int64_t v_int64 + bool_t v_bool double v_float64 void* v_handle const char* v_str diff --git a/python/tvm/_ffi/_cython/object.pxi b/python/tvm/_ffi/_cython/object.pxi index 94a9310d7815..ff38cd3d0ec2 100644 --- a/python/tvm/_ffi/_cython/object.pxi +++ b/python/tvm/_ffi/_cython/object.pxi @@ -60,7 +60,17 @@ cdef inline object make_ret_object(void* chandle): obj = _CLASS_OBJECT.__new__(_CLASS_OBJECT) (obj).chandle = chandle + + # Handle return values that must be converted from the TVM object + # to a python native object. This should be used in cases where + # subclassing the python native object is forbidden. For example, + # `runtime.BoxBool` cannot be a subclass of `bool`, as `bool` does + # not allow any subclasses. + # if hasattr(obj, '__into_pynative_object__'): + # return obj.__into_pynative_object__) + return obj + # return obj.__into_pynative_object__() class PyNativeObject: diff --git a/python/tvm/_ffi/_cython/packed_func.pxi b/python/tvm/_ffi/_cython/packed_func.pxi index 3d1e87bf563d..7977f37d0be5 100644 --- a/python/tvm/_ffi/_cython/packed_func.pxi +++ b/python/tvm/_ffi/_cython/packed_func.pxi @@ -45,7 +45,7 @@ cdef int tvm_callback(TVMValue* args, tcode == kTVMModuleHandle or tcode == kTVMNDArrayHandle or tcode == kTVMObjectRefArg or - tcode > kTVMExtBegin): + tcode >= kTVMExtBegin): CHECK_CALL(TVMCbArgToReturn(&value, &tcode)) if tcode != kTVMDLTensorHandle: @@ -118,6 +118,11 @@ cdef inline int make_arg(object arg, ptr = arg._tvm_handle value[0].v_handle = (ptr) tcode[0] = arg.__class__._tvm_tcode + elif isinstance(arg, bool): + # A python `bool` is a subclass of `int`, so this check + # must occur before `Integral`. + value[0].v_bool = arg + tcode[0] = kTVMArgBool elif isinstance(arg, Integral): value[0].v_int64 = arg tcode[0] = kInt @@ -209,6 +214,8 @@ cdef inline object make_ret(TVMValue value, int tcode): return make_ret_object(value.v_handle) elif tcode == kTVMNullptr: return None + elif tcode == kTVMArgBool: + return value.v_bool elif tcode == kInt: return value.v_int64 elif tcode == kFloat: diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index f148e26f3fcb..03dc18ea6e0b 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,7 +48,8 @@ class ArgTypeCode(object): BYTES = 12 NDARRAY_HANDLE = 13 OBJECT_RVALUE_REF_ARG = 14 - EXT_BEGIN = 15 + BOOL = 15 + EXT_BEGIN = 16 class TVMByteArray(ctypes.Structure): diff --git a/python/tvm/driver/tvmc/registry.py b/python/tvm/driver/tvmc/registry.py index c2e74eb1935e..b76202a730a2 100644 --- a/python/tvm/driver/tvmc/registry.py +++ b/python/tvm/driver/tvmc/registry.py @@ -20,11 +20,23 @@ from tvm.driver.tvmc import TVMCException -# We can't tell the type inside an Array but all current options are strings so -# it can default to that. Bool is used alongside Integer but aren't distinguished -# between as both are represented by IntImm -INTERNAL_TO_NATIVE_TYPE = {"runtime.String": str, "IntImm": int, "Array": str} -INTERNAL_TO_HELP = {"runtime.String": " string", "IntImm": "", "Array": " options"} +# We can't tell the type inside an Array but all current options are +# strings so it can default to that. runtime.BoxBool is used to +# distinguish from runtime.BoxInt. +INTERNAL_TO_NATIVE_TYPE = { + "runtime.String": str, + "runtime.BoxBool": bool, + "runtime.BoxFloat": float, + "runtime.BoxInt": int, + "Array": str, +} +INTERNAL_TO_HELP = { + "runtime.String": " string", + "runtime.BoxBool": " bool", + "runtime.BoxInt": " int", + "runtime.BoxFloat": " float", + "Array": " options", +} def _generate_registry_option_args(parser, registry, name): diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 6f0a6dd7d155..6afb383c9f04 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -61,7 +61,7 @@ def get_int_tuple(self, key): ------- value: Tuple of int """ - return tuple(x.value for x in self.__getattr__(key)) + return tuple(x if isinstance(x, int) else x.value for x in self.__getattr__(key)) def get_int(self, key): """Get a python int value of a key diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index c70ac2acc71b..263976fa98ff 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -20,7 +20,7 @@ import tvm._ffi -from ..runtime import Object, Scriptable, const, convert +from ..runtime import Object, Scriptable from . import _ffi_api from .base import Node, Span from .type import Type @@ -184,9 +184,6 @@ class Range(Node, Scriptable): def __init__( self, begin: PrimExpr, end: Optional[PrimExpr] = None, span: Optional[Span] = None ) -> None: - if end is None: - end = convert(begin) - begin = const(0, dtype=end.dtype, span=span) self.__init_handle_by_constructor__(_ffi_api.Range, begin, end, span) @staticmethod diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 6f76452a57b5..51d9a013d8b3 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -28,6 +28,7 @@ from tvm.runtime import Object from tvm.target import Target from tvm.tir import PrimFunc, Schedule +from tvm.script import tir as T from . import _ffi_api from .logging import Logger, get_logger, get_logging_func @@ -47,7 +48,7 @@ def _normalize_mod(mod: Union[PrimFunc, IRModule]) -> IRModule: if isinstance(mod, PrimFunc): if not (mod.attrs and "global_symbol" in mod.attrs): mod = mod.with_attr("global_symbol", "main") - mod = mod.with_attr("tir.noalias", True) + mod = mod.with_attr("tir.noalias", T.bool(True)) mod = IRModule({"main": mod}) if not isinstance(mod, IRModule): raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") diff --git a/python/tvm/relax/op/statistical.py b/python/tvm/relax/op/statistical.py index eb44696871eb..502d058ffdf6 100644 --- a/python/tvm/relax/op/statistical.py +++ b/python/tvm/relax/op/statistical.py @@ -195,7 +195,7 @@ def cumprod( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumprod op. Return the cumulative product of the elements along a given axis. @@ -213,9 +213,9 @@ def cumprod( Type of the returned array and of the accumulator in which the elements are computed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the product. If + true, the first element is excluded from the product. Returns ------- @@ -247,6 +247,9 @@ def cumprod( cumprod(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 1, 0, 0, 0, 0] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumprod(data, axis, dtype, exclusive) # type: ignore @@ -254,7 +257,7 @@ def cumsum( data: Expr, axis: Optional[int] = None, dtype: Optional[Union[str, DataType]] = None, - exclusive: Optional[bool] = None, + exclusive: bool = False, ): """Numpy style cumsum op. Return the cumulative inclusive sum of the elements along a given axis. @@ -272,9 +275,9 @@ def cumsum( Type of the returned array and of the accumulator in which the elements are summed. If dtype is not specified, it defaults to the dtype of data. - exclusive : Optional[bool] - If true will return exclusive sum in which the first element is not - included. + exclusive : bool + If false (default), all elements are included in the sum. If + true, the first element is excluded from the sum. Returns ------- @@ -306,6 +309,9 @@ def cumsum( cumsum(a, dtype=int32) # dtype should be provided to get the expected results -> [1, 1, 2, 2, 3, 4, 4] """ + if exclusive is None: + exclusive = False + return _ffi_api.cumsum(data, axis, dtype, exclusive) # type: ignore diff --git a/python/tvm/relax/testing/ast_printer.py b/python/tvm/relax/testing/ast_printer.py index 1ed16363b20a..4c670bbe74b2 100644 --- a/python/tvm/relax/testing/ast_printer.py +++ b/python/tvm/relax/testing/ast_printer.py @@ -171,11 +171,19 @@ def visit_call_(self, op: relax.Call) -> str: def display_attrs(attr_key): attr_val = op.attrs[attr_key] - # attrs can be strings but also other types; - # we want to wrap strings in quotes - # (__repr__ would work but it uses single quotes) - attr_str = wrap_quotes(attr_val) if isinstance(attr_val, str) else str(attr_val) - return f"{wrap_quotes(attr_key)}: {attr_str}" + + if isinstance(attr_val, str): + # attrs can be strings but also other types; + # we want to wrap strings in quotes + # (__repr__ would work but it uses single quotes) + attr_val = wrap_quotes(attr_val) + elif isinstance(attr_val, tvm.tir.IntImm): + if attr_val.dtype == "bool": + attr_val = bool(attr_val.value) + else: + attr_val = int(attr_val.value) + + return f"{wrap_quotes(attr_key)}: {attr_val}" fields["attrs"] = self.build_list( map(display_attrs, op.attrs.keys()), diff --git a/python/tvm/relax/training/setup_trainer.py b/python/tvm/relax/training/setup_trainer.py index 71bf8509a63e..aba7ae912c54 100644 --- a/python/tvm/relax/training/setup_trainer.py +++ b/python/tvm/relax/training/setup_trainer.py @@ -139,14 +139,14 @@ def _check_well_formed(self, mod: IRModule): # Check function attrs if not self.PARAM_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.PARAM_NUM_ATTR_KEY], IntImm + mod.attrs[self.PARAM_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " f"{self.PARAM_NUM_ATTR_KEY}" ) if not self.STATE_NUM_ATTR_KEY in mod.attrs or not isinstance( - mod.attrs[self.STATE_NUM_ATTR_KEY], IntImm + mod.attrs[self.STATE_NUM_ATTR_KEY], (IntImm, int) ): raise ValueError( f"SetupTrainer: The backbone module should has an integer attribute named " diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 9323bc40da69..e1cab4cbd53b 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -97,6 +97,9 @@ def convert_to_expr(value: Any) -> Expr: if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) + if isinstance(value, float): + return PrimValue(tir.FloatImm("float64", value)) + tvm_value = convert_to_object(value) # Case 1 if isinstance(tvm_value, Expr): # type: ignore diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 97d7cfa93c8d..199193f75939 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -76,7 +76,7 @@ def get_section_begin_coords(split: tvm.relay.Expr) -> List[int]: # 0 is the beginning of the first section. return [0] + list(indices_or_sections) split_axis_len = input_shape[split_axis].value - section_length = split_axis_len // indices_or_sections.value + section_length = split_axis_len // indices_or_sections return list(range(0, split_axis_len, section_length)) def callback( diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 6b9b311c83b5..dca7b995b22d 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Gradient definitions for Relay operators""" +import tvm from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.error import OpError @@ -383,6 +384,8 @@ def concatenate_grad(orig, grad): axis_dims = [ty.shape[orig.attrs.axis] for ty in t.checked_type.fields] splits, cumsum = [], 0 for dim in axis_dims[:-1]: + if isinstance(dim, tvm.tir.IntImm): + dim = dim.value cumsum += dim splits.append(cumsum) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 93df67ff6b99..8bca72655491 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1057,10 +1057,10 @@ def split_shape_func(attrs, inputs, _): return [ _split_shape_func( inputs[0], - convert(i), - convert(indices_or_sections), - convert(param_is_indices), - convert(axis), + i, + indices_or_sections, + param_is_indices, + axis, ) for i in range(num_out) ] diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index dd04d613079b..c4eff3fcc9e0 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -1630,10 +1630,10 @@ def __init__(self, func_body): def convert_indices_or_sections(self, indices_or_sections): # split_v if isinstance(indices_or_sections, tvm.ir.container.Array): - values = [i.value for i in indices_or_sections] + values = [int(i) for i in indices_or_sections] # split else: - values = indices_or_sections.value + values = int(indices_or_sections) return values def is_valid(self): diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index ef1cdb3afdd8..dd9c670e2a37 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,6 +18,8 @@ # pylint: disable=import-outside-toplevel """Transform operators.""" +from typing import Optional + from ...tir import expr as _expr from ..expr import Constant, Expr, Tuple, TupleWrapper, const from . import _make @@ -855,13 +857,14 @@ def broadcast_to(data, shape): The resulting tensor. """ if isinstance(shape, Constant): - shape = list(shape.data.numpy()) - if isinstance(shape, Expr): + shape = shape.data.numpy() + shape = [_expr.IntImm(str(shape.dtype), int(value)) for value in shape] + elif isinstance(shape, Expr): return _dyn_make.broadcast_to(data, shape) + if isinstance(shape, int): shape = [shape] - if isinstance(shape, (list, tuple)): - shape = list(shape) + return _make.broadcast_to(data, shape) @@ -1938,9 +1941,8 @@ def stft( return _make.stft(data, n_fft, hop_length, win_length, window, normalized, onesided) -def dft(re_data, im_data, inverse=False): - """ - Computes the discrete Fourier transform of input (calculation along the last axis). +def dft(re_data, im_data, inverse: Optional[bool] = False): + """Computes the discrete Fourier transform of input (calculation along the last axis). This gives frequency components of the signal as they change over time. Parameters @@ -1952,8 +1954,11 @@ def dft(re_data, im_data, inverse=False): N-D tensor, imaginary part of the input signal. If the signal is real, then the values of this tensor are zeros. - inverse : bool + inverse : Optional[bool] + Whether to perform the inverse discrete fourier transform. + Providing None is equivalent to False, and is maintained for + compatibility. Returns ------- @@ -1961,7 +1966,11 @@ def dft(re_data, im_data, inverse=False): The Fourier Transform of the input (Real part). im_output : relay.Expr The Fourier Transform of the input (Imaginary part). + """ + if inverse is None: + inverse = False + return TupleWrapper(_make.dft(re_data, im_data, inverse), 2) diff --git a/python/tvm/relay/transform/fake_quantization_to_integer.py b/python/tvm/relay/transform/fake_quantization_to_integer.py index 7ad838895c9f..6eef6ff3ffae 100644 --- a/python/tvm/relay/transform/fake_quantization_to_integer.py +++ b/python/tvm/relay/transform/fake_quantization_to_integer.py @@ -364,9 +364,8 @@ def split(expr, type_map): arg = expr.args[0] t = type_map[arg] attrs = {**expr.attrs} - if isinstance(attrs["indices_or_sections"], tvm.tir.IntImm): - num_split = attrs["indices_or_sections"].value - attrs["indices_or_sections"] = num_split + if isinstance(attrs["indices_or_sections"], int): + num_split = attrs["indices_or_sections"] else: num_split = len(attrs["indices_or_sections"]) + 1 return [expr, TupleAffineType([t] * num_split)] diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index f182cd9bfd2f..301f0ef66286 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -27,11 +27,11 @@ from .profiling import Report # function exposures -from .object_generic import convert_to_object, convert, const from .ndarray import device, cpu, cuda, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, ext_dev from .module import load_module, enabled, system_lib, load_static_library -from .container import String, ShapeTuple +from .container import String, ShapeTuple # , BoxBool +from .object_generic import convert_to_object, convert, const from .params import ( save_param_dict, load_param_dict, diff --git a/python/tvm/runtime/container.py b/python/tvm/runtime/container.py index 686b4a26c80c..f1a0706a387d 100644 --- a/python/tvm/runtime/container.py +++ b/python/tvm/runtime/container.py @@ -172,3 +172,41 @@ def __eq__(self, other): return False return True + + +# @tvm._ffi.register_object("runtime.BoxBool") +# class BoxBool(Object): +# """A boolean wrapped as a tvm Object + +# Parameters +# ---------- +# value: bool + +# The value to hold +# """ + +# def __init__(self, value: bool): +# # Convert to int to avoid an infinite recursion, because +# # BoxBool may be constructed in _make_tvm_args, and calling +# # the packed func `_ffi_api.BoxBool` internally calls +# # `_make_tvm_args`. +# self.__init_handle_by_constructor__(_ffi_api.BoxBool, int(value)) + +# def __into_pynative_object__(self) -> bool: +# return self.value + +# @property +# def value(self) -> bool: +# """Unwrap the boxed value. + +# This is implemented explicitly rather than using the usual +# PackedFunc handling or AttrVisitor mechanics for two reasons. +# First, because the PackedFunc handling would require ambiguous +# representations between `True`/`1` and `False`/`0`. Second, +# because the boxing/unboxing must be available in +# `libtvm_runtime.so`, and AttrVisitor is only available in +# `libtvm.so`. +# """ +# unboxed_bool = _ffi_api.UnBoxBool(self) +# assert unboxed_bool is not None +# return bool(unboxed_bool) diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index 887c2faaeb2b..20909c53c787 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,65 +38,62 @@ def asobject(self): ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PackedFuncBase, PyNativeObject) -def convert_to_object(value, span=None): +def convert_to_object(value): """Convert a Python value to corresponding object type. + Type conversions performed by this function must *only* produce + types that are supported by `libtvm_runtime.so`. This function + must be usable in environments where only TVM runtime support is + present. Automatic conversions to compile-time representations + (e.g. `tir.IntImm` or `relax.PrimValue`) should not be done as + part of this conversion, as these types are not available in + `libtvm_runtime.so`. + Parameters ---------- value : str The value to be inspected. - span : Optional[Span] - The location of this itervar in the source code. - Returns ------- obj : Object The corresponding object value. + """ + if isinstance(value, ObjectTypes): return value - if isinstance(value, bool): - return const(value, "uint1x1", span=span) - if isinstance(value, Number): - return const(value, span=span) - if isinstance(value, string_types): + elif isinstance(value, (bool, int, float)): + return value + elif isinstance(value, string_types): return _ffi_api.String(value) - if isinstance(value, (list, tuple)): - value = [convert_to_object(x) for x in value] + elif isinstance(value, (list, tuple)): + # The call to _ffi_api.Array will convert its own arguments, + # so we don't need to apply any explicit conversions here. return _ffi_api.Array(*value) - if isinstance(value, dict): - vlist = [] - for item in value.items(): - if ( - not isinstance(item[0], ObjectTypes) - and not isinstance(item[0], string_types) - and not isinstance(item[0], Number) - ): - raise ValueError("key of map must already been a container type") - vlist.append(convert_to_object(item[0])) - vlist.append(convert_to_object(item[1])) + elif isinstance(value, dict): + if any(not isinstance(key, (ObjectTypes, string_types, Number)) for key in value): + raise ValueError("key of map must already been a container type") + + vlist = [kv for item in value.items() for kv in item] return _ffi_api.Map(*vlist) - if isinstance(value, ObjectGeneric): + elif isinstance(value, ObjectGeneric): return value.asobject() - if callable(value): + elif callable(value): return convert_to_tvm_func(value) - if value is None: + elif value is None: return None - - raise ValueError(f"don't know how to convert type {type(value)} to object") + else: + raise TypeError(f"don't know how to convert type {type(value)} to object") -def convert(value, span=None): +def convert(value): """Convert value to TVM object or function. Parameters ---------- value : python value - span : Optional[Span] - The location of this statement in the source code. - Returns ------- tvm_val : Object or Function @@ -107,29 +104,29 @@ def convert(value, span=None): This function is redirected to `convert_to_object` as it is widely used in the codebase. We can choose one to keep and discard the other one later. """ - return convert_to_object(value, span=span) + + return convert_to_object(value) def _scalar_type_inference(value): if hasattr(value, "dtype"): - dtype = str(value.dtype) + return str(value.dtype) elif isinstance(value, bool): - dtype = "bool" + return "bool" elif isinstance(value, float): # We intentionally prefer convert the float to float32 since it's more common in DL. if -3.40282347e38 <= value <= 3.40282347e38: - dtype = "float32" + return "float32" else: - dtype = "float64" + return "float64" elif isinstance(value, int): # We intentionally prefer convert the python int to int32 since it's more common in DL. if -2147483648 <= value <= 2147483647: - dtype = "int32" + return "int32" else: - dtype = "int64" + return "int64" else: raise NotImplementedError(f"Cannot automatically inference the type. value={value}") - return dtype def const(value, dtype=None, span=None): diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index e545bc3a5e53..3107354ac353 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -536,6 +536,8 @@ def visit_return(self: Parser, node: doc.Return) -> None: The doc AST return node. """ value = self.eval_expr(node.value) + if value is None: + self.report_error(node, "Expression to be returned must be a PrimExpr") T.evaluate(tvm.tir.ret(value)) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 462066106a9d..948a0d7665ff 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -96,7 +96,7 @@ def _allocate_tensor(func_id, args): ) shape = args[0] for i in shape: - _internal_assert(isinstance(i, _expr.PrimExpr), "The shape should be an expression") + _internal_assert(isinstance(i, (_expr.PrimExpr, int)), "The shape should be an expression") if n > 1: _internal_assert(isinstance(args[1], str), "The data type should be an str") _internal_assert( @@ -131,9 +131,11 @@ def len(func_id, args): def _cast(func_id, args): _internal_assert( - args.__len__() == 1 and isinstance(args[0], _expr.PrimExpr), - "Only one expression can be cast", + args.__len__() == 1, + f"Casting to {func_id} only supports a single argument", ) + # The FFI can handle any conversion of `args[0]` into PrimExpr, if + # required. return _expr.Cast(func_id, args[0]) @@ -145,9 +147,7 @@ def _cast(func_id, args): def ceil_div(func_id, args): _internal_assert(func_id == "ceil_div", "This function cannot be directly invoked!") _internal_assert(args.__len__() == 2, "2 arguments expected for division!") - _internal_assert(isinstance(args[0], _expr.PrimExpr), "Only expressions can div") - _internal_assert(isinstance(args[1], _expr.PrimExpr), "Only expressions can div") - a, b = args[0], args[1] + a, b = args return (a + b - 1) // b diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 846ef818ea54..bd5a060cd01c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -279,7 +279,7 @@ def visit_Num(self, node): return tvm.runtime.const(node.n, dtype) def visit_NameConstant(self, node): - return tvm.runtime.convert(node.value) + return tvm.tir.const(node.value) def visit_AugAssign(self, node): buf = self.visit(node.target) @@ -376,7 +376,7 @@ def visit_Subscript(self, node): args = [args] arr = self.visit(node.value) - if isinstance(arr, Array): + if isinstance(arr, (Array, list, tuple)): for i in args: if isinstance(i, numbers.Integral): arr = arr[i] diff --git a/python/tvm/te/hybrid/utils.py b/python/tvm/te/hybrid/utils.py index f653b3e83d8b..a515938fa524 100644 --- a/python/tvm/te/hybrid/utils.py +++ b/python/tvm/te/hybrid/utils.py @@ -33,9 +33,9 @@ # pylint: disable=invalid-name -np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) -tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr) -halide_imm_types = (_expr.IntImm, _expr.FloatImm) +np_arg_types = (numpy.ndarray, *numeric_types) +tvm_arg_types = (Tensor, Array, _expr.Var, _expr.ConstExpr, *numeric_types, list, tuple, str) +halide_imm_types = (_expr.IntImm, _expr.FloatImm, *numeric_types) def _internal_assert(cond, err): @@ -91,19 +91,13 @@ def replace(op): def _is_tvm_arg_types(args): """Determine a list of element is either a list of tvm arguments of a list of numpy arguments. If neither is true, raise a value error.""" - if isinstance(args[0], tvm_arg_types): - for elem in args[1:]: - _internal_assert( - isinstance(elem, tvm_arg_types), - f"Expecting a Var, Tensor or ConstExpr instance but {type(elem)} get!", - ) + if all(isinstance(elem, tvm_arg_types) for elem in args): return True - - _internal_assert( - isinstance(args[0], np_arg_types), f"Expect a numpy type but {type(args[0])} get!" - ) - for elem in args[1:]: - _internal_assert( - isinstance(elem, np_arg_types), f"Expect a numpy type but {type(elem)} get!" + elif all(isinstance(elem, np_arg_types) for elem in args): + return False + else: + raise ValueError( + f"Expected arguments to be entirely TVM types, " + f"or entirely numpy types, " + f"but received {[type(elem) for elem in args]}" ) - return False diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index dc2c67849925..64a282dcf755 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -53,7 +53,6 @@ def placeholder(shape, dtype=None, name="placeholder"): tensor: Tensor The created tensor """ - shape = (shape,) if isinstance(shape, tvm.tir.PrimExpr) else shape dtype = "float32" if dtype is None else dtype return _ffi_api.Placeholder(shape, dtype, name) diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index d435e821acf3..930667242e29 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -64,16 +64,7 @@ def __call__(self, *indices): f"Need to provide {ndim} index in tensor but {len(indices)} was provided" ) indices = convert_to_object(indices) - args = [] - for x in indices: - if isinstance(x, _expr.PrimExpr): - args.append(x) - elif isinstance(x, _expr.IterVar): - args.append(x.var) - else: - raise ValueError("The indices must be expression") - - return _expr.ProducerLoad(self, args) + return _expr.ProducerLoad(self, indices) def __getitem__(self, indices): return TensorSlice(self, indices) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index bcfbe6575d52..0c8048d24d8b 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -21,6 +21,7 @@ from .buffer import Buffer, decl_buffer, DataProducer from .data_layout import Layout, BijectiveLayout, bijective_layout, layout +from .expr import convert from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index c78bb9e7ecd0..37976394f831 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -41,6 +41,10 @@ from .buffer import Buffer, DataProducer +def convert(expr) -> PrimExpr: + return _ffi_api.convert(expr) + + def div_ambiguity_error() -> RuntimeError: return RuntimeError( "TVM supports multiple types of integer divisions, " diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 50de995a9145..777d46ec7b0d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -17,7 +17,7 @@ """Developer API of IR node builder make function.""" import tvm from tvm._ffi.base import string_types -from tvm.runtime import ObjectGeneric, convert, const +from tvm.runtime import ObjectGeneric, const from tvm.ir import container as _container from . import stmt as _stmt @@ -107,7 +107,9 @@ def __getitem__(self, index): def __setitem__(self, index, value): index = self._normalize_index(index) - value = convert(value) + if isinstance(value, (int, bool, float)): + value = tvm.tir.const(value) + value_element = value.dtype.split("x", maxsplit=1)[0] content_element = self._content_type.split("x", maxsplit=1)[0] if value_element != content_element: diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index 0bc299e403c5..8d9647b60049 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -19,13 +19,14 @@ from typing import Any, Optional, Union import tvm._ffi +from tvm import tir from tvm.ir import Array, Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import const, convert +from tvm.runtime import const from . import _ffi_api from .buffer import Buffer -from .expr import Call, CommReducer, IntImm, PrimExprWithOp, StringImm, Var +from .expr import Call, CommReducer, IntImm, PrimExprWithOp, Var def _pack_buffer(buf, span=None): @@ -181,7 +182,7 @@ def call_intrin(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call(dtype, func_name, convert(args), span) + return Call(dtype, func_name, args, span) def call_pure_extern(dtype, func_name, *args, span=None): @@ -206,9 +207,7 @@ def call_pure_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_pure_extern"), convert((StringImm(func_name),) + args), span - ) + return Call(dtype, Op.get("tir.call_pure_extern"), [func_name, *args], span) def call_extern(dtype, func_name, *args, span=None): @@ -233,9 +232,7 @@ def call_extern(dtype, func_name, *args, span=None): call : PrimExpr The call expression. """ - return Call( - dtype, Op.get("tir.call_extern"), convert((StringImm(func_name),) + args), span=span - ) + return Call(dtype, Op.get("tir.call_extern"), [func_name, *args], span=span) def call_llvm_intrin(dtype, name, *args, span=None): @@ -1832,13 +1829,10 @@ def dp4a(vec1, vec2, acc=0): call : PrimExpr The call expression. """ - vec1 = convert(vec1) - vec2 = convert(vec2) - acc = convert(acc) return call_intrin("int32", "tir.dp4a", vec1, vec2, acc) -def ret(val): +def ret(val, span=None): """Create a tir return expression Parameters @@ -1846,14 +1840,16 @@ def ret(val): val : Expr The returned tir expression, whose data type is int, float or void pointer. + span : Optional[Span] + The location of this operator in the source code. + Returns ------- ret : PrimExpr The return expression """ - val = convert(val) - return call_intrin(val.dtype, "tir.ret", val) + return _ffi_api.ret(val, span) def any(*args, span=None): @@ -2038,7 +2034,7 @@ def exp(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp", x) @@ -2055,7 +2051,7 @@ def exp2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp2", x) @@ -2072,7 +2068,7 @@ def exp10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.exp10", x) @@ -2089,7 +2085,7 @@ def erf(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.erf", x) @@ -2106,7 +2102,7 @@ def tanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tanh", x) @@ -2123,7 +2119,7 @@ def sigmoid(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sigmoid", x) @@ -2140,7 +2136,7 @@ def log(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log", x) @@ -2157,7 +2153,7 @@ def log2(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log2", x) @@ -2174,7 +2170,7 @@ def log10(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log10", x) @@ -2191,7 +2187,7 @@ def log1p(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.log1p", x) @@ -2208,7 +2204,7 @@ def tan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.tan", x) @@ -2225,7 +2221,7 @@ def cos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cos", x) @@ -2242,7 +2238,7 @@ def cosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.cosh", x) @@ -2259,7 +2255,7 @@ def acos(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acos", x) @@ -2276,7 +2272,7 @@ def acosh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.acosh", x) @@ -2293,7 +2289,7 @@ def sin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sin", x) @@ -2310,7 +2306,7 @@ def sinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sinh", x) @@ -2327,7 +2323,7 @@ def asin(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asin", x) @@ -2344,7 +2340,7 @@ def asinh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.asinh", x) @@ -2361,7 +2357,7 @@ def atan(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atan", x) @@ -2378,7 +2374,7 @@ def atanh(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.atanh", x) @@ -2398,8 +2394,8 @@ def atan2(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.atan2", x1, x2) @@ -2416,7 +2412,7 @@ def sqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.sqrt", x) @@ -2433,7 +2429,7 @@ def rsqrt(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.rsqrt", x) @@ -2679,8 +2675,8 @@ def nextafter(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.nextafter", x1, x2) # type: ignore @@ -2700,8 +2696,8 @@ def hypot(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.hypot", x1, x2) # type: ignore @@ -2721,8 +2717,8 @@ def copysign(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.copysign", x1, x2) # type: ignore @@ -2742,8 +2738,8 @@ def ldexp(x1, x2): y : PrimExpr The result. """ - x1 = convert(x1) - x2 = convert(x2) + x1 = tir.convert(x1) + x2 = tir.convert(x2) return call_intrin(x1.dtype, "tir.ldexp", x1, x2) # type: ignore @@ -2862,7 +2858,7 @@ def power(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def pow(x, y, span=None): @@ -2884,7 +2880,7 @@ def pow(x, y, span=None): z : PrimExpr The result. """ - return _ffi_api._OpPow(convert(x), convert(y), span) # type: ignore + return _ffi_api._OpPow(x, y, span) # type: ignore def popcount(x): @@ -2900,7 +2896,7 @@ def popcount(x): y : PrimExpr The result. """ - x = convert(x) + x = tir.convert(x) return call_intrin(x.dtype, "tir.popcount", x) @@ -3032,8 +3028,8 @@ def fmod(x, y): z : PrimExpr The result. """ - x = convert(x) - y = convert(y) + x = tir.convert(x) + y = tir.convert(y) return call_intrin(x.dtype, "tir.fmod", x, y) @@ -3067,7 +3063,7 @@ def if_then_else(cond, t, f, span=None): Unlike Select, if_then_else cannot be vectorized if some lanes in the vector have different conditions. """ - return _ffi_api._OpIfThenElse(convert(cond), convert(t), convert(f), span) # type: ignore + return _ffi_api._OpIfThenElse(cond, t, f, span) # type: ignore def div(a, b, span=None): @@ -3314,34 +3310,23 @@ def _reduce_directly(*args): def _make_reduce(expr, axis, where=None, init=None): code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - expr = convert(expr) + expr = tir.convert(expr) if init is not None: - init = convert(init) + init = tir.convert(init) if isinstance(expr, Array): size = len(expr) - larr = [] - rarr = [] + lhs = [] + rhs = [] dtypes = [] for i in range(size): dtype = expr[i].dtype dtypes.append(dtype) lname = code.co_varnames[0] + "_" + str(i) - larr.append(Var(lname, dtype)) + lhs.append(Var(lname, dtype)) rname = code.co_varnames[1] + "_" + str(i) - rarr.append(Var(rname, dtype)) - if init is not None: - init = convert(init) - assert isinstance(init, Array) - assert len(init) == size - for init_i in range(size): - init_i = convert(init_i) - assert isinstance( - init_i, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm) - ) - else: - init = convert([]) - lhs = convert(larr) - rhs = convert(rarr) + rhs.append(Var(rname, dtype)) + if init is None: + init = [] result = fcombine(lhs, rhs) id_elem = fidentity(*dtypes) else: @@ -3352,22 +3337,18 @@ def _make_reduce(expr, axis, where=None, init=None): rvar = Var(code.co_varnames[1], dtype) result = [fcombine(lvar, rvar)] id_elem = [fidentity(dtype)] - lhs = convert([lvar]) - rhs = convert([rvar]) - expr = convert([expr]) + lhs = [lvar] + rhs = [rvar] + expr = [expr] if init is not None: - assert isinstance(init, (tvm.tir.ProducerLoad, tvm.tir.IntImm, tvm.tir.FloatImm)) - init = convert([init]) - result = convert(result) - id_elem = convert(id_elem) + init = [init] combiner = CommReducer(lhs, rhs, result, id_elem) - axis = convert(axis if isinstance(axis, (list, tuple)) else [axis]) + if not isinstance(axis, (list, tuple, tvm.ir.Array)): + axis = [axis] if where is None: - where = convert(True) + where = tir.convert(True) if init is None: - outputs = tuple( - tvm.tir.Reduce(combiner, expr, axis, where, i, convert([])) for i in range(size) - ) + outputs = tuple(tvm.tir.Reduce(combiner, expr, axis, where, i, []) for i in range(size)) else: outputs = tuple( tvm.tir.Reduce(combiner, expr, axis, where, i, init) for i in range(size) diff --git a/python/tvm/tir/schedule/trace.py b/python/tvm/tir/schedule/trace.py index cb8d5ce9973e..85377560f1fc 100644 --- a/python/tvm/tir/schedule/trace.py +++ b/python/tvm/tir/schedule/trace.py @@ -39,17 +39,20 @@ def _json_from_tvm(obj): if obj is None: return None - if isinstance(obj, Array): + elif isinstance(obj, (bool, int, float, str)): + return obj + elif isinstance(obj, Array): return [_json_from_tvm(i) for i in obj] - if isinstance(obj, Map): + elif isinstance(obj, Map): return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()} - if isinstance(obj, String): + elif isinstance(obj, String): return str(obj) - if isinstance(obj, (IntImm, FloatImm)): + elif isinstance(obj, (IntImm, FloatImm)): return obj.value - if isinstance(obj, IndexMap): + elif isinstance(obj, IndexMap): return save_json(obj) - raise TypeError("Not supported type: " + str(type(obj))) + else: + raise TypeError("Not supported type: " + str(type(obj))) @_register_object("tir.Trace") diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index bf6a9c75516f..cc1a28b9dee0 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -468,7 +468,7 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): C = out.op.input_tensors[0] A = C.op.input_tensors[0] in_type = A.dtype - use_scalable_vectors = out.op.attrs["use_scalable_vectors"].value + use_scalable_vectors = bool(out.op.attrs["use_scalable_vectors"]) tile_M, tile_K = arm_utils.get_tiling_A(False, in_type) tile_N, _ = arm_utils.get_tiling_B_transformed(False, in_type, use_scalable_vectors) diff --git a/python/tvm/topi/cuda/batch_matmul.py b/python/tvm/topi/cuda/batch_matmul.py index 83b000a4b9bb..0a7acfa50444 100644 --- a/python/tvm/topi/cuda/batch_matmul.py +++ b/python/tvm/topi/cuda/batch_matmul.py @@ -295,15 +295,11 @@ def batch_matmul_int8( # pad for _dp4a vectorize pad_x = te.compute( (XB, M, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= XK, tvm.runtime.convert(0).astype(x.dtype), x[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= XK, tvm.tir.const(0, x.dtype), x[b, i, j]), ) pad_y = te.compute( (YB, N, nK), - lambda b, i, j: tvm.te.if_then_else( - j >= YK, tvm.runtime.convert(0).astype(y.dtype), y[b, i, j] - ), + lambda b, i, j: tvm.te.if_then_else(j >= YK, tvm.tir.const(0, y.dtype), y[b, i, j]), ) out = te.compute( diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs index 8d59c2a035a9..b98d9c102baa 100644 --- a/rust/tvm-rt/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -48,7 +48,7 @@ pub struct ModuleNode { crate::external! { #[name("runtime.RuntimeEnabled")] - fn runtime_enabled(target: CString) -> i32; + fn runtime_enabled(target: CString) -> bool; #[name("runtime.ModuleLoadFromFile")] fn load_from_file(file_name: CString, format: CString) -> Module; @@ -121,8 +121,7 @@ impl Module { /// Checks if a target device is enabled for a module. pub fn enabled(&self, target: &str) -> bool { let target = CString::new(target).unwrap(); - let enabled = runtime_enabled(target).unwrap(); - enabled != 0 + runtime_enabled(target).unwrap() } /// Returns the underlying module handle. diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index a74cbe318e2d..2c1f7db6adb0 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -73,6 +73,7 @@ macro_rules! TVMPODValue { Int(i64), UInt(i64), Float(f64), + Bool(bool), Null, DataType(DLDataType), String(*mut c_char), @@ -95,6 +96,7 @@ macro_rules! TVMPODValue { DLDataTypeCode_kDLInt => Int($value.v_int64), DLDataTypeCode_kDLUInt => UInt($value.v_int64), DLDataTypeCode_kDLFloat => Float($value.v_float64), + TVMArgTypeCode_kTVMArgBool => Bool($value.v_bool), TVMArgTypeCode_kTVMNullptr => Null, TVMArgTypeCode_kTVMDataType => DataType($value.v_type), TVMArgTypeCode_kDLDevice => Device($value.v_device), @@ -117,6 +119,7 @@ macro_rules! TVMPODValue { Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt), UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt), Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat), + Bool(val) => (TVMValue { v_bool: *val }, TVMArgTypeCode_kTVMArgBool), Null => (TVMValue{ v_int64: 0 },TVMArgTypeCode_kTVMNullptr), DataType(val) => (TVMValue { v_type: *val }, TVMArgTypeCode_kTVMDataType), Device(val) => (TVMValue { v_device: val.clone() }, TVMArgTypeCode_kDLDevice), @@ -263,6 +266,7 @@ macro_rules! impl_pod_value { impl_pod_value!(Int, i64, [i8, i16, i32, i64, isize]); impl_pod_value!(UInt, i64, [u8, u16, u32, u64, usize]); impl_pod_value!(Float, f64, [f32, f64]); +impl_pod_value!(Bool, bool, [bool]); impl_pod_value!(DataType, DLDataType, [DLDataType]); impl_pod_value!(Device, DLDevice, [DLDevice]); @@ -380,37 +384,6 @@ impl TryFrom for std::ffi::CString { } } -// Implementations for bool. - -impl<'a> From<&bool> for ArgValue<'a> { - fn from(s: &bool) -> Self { - (*s as i64).into() - } -} - -impl From for RetValue { - fn from(s: bool) -> Self { - (s as i64).into() - } -} - -impl TryFrom for bool { - type Error = ValueDowncastError; - - fn try_from(val: RetValue) -> Result { - try_downcast!(val -> bool, - |RetValue::Int(val)| { !(val == 0) }) - } -} - -impl<'a> TryFrom> for bool { - type Error = ValueDowncastError; - - fn try_from(val: ArgValue<'a>) -> Result { - try_downcast!(val -> bool, |ArgValue::Int(val)| { !(val == 0) }) - } -} - impl From<()> for RetValue { fn from(_: ()) -> Self { RetValue::Null diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e03d4302c89f..82e439cddbc2 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -554,9 +554,19 @@ class FlopEstimator : public ExprFunctor { if (auto pop = op.as()) { if (pop->attrs.count("FLOP")) { // Use user-provided FLOP - auto pint = pop->attrs["FLOP"].as(); - ICHECK(pint != nullptr); - ret += pint->value; + ObjectRef annotation = pop->attrs["FLOP"]; + auto value = [&]() -> int64_t { + if (auto runtime_int = annotation.as()) { + return runtime_int->value; + } else if (auto int_imm = annotation.as()) { + return int_imm->value; + } else { + LOG(FATAL) << "FLOP annotation must be an integer, " + << "but was an object of type " << annotation->GetTypeKey(); + } + }(); + + ret += value; } else { // Estimate by parsing the compute body double num_element = AxisLengthProd(pop->axis); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 862e593c9dd3..0bf6da255d2a 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -482,7 +482,8 @@ std::vector> RuleCustomSketch::Apply(const SketchPolicyNod std::vector> ret; for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); - auto next = item[1].as(); + auto next = item[1].as(); + ICHECK(next); ret.emplace_back(Downcast(item[0]), next->value); } return ret; diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 76fb77dd9527..cc6b0ab23756 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -101,7 +101,7 @@ inline int OperationToStage(const te::Operation& op, const State& state) { /*! \brief Get an integer from a tvm str Map. */ inline int GetIntParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pint = attr_dict[key].as(); + auto pint = attr_dict[key].as(); ICHECK(pint != nullptr); return pint->value; } @@ -109,7 +109,7 @@ inline int GetIntParam(const Map& attr_dict, const std::strin /*! \brief Get a double from a tvm str Map. */ inline double GetDoubleParam(const Map& attr_dict, const std::string& key) { ICHECK_GT(attr_dict.count(key), 0) << "Cannot find key: \"" << key << "\" in " << attr_dict; - auto pdouble = attr_dict[key].as(); + auto pdouble = attr_dict[key].as(); ICHECK(pdouble != nullptr); return pdouble->value; } @@ -120,10 +120,12 @@ inline std::string GetStringParam(const Map& attr_dict, const const auto& target = attr_dict[key]; if (auto pstr = target.as()) { return pstr->value; + } else if (auto pstr = target.as()) { + return pstr->data; + } else { + LOG(FATAL) << "Could not convert object " << target << " of type " << target->GetTypeKey() + << " to string"; } - auto pstr = target.as(); - ICHECK(pstr != nullptr); - return pstr->data; } /*! \brief Get a iterator name set from a tvm str Map. */ diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index 289c1b79fd66..708fb56c9851 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -100,8 +100,17 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { const ObjectRef& value = doc->value; if (!value.defined()) { output_ << "\"\""; + } else if (const auto* runtime_int = value.as()) { + output_ << runtime_int->value; } else if (const auto* int_imm = value.as()) { output_ << int_imm->value; + } else if (const auto* runtime_float = value.as()) { + output_.precision(config_.float_precision); + if (std::isinf(runtime_float->value) || std::isnan(runtime_float->value)) { + output_ << '"' << runtime_float->value << '"'; + } else { + output_ << runtime_float->value; + } } else if (const auto* float_imm = value.as()) { output_.precision(config_.float_precision); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { diff --git a/src/contrib/msc/core/printer/prototxt_printer.cc b/src/contrib/msc/core/printer/prototxt_printer.cc index 7e96c657a711..99be910bd70a 100644 --- a/src/contrib/msc/core/printer/prototxt_printer.cc +++ b/src/contrib/msc/core/printer/prototxt_printer.cc @@ -33,6 +33,10 @@ namespace msc { LiteralDoc PrototxtPrinter::ToLiteralDoc(const ObjectRef& obj) { if (obj.as()) { return LiteralDoc::Str(Downcast(obj), NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Int(ptr->value, NullOpt); + } else if (auto ptr = obj.as()) { + return LiteralDoc::Float(ptr->value, NullOpt); } else if (obj.as()) { return LiteralDoc::Int(Downcast(obj)->value, NullOpt); } else if (obj.as()) { diff --git a/src/contrib/msc/core/utils.cc b/src/contrib/msc/core/utils.cc index f58f95ae53b0..5fcbe924ae1c 100644 --- a/src/contrib/msc/core/utils.cc +++ b/src/contrib/msc/core/utils.cc @@ -263,6 +263,10 @@ const String StringUtils::ToString(const runtime::ObjectRef& obj) { obj_string = ""; } else if (obj.as()) { obj_string = Downcast(obj); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); + } else if (const auto* n = obj.as()) { + obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { obj_string = std::to_string(n->value); } else if (const auto* n = obj.as()) { diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 105ac063e0ea..1e576bc91002 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -171,9 +171,10 @@ Array CreatePassList(bool disable_loop_partition) { // phase passes is of the form // [[phase_number, pass], [phase_number, pass]... ] for (Array phase_pass : add_lower_pass) { - const IntImmNode* phase_num = phase_pass[0].as(); + auto phase_num = phase_pass[0].as(); ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer"; + << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " + << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); int phase_num_val = phase_num->value; CHECK_GE(phase_num_val, 0); diff --git a/src/ir/attrs.cc b/src/ir/attrs.cc index f197ac4416fa..08e7ffc5bf59 100644 --- a/src/ir/attrs.cc +++ b/src/ir/attrs.cc @@ -31,6 +31,91 @@ void DictAttrsNode::VisitAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) { v->Visit("__dict__", &dict); } +namespace { + +/* \brief Normalize attributes from runtime types to Relax IR types + * + * While conversion from `tvm::runtime` types to compile-time IR + * types usually occurs as part of FFI conversions, the attributes + * are not converted, as they are stored in a `Map`. While this is required to allow attribute values to + * contain `ObjectRef` instances that are not IR expressions, the + * conversion should still be applied when possible. + * + * \param obj The IR attribute value to be normalized + * + * \return The normalized attribute value + */ +ObjectRef NormalizeAttr(ObjectRef obj) { + if (auto dict_attrs = obj.as()) { + auto new_dict = Downcast>(NormalizeAttr(dict_attrs->dict)); + if (new_dict.same_as(dict_attrs->dict)) { + return obj; + } else { + return DictAttrs(new_dict); + } + } else if (auto runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (auto runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map([](const ObjectRef& inner) { return NormalizeAttr(inner); }); + } else if (auto opt_map = obj.as>()) { + auto map = opt_map.value(); + + Map updates; + for (const auto& [key, inner] : map) { + auto new_inner = NormalizeAttr(inner); + if (!new_inner.same_as(inner)) { + updates.Set(key, new_inner); + } + } + for (const auto& [key, new_inner] : updates) { + map.Set(key, new_inner); + } + + return map; + + } else { + return obj; + } +} +} // namespace + +DictAttrs WithAttrs(DictAttrs attrs, Map new_attrs) { + if (new_attrs.empty()) { + return attrs; + } + + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + + for (const auto& [key, value] : new_attrs) { + attr_dict.Set(key, NormalizeAttr(value)); + } + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.Set(key, NormalizeAttr(value)); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + +DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key) { + auto* write_ptr = attrs.CopyOnWrite(); + Map attr_dict = std::move(write_ptr->dict); + attr_dict.erase(key); + + write_ptr->dict = std::move(attr_dict); + return attrs; +} + void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) { for (int i = 0; i < args.size(); i += 2) { std::string key = args[i]; @@ -43,11 +128,15 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un dict.Set(key, val.operator PrimExpr()); } } + + dict = Downcast>(NormalizeAttr(dict)); } Array DictAttrsNode::ListFieldInfo() const { return {}; } DictAttrs::DictAttrs(Map dict) { + dict = Downcast>(NormalizeAttr(dict)); + ObjectPtr n = make_object(); n->dict = std::move(dict); data_ = std::move(n); diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 596805f74b24..ded046eafc5d 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -47,6 +47,12 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) { if (auto opt = ref.as()) { return tir::StringImm(opt.value()); } + if (auto opt = ref.as()) { + return Bool(opt.value()); + } + if (auto opt = ref.as()) { + return Integer(opt.value()); + } if (const auto* buffer_region = ref.as()) { Array indices; indices.reserve(buffer_region->region.size()); @@ -155,9 +161,14 @@ Range Range::FromMinExtent(PrimExpr min, PrimExpr extent, Span span) { TVM_REGISTER_GLOBAL("ir.Range_from_min_extent").set_body_typed(Range::FromMinExtent); -TVM_REGISTER_GLOBAL("ir.Range").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = Range(args[0], args[1], args[2]); -}); +TVM_REGISTER_GLOBAL("ir.Range") + .set_body_typed([](PrimExpr begin, Optional end, Span span) -> Range { + if (end.defined()) { + return Range(begin, end.value(), span); + } else { + return Range(IntImm(begin->dtype, 0), begin, span); + } + }); TVM_REGISTER_NODE_TYPE(RangeNode); diff --git a/src/ir/transform.cc b/src/ir/transform.cc index dc67822411c5..f0b879acbc03 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -107,43 +107,42 @@ bool PassContext::PassEnabled(const PassInfo& info) const { class PassConfigManager { public: - void Register(std::string key, uint32_t value_type_index) { + void Register(std::string key, uint32_t value_type_index, + std::function legalization) { ICHECK_EQ(key2vtype_.count(key), 0U); ValueTypeInfo info; info.type_index = value_type_index; info.type_key = runtime::Object::TypeIndex2Key(value_type_index); + info.legalization = legalization; key2vtype_[key] = info; } // Trying to validate and legalize a config. void Legalize(Map* config) { std::vector> update; - auto* reflection = ReflectionVTable::Global(); - - for (auto kv : *config) { - auto it = key2vtype_.find(kv.first); + for (auto [key, obj] : *config) { + auto it = key2vtype_.find(key); if (it == key2vtype_.end()) { std::ostringstream os; - os << "AttributeError: Invalid config option \'" << kv.first << "\' candidates are:"; + os << "AttributeError: Invalid config option \'" << key << "\' candidates are:"; int counter = 0; - for (const auto& kv : key2vtype_) { + for (const auto& [key, obj] : key2vtype_) { os << ' '; if (counter++ != 0) os << ','; - os << kv.first; + os << key; } LOG(FATAL) << os.str(); } const auto& info = it->second; - ICHECK(kv.second.defined()) << "AttributeError: " << kv.first << " is None"; - if (kv.second->IsInstance::ContainerType>()) { - ObjectRef converted = - reflection->CreateObject(info.type_key, Downcast>(kv.second)); - update.emplace_back(kv.first, converted); - } else { - if (!runtime::ObjectInternal::DerivedFrom(kv.second.get(), info.type_index)) { - LOG(FATAL) << "AttributeError: expect config " << kv.first << " to have type " - << info.type_key << " but get " << kv.second->GetTypeKey(); - } + + ICHECK(obj.defined()) << "AttributeError: " << key << " is None"; + + ICHECK(info.legalization) << "AttributeError: " + << "Config option \'" << key + << "\' was defined without a legalization function."; + auto legalized = info.legalization(obj); + if (!legalized.same_as(obj)) { + update.emplace_back(key, legalized); } } for (auto&& kv : update) { @@ -170,13 +169,15 @@ class PassConfigManager { struct ValueTypeInfo { std::string type_key; uint32_t type_index; + std::function legalization; }; std::unordered_map key2vtype_; }; -void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index) { - PassConfigManager::Global()->Register(key, value_type_index); +void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_index, + std::function legalization) { + PassConfigManager::Global()->Register(key, value_type_index, legalization); } Map> PassContext::ListConfigs() { diff --git a/src/meta_schedule/database/database_utils.cc b/src/meta_schedule/database/database_utils.cc index 416753871244..ce025540e496 100644 --- a/src/meta_schedule/database/database_utils.cc +++ b/src/meta_schedule/database/database_utils.cc @@ -39,8 +39,14 @@ void JSONDumps(ObjectRef json_obj, std::ostringstream& os) { } else { os << int_imm->value; } + } else if (const auto* runtime_bool = json_obj.as()) { + os << (runtime_bool->value ? "true" : "false"); + } else if (const auto* runtime_int = json_obj.as()) { + os << runtime_int->value; } else if (const auto* float_imm = json_obj.as()) { os << std::setprecision(20) << float_imm->value; + } else if (const auto* runtime_float = json_obj.as()) { + os << std::setprecision(20) << runtime_float->value; } else if (const auto* str = json_obj.as()) { os << '"' << support::StrEscape(str->data, str->size) << '"'; } else if (const auto* array = json_obj.as()) { @@ -165,7 +171,7 @@ class JSONTokenizer { std::string to_parse(st, cur_); if (!is_float) { try { - *token = Token{TokenType::kInteger, IntImm(DataType::Int(64), std::stoll(to_parse))}; + *token = Token{TokenType::kInteger, runtime::Int(std::stoll(to_parse))}; } catch (const std::invalid_argument& e) { LOG(WARNING) << "ValueError: Invalid argument to std::stoll: " << to_parse << ". Details: " << e.what() << ". Switching to std::stod now."; @@ -178,7 +184,7 @@ class JSONTokenizer { } if (is_float) { try { - *token = Token{TokenType::kFloat, FloatImm(DataType::Float(64), std::stod(to_parse))}; + *token = Token{TokenType::kFloat, runtime::Float(std::stod(to_parse))}; } catch (const std::invalid_argument& e) { LOG(INFO) << "ValueError: Invalid argument to std::stod: " << to_parse << ". Details: " << e.what(); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 53f680f0a666..63af4a684567 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -192,7 +192,9 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, try { const ArrayNode* arr = json_obj.as(); ICHECK_EQ(arr->size(), 2); - workload = workloads[Downcast(arr->at(0)).IntValue()]; + int64_t workload_index = Downcast(arr->at(0)); + ICHECK(workload_index >= 0 && static_cast(workload_index) < workloads.size()); + workload = workloads[workload_index]; records[task_id] = TuningRecord::FromJSON(arr->at(1), workload); } catch (std::runtime_error& e) { LOG(FATAL) << "ValueError: Unable to parse TuningRecord, on line " << (task_id + 1) diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index f5d89a85092b..5b3e6d251d56 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -137,7 +137,7 @@ std::vector MutateThreadBindingNode::FindCan ICHECK(sample_it != sample_insts.end()); const InstructionNode* sample_inst = sample_it->second; - int decision = Downcast(trace->decisions[GetRef(sample_inst)])->value; + int decision = Downcast(trace->decisions[GetRef(sample_inst)]); std::vector probs = support::AsVector(Downcast>(sample_inst->attrs[1])); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index ea4e81c16f0c..a78b829e34ab 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -129,13 +129,13 @@ void FindSampleVectorize(const Trace& trace, std::vector* inst, ICHECK_EQ(inst->outputs.size(), 1); if (annotated.count(inst->outputs[0].get())) { ICHECK_EQ(inst->attrs.size(), 2); - std::vector probs = - support::AsVector(Downcast>(inst->attrs[1])); + std::vector probs = support::AsVector( + Downcast>(inst->attrs[1])); if (probs.size() == 1) { // Skip mutating the sampling instructions who have only single candidate. continue; } - const auto* d = TVM_TYPE_AS(decision, IntImmNode); + const auto* d = TVM_TYPE_AS(decision, runtime::Int::ContainerType); instructions.push_back(inst); decisions.push_back(d->value); } diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 7bbf00343af3..36dc57d80e66 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -114,9 +114,9 @@ bool FindUnrollDecision(const Trace& trace, TRandState* rand_state, ICHECK_EQ(sample_inst->attrs.size(), 2); candidate->inst = GetRef(sample_inst); candidate->decision = - Downcast(trace->decisions[GetRef(sample_inst)])->value; - candidate->probs = - support::AsVector(Downcast>(sample_inst->attrs[1])); + Downcast(trace->decisions[GetRef(sample_inst)])->value; + candidate->probs = support::AsVector( + Downcast>(sample_inst->attrs[1])); return true; } diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index b651b1f401cb..110cae96cb53 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -34,11 +34,11 @@ using namespace tvm::tir; std::function MakeFactorSampler(Schedule sch, Array thread_extents) { return [sch = std::move(sch), thread_extents = std::move(thread_extents)](int64_t max_extent) -> ExprRV { - Array extents; + Array extents; extents.reserve(thread_extents.size()); for (const Integer extent : thread_extents) { if (extent->value <= max_extent) { - extents.push_back(extent); + extents.push_back(runtime::Int(extent->value)); } } int n = extents.size(); @@ -48,7 +48,7 @@ std::function MakeFactorSampler(Schedule sch, Array th if (n == 1) { return Integer(extents[0]); } - Array probs(n, FloatImm(DataType::Float(64), 1.0 / n)); + Array probs(n, runtime::Float(1.0 / n)); return sch->SampleCategorical(extents, probs); }; } diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index e8d821636fd3..4a304cefa6bb 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -73,7 +73,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 3. Try block fusion. int n_candidate = static_cast(thread_extents.size()); - Array probs(n_candidate, FloatImm(DataType::Float(64), 1.0 / n_candidate)); + Array probs(n_candidate, 1.0 / n_candidate); tir::ExprRV thread_extent = tmp_sch->SampleCategorical(thread_extents, probs); if (fusible) { ICHECK(target_block.defined()); @@ -267,7 +267,7 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief The number of threads per warp */ int warp_size; /*! \brief Candidates of thread axis extent (values are required to be positive). */ - Array thread_extents; + Array thread_extents; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("max_threads_per_block", &max_threads_per_block); @@ -279,8 +279,8 @@ class CrossThreadReductionNode : public ScheduleRuleNode { TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; -ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { - for (const Integer& extent : thread_extents) { +ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { + for (const auto& extent : thread_extents) { CHECK(extent->value > 0) << "ValueError: The candidates of thread extent must be positive"; } ObjectPtr n = make_object(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index bcaf4343e256..2979e4229bdd 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -383,9 +383,8 @@ void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, if (!valid_vector_lens.empty()) { int n = valid_vector_lens.size(); double prob = 1.0 / n; - tir::ExprRV vector_load_len = - (*sch)->SampleCategorical(support::AsArray(valid_vector_lens), - Array(n, FloatImm(DataType::Float(64), prob))); + tir::ExprRV vector_load_len = (*sch)->SampleCategorical( + support::AsArray(valid_vector_lens), Array(n, prob)); (*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len); } } diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 045aa85b73ad..8ea2c2d1c6c3 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -68,7 +68,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { if (!unroll_max_steps.empty() && !tir::CheckSpatialPrimFunc(sch, root_rv)) { int n = unroll_max_steps.size(); double prob = 1.0 / n; - Array probs(n, FloatImm(DataType::Float(64), prob)); + Array probs(n, runtime::Float(prob)); PrimExpr max_step = sch->SampleCategorical(unroll_max_steps, probs); if (unroll_explicit) { sch->Annotate(root_rv, tir::attr::meta_schedule_unroll_explicit, max_step); @@ -102,7 +102,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { * \brief The options of the maximum number of unroll steps to be done. * Use an empty array to disable unroll. */ - Array unroll_max_steps; + Array unroll_max_steps; /*! \brief Whether to explicitly unroll the loop, or just add an "unroll" pragma. */ bool unroll_explicit; /*! \brief The number of maximum available jobs in CPU. */ @@ -122,7 +122,7 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, int max_vectorize_extent, - Array unroll_max_steps, + Array unroll_max_steps, bool unroll_explicit) { ObjectPtr n = make_object(); n->max_jobs_per_core = max_jobs_per_core; diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 3be264332461..83f5d073cb32 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -79,7 +79,7 @@ Array ScheduleRule::DefaultLLVM() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -126,7 +126,7 @@ Array ScheduleRule::DefaultX86(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/64, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation(), }; @@ -158,11 +158,11 @@ Array ScheduleRule::DefaultCUDA() { /*require_ordered=*/false, /*disallow_op=*/Array{}), ScheduleRule::CrossThreadReduction( - /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), + /*thread_extents=*/Array{4, 8, 16, 32, 64, 128, 256, 512}), ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/-1, /*max_vectorize_extent=*/-1, - /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, + /*unroll_max_steps=*/Array{0, 16, 64, 512, 1024}, /*unroll_explicit=*/true), ScheduleRule::AutoBind( /*max_threadblocks=*/256, @@ -297,7 +297,7 @@ Array ScheduleRule::DefaultHexagon() { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/16, /*max_vectorize_extent=*/128, - /*unroll_max_steps=*/Array{0, 16, 64, 512}, + /*unroll_max_steps=*/Array{0, 16, 64, 512}, /*unroll_explicit=*/true), }; } @@ -410,7 +410,7 @@ Array ScheduleRule::DefaultARM(const String& type) { ScheduleRule::ParallelizeVectorizeUnroll( /*max_jobs_per_core=*/8, /*max_vectorize_extent=*/32, - /*unroll_max_steps=*/Array{0, 8, 32, 256}, + /*unroll_max_steps=*/Array{0, 8, 32, 256}, /*unroll_explicit=*/true), ScheduleRule::RandomComputeLocation()); } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ceb0356cbcfe..28c45ea7455d 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -424,13 +424,22 @@ inline Array AsFloatArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), int_imm->value)); - } else if (const auto* float_imm = elem.as()) { - results.push_back(FloatImm(DataType::Float(32), float_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " << elem->GetTypeKey(); - } + auto float_value = [&]() -> double { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else if (const auto* float_imm = elem.as()) { + return float_imm->value; + } else if (const auto* runtime_float = elem.as()) { + return runtime_float->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of float or int, but gets: " + << elem->GetTypeKey(); + } + }(); + + results.push_back(FloatImm(DataType::Float(32), float_value)); } return results; } @@ -446,11 +455,16 @@ inline Array AsIntArray(const ObjectRef& obj) { Array results; results.reserve(arr->size()); for (const ObjectRef& elem : *arr) { - if (const auto* int_imm = elem.as()) { - results.push_back(Integer(int_imm->value)); - } else { - LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); - } + auto int_value = [&]() -> int64_t { + if (const auto* int_imm = elem.as()) { + return int_imm->value; + } else if (const auto* runtime_int = elem.as()) { + return runtime_int->value; + } else { + LOG(FATAL) << "TypeError: Expect an array of integers, but gets: " << elem->GetTypeKey(); + } + }(); + results.push_back(Integer(int_value)); } return results; } diff --git a/src/node/boxed_primitive.cc b/src/node/boxed_primitive.cc new file mode 100644 index 000000000000..86596fb5ce29 --- /dev/null +++ b/src/node/boxed_primitive.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file node/boxed_primitive.cc + * + * \brief Reflection utilities for runtime-supported classes + * + * The fundamental support for boxing and unboxing of primitives + * during FFI calls is implemented in runtime/boxed_primitive.cc. In + * addition, boxed primitives may be registered with compile-time + * utilities (e.g. reflection, JSON import/export) that can provide + * additional functionality and improved debugging ability. However, + * neither these compile-time utilities nor any registration of + * `Box` into the compile-time utilities should be included as + * part of `libtvm_runtime.so`. + * + * This file contains the registration of the `libtvm_runtime.so` + * class `Box` for utilities that are contained in `libtvm.so`. + */ +#include +#include +#include +#include + +namespace tvm { +namespace runtime_ext { + +using runtime::Box; +using runtime::BoxNode; + +/* \brief Compile-time extension trait for runtime types + * + * Extends the use of boxed primitive during TVM's compilation step. + * + * Most TVM classes define these functions as part of the class + * definition. However, the boxed primitives must be usable at + * runtime, and so the class definition may only refer to types that + * are present in `libtvm_runtime.so`. + */ +template +struct BoxNodeCompileTimeTraits { + static constexpr const std::nullptr_t VisitAttrs = nullptr; + + static void SHashReduce(const BoxNode* node, SHashReducer hash_reduce) { + hash_reduce(node->value); + } + + static bool SEqualReduce(const BoxNode* lhs, const BoxNode* rhs, + SEqualReducer equal) { + return equal(lhs->value, rhs->value); + } +}; + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + int64_t value = std::atoll(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + int64_t value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + if (blob == "true") { + return make_object>(true); + } else if (blob == "false") { + return make_object>(false); + } else { + LOG(FATAL) << "Invalid string '" << blob << "' for boolean"; + } + }) + .set_repr_bytes([](const Object* n) -> std::string { + bool value = GetRef(n).as>().value()->value; + if (value) { + return "true"; + } else { + return "false"; + } + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << (box->value ? "true" : "false") << ")"; + }); + +TVM_REGISTER_REFLECTION_VTABLE(BoxNode, BoxNodeCompileTimeTraits) + .set_creator([](const std::string& blob) -> ObjectPtr { + double value = std::atof(blob.c_str()); + return make_object>(value); + }) + .set_repr_bytes([](const Object* n) -> std::string { + double value = GetRef(n).as>().value()->value; + std::stringstream ss; + ss << value; + return ss.str(); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch>([](const ObjectRef& node, ReprPrinter* p) { + auto box = Downcast>(node); + p->stream << box->GetTypeKey() << "(" << box->value << ")"; + }); + +} // namespace runtime_ext + +} // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 6e7d82ee4a59..b8918b4ea48c 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -57,7 +57,7 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->binding_names.push_back(Downcast(v)); } if (auto v = config_dict.Get("show_meta")) { - n->show_meta = Downcast(v)->value; + n->show_meta = Downcast(v)->value; } if (auto v = config_dict.Get("ir_prefix")) { n->ir_prefix = Downcast(v); @@ -81,16 +81,16 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); } if (auto v = config_dict.Get("verbose_expr")) { - n->verbose_expr = Downcast(v)->value; + n->verbose_expr = Downcast(v)->value; } if (auto v = config_dict.Get("indent_spaces")) { - n->indent_spaces = Downcast(v)->value; + n->indent_spaces = Downcast(v)->value; } if (auto v = config_dict.Get("print_line_numbers")) { - n->print_line_numbers = Downcast(v)->value; + n->print_line_numbers = Downcast(v)->value; } if (auto v = config_dict.Get("num_context_lines")) { - n->num_context_lines = Downcast(v)->value; + n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast>>(v).value_or(Array()); @@ -107,13 +107,13 @@ PrinterConfig::PrinterConfig(Map config_dict) { Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { - n->syntax_sugar = Downcast(v)->value; + n->syntax_sugar = Downcast(v)->value; } if (auto v = config_dict.Get("show_object_address")) { - n->show_object_address = Downcast(v)->value; + n->show_object_address = Downcast(v)->value; } if (auto v = config_dict.Get("show_all_struct_info")) { - n->show_all_struct_info = Downcast(v)->value; + n->show_all_struct_info = Downcast(v)->value; } // Checking prefixes if they are valid Python identifiers. diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 379a75f6109b..614669a412d0 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -65,6 +65,22 @@ bool ReflectionVTable::SEqualReduce(const Object* self, const Object* other, return fsequal_reduce_[tindex](self, other, equal); } +namespace { +ObjectPath GetAttrPath(const ObjectRef& obj, const void* attr_address, const ObjectPath& path) { + if (obj->IsInstance() || + obj->IsInstance() || + obj->IsInstance()) { + // Special case for containers that contain boxed primitives. The + // "value" attribute containing the boxed value should not be part + // of the reported mismatched path. + return path; + } else { + Optional attr_key = GetAttrKeyByAddress(obj.get(), attr_address); + return path->Attr(attr_key); + } +} +} // namespace + struct SEqualReducer::PathTracingData { ObjectPathPair current_paths; ObjectRef lhs_object; @@ -72,10 +88,9 @@ struct SEqualReducer::PathTracingData { Optional* first_mismatch; ObjectPathPair GetPathsForAttrs(const ObjectRef& lhs, const ObjectRef& rhs) const { - Optional lhs_attr_key = GetAttrKeyByAddress(lhs_object.get(), &lhs); - Optional rhs_attr_key = GetAttrKeyByAddress(rhs_object.get(), &rhs); - return ObjectPathPair(current_paths->lhs_path->Attr(lhs_attr_key), - current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = GetAttrPath(lhs_object, &lhs, current_paths->lhs_path); + ObjectPath rhs_attr_path = GetAttrPath(rhs_object, &rhs, current_paths->rhs_path); + return ObjectPathPair(lhs_attr_path, rhs_attr_path); } }; @@ -98,13 +113,12 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) { /* static */ void SEqualReducer::GetPathsFromAttrAddressesAndStoreMismatch( const void* lhs_address, const void* rhs_address, const PathTracingData* tracing_data) { if (tracing_data != nullptr && !tracing_data->first_mismatch->defined()) { - Optional lhs_attr_key = - GetAttrKeyByAddress(tracing_data->lhs_object.get(), lhs_address); - Optional rhs_attr_key = - GetAttrKeyByAddress(tracing_data->rhs_object.get(), rhs_address); - *tracing_data->first_mismatch = - ObjectPathPair(tracing_data->current_paths->lhs_path->Attr(lhs_attr_key), - tracing_data->current_paths->rhs_path->Attr(rhs_attr_key)); + ObjectPath lhs_attr_path = + GetAttrPath(tracing_data->lhs_object, lhs_address, tracing_data->current_paths->lhs_path); + ObjectPath rhs_attr_path = + GetAttrPath(tracing_data->rhs_object, rhs_address, tracing_data->current_paths->rhs_path); + + *tracing_data->first_mismatch = ObjectPathPair(lhs_attr_path, rhs_attr_path); } } @@ -200,7 +214,6 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, } // Slow path: tracing object paths for better error reporting - ObjectPathPair new_paths = paths == nullptr ? tracing_data_->GetPathsForAttrs(lhs, rhs) : *paths; if (handler_->SEqualReduce(lhs, rhs, map_free_vars, new_paths)) { diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 334e6e5c9a62..1c795594629e 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -45,6 +45,7 @@ using namespace relax; using namespace tvm::runtime; using namespace tvm::runtime::relax_vm; +namespace { // Helper function to get the function name of the registered packed function implementation of // relax operator. FCallPacked GetPackedFuncName(const Call& call) { @@ -57,6 +58,7 @@ FCallPacked GetPackedFuncName(const Call& call) { } return {}; } +} // namespace /*! * \brief A class to generate VM executable for Relax functions. diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index dd34bc63bb31..5e6a1c3f8442 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -44,6 +44,21 @@ namespace relax_vm { using vm::VMFuncInfo; +namespace { +// Helper function to get the function name of the registered packed function implementation of +// relax operator. +FCallPacked GetPackedFuncName(const Call& call) { + static auto op_map = Op::GetAttrMap("FCallPacked"); + if (call->op.as()) { + Op op = Downcast(call->op); + if (op_map.count(op)) { + return op_map[op]; + } + } + return {}; +} +} // namespace + /*! * \brief A class to generate VMTIR for Relax functions. * @@ -232,7 +247,14 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { } int64_t dst_reg = HasVoidStructInfo(call) ? -1 : NewRegister(); if (call->op.as()) { - if (call_node->op == call_builtin_with_ctx_op_) { + // special case generate for the intrinsics whose attribute fields + // cannot be represented by args in the CallNode + FCallPacked name = GetPackedFuncName(call); + if (name.size()) { + // If the operator has a registered packed function implementation, emit call to that packed + // function. + EmitCallPacked(name, VisitArray(call->args), dst_reg); + } else if (call_node->op == call_builtin_with_ctx_op_) { EmitCallBuiltinWithCtx(call, dst_reg); } else if (call_node->op == alloc_storage_op_) { EmitAllocStorage(call, dst_reg); @@ -260,10 +282,8 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { size_t merge_register = NewRegister(); PrimExpr cond_value = this->VisitExpr(op->cond).value(); - // turn ndarray cond value into scalar. - cond_value = tir::Cast(DataType::Bool(), - tir::Call(DataType::Int(32), tir::builtin::tvm_call_packed(), - {tir::StringImm("vm.builtin.read_if_cond"), cond_value})); + cond_value = tir::Call(DataType::Bool(), tir::builtin::tvm_call_packed(), + {tir::StringImm("vm.builtin.read_if_cond"), cond_value}); tir::Stmt true_branch = WithNewScope([&]() { PrimExpr true_value = this->VisitExpr(op->true_branch).value(); diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index fd6fea6e703c..7aca1470aee4 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -36,7 +36,7 @@ namespace relax { TVM_REGISTER_NODE_TYPE(InitAttrs); /* relax.full */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { +Expr full(Variant> shape, Expr fill_value, DataType dtype) { Expr shape_in_expr{nullptr}; if (const auto* expr = shape.as()) { shape_in_expr = GetRef(expr); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 989eaa12fdbf..6e7c8255238a 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -39,7 +39,7 @@ namespace relax { * If dtype is not given, it will by default use the dtype of fill_value. * \return The result tensor. */ -Expr full(ObjectRef shape, Expr fill_value, DataType dtype); +Expr full(Variant> shape, Expr fill_value, DataType dtype); /*! * \brief Construct a tensor such that diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 07c90756bf90..2b1c6eafb652 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -654,7 +654,7 @@ TVM_REGISTER_OP("relax.permute_dims") .set_attr("FPurity", Bool(true)); /* relax.reshape */ -Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { +Expr ConvertNewShapeToExpr(const Expr& data, const Variant>& shape) { const ArrayNode* array; // Treat shape expressions as constant arrays to handle special values. if (const auto* e = shape.as()) { @@ -747,7 +747,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { return ShapeExpr(array_ref); } -Expr reshape(Expr x, ObjectRef shape) { +Expr reshape(Expr x, Variant> shape) { Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); static const Op& op = Op::Get("relax.reshape"); return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); @@ -812,7 +812,7 @@ TVM_REGISTER_OP("relax.reshape") /* relax.split */ TVM_REGISTER_NODE_TYPE(SplitAttrs); -Expr split(Expr x, ObjectRef indices_or_sections, int axis) { +Expr split(Expr x, Variant> indices_or_sections, int axis) { ObjectPtr attrs = make_object(); if (const auto* indices = indices_or_sections.as()) { for (int i = 0; i < static_cast(indices->size()); ++i) { diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 32aa10776894..68622f1359e0 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -90,7 +90,7 @@ Expr permute_dims(Expr x, Optional> axes); * It is required to be either an Array of PrimExpr, or a Shape in Relax * \return The reshaped result. */ -Expr reshape(Expr x, ObjectRef shape); +Expr reshape(Expr x, Variant> shape); /*! * \brief Split input tensor along axis by sections or indices. @@ -105,7 +105,7 @@ Expr reshape(Expr x, ObjectRef shape); * \param axis The axis over which to split. * \return The computed result. */ -Expr split(Expr x, ObjectRef indices_or_sections, int axis); +Expr split(Expr x, Variant> indices_or_sections, int axis); /*! * \brief Squeeze axes in the array. diff --git a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc index 61b6c9ce897f..345e2d0e60da 100644 --- a/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc +++ b/src/relay/backend/contrib/cmsisnn/compiler_attrs.cc @@ -40,7 +40,7 @@ Target CreateTarget(const tvm::transform::PassContext& ctx) { String mcpu = cfg.value()->mcpu; Array mattr = {cfg.value()->mattr}; - Bool debug_last_error = cfg.value()->debug_last_error; + runtime::Bool debug_last_error = cfg.value()->debug_last_error->value; Target cmsis_nn_target(TargetJSON{ {"kind", String("cmsis-nn")}, diff --git a/src/relay/backend/contrib/cmsisnn/target.cc b/src/relay/backend/contrib/cmsisnn/target.cc index 10125bf814ad..00581a089a4a 100644 --- a/src/relay/backend/contrib/cmsisnn/target.cc +++ b/src/relay/backend/contrib/cmsisnn/target.cc @@ -37,7 +37,7 @@ using FTVMTIRToRuntime = tvm::runtime::TypedPackedFunc>("mattr") .add_attr_option("mcpu") - .add_attr_option("debug_last_error") + .add_attr_option("debug_last_error") .set_attr(tvm::attr::kRelayToTIR, RelayToTIR()) .set_attr("TIRToRuntime", TIRToRuntime) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); diff --git a/src/relay/backend/contrib/cutlass/target.cc b/src/relay/backend/contrib/cutlass/target.cc index 50c8b84a9069..ea040f6ff56a 100644 --- a/src/relay/backend/contrib/cutlass/target.cc +++ b/src/relay/backend/contrib/cutlass/target.cc @@ -39,32 +39,32 @@ namespace cutlass { * src/relay/backend/contrib/cutlass/codegen.cc */ TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForCutlass()) // An integer specifying the compute capability. For example, 75 for Turing and // 80 or 86 for Ampere. - .add_attr_option("sm", Integer(80)) + .add_attr_option("sm", runtime::Int(80)) // Whether to use slower but very accurate (compared to tf32) 3xtf32 mode for // fp32 inputs on tensorcore. - .add_attr_option("use_3xtf32", Bool(true)) + .add_attr_option("use_3xtf32", runtime::Bool(true)) // Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in // parallel across split-K blocks, and a separate global reduction kernel is launched to // accumulate partial reductions. The profiler will pick the best split-k factor from the // given candidate list. Note that the larger split-K factor requires a larger workspace. // Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d // kinds, split_k_slices is ignored. - .add_attr_option>("split_k_slices", Array({1})) + .add_attr_option>("split_k_slices", Array{runtime::Int(1)}) // When True, profile all kernel variants with smaller alignments than the largest possible. - .add_attr_option("profile_all_alignments", Bool(false)) + .add_attr_option("profile_all_alignments", runtime::Bool(false)) // Whether to profile all candidate kernels, or stop profiling after the first applicable kernel // is found. - .add_attr_option("find_first_valid", Bool(false)) + .add_attr_option("find_first_valid", runtime::Bool(false)) // Whether to compile profiler executables for different kernels in parallel. - .add_attr_option("use_multiprocessing", Bool(false)) + .add_attr_option("use_multiprocessing", runtime::Bool(false)) // Number of threads to use during compilation, or -1 to use number of cpus. - .add_attr_option("threads", Integer(-1)) + .add_attr_option("threads", runtime::Int(-1)) // Whether to replace sigmoid with tanh. - .add_attr_option("use_fast_math", Bool(false)) + .add_attr_option("use_fast_math", runtime::Bool(false)) // A temporary directory where intermediate compiled artifacts will be stored. .add_attr_option("tmp_dir", String("./tmp")); diff --git a/src/relay/backend/contrib/ethosn/ethosn_api.cc b/src/relay/backend/contrib/ethosn/ethosn_api.cc index a3f3e6e1eb6e..0f539d96e919 100644 --- a/src/relay/backend/contrib/ethosn/ethosn_api.cc +++ b/src/relay/backend/contrib/ethosn/ethosn_api.cc @@ -687,14 +687,14 @@ EthosnError EthosnAPI::Split(const Expr& expr, SplitParams* params) { sl::TensorInfo(input_tensor_shape, input_data_type, params->input_info.m_DataFormat, params->input_info.m_QuantizationInfo); params->split_info.m_Axis = attrs->axis; - if (attrs->indices_or_sections->IsInstance()) { - auto sections = Downcast(attrs->indices_or_sections)->value; + if (const auto* sections_ptr = attrs->indices_or_sections.as()) { + auto sections = sections_ptr->value; int size = input_tensor_shape[attrs->axis] / sections; for (int i = 0; i < sections; i++) { params->split_info.m_Sizes.push_back(size); } } else { - auto indices = Downcast>(attrs->indices_or_sections); + auto indices = Downcast>(attrs->indices_or_sections); int last_index = 0; for (const auto& i : indices) { params->split_info.m_Sizes.push_back(i->value - last_index); diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index 54d0595c4634..300372838416 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -307,8 +307,7 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { Array compile_artifacts; for (const auto& kv : mod->functions) { const tir::PrimFunc& prim_func = Downcast(kv.second); - Optional> params = - prim_func->GetAttr>("ethos-u.constants"); + auto params = prim_func->GetAttr>("ethos-u.constants"); ICHECK(params) << "microNPU params should be present"; auto primfunc_to_artifact_pf = tvm::runtime::Registry::Get("relay.ext.ethos-u.primfunc_to_artifact"); diff --git a/src/relay/backend/contrib/ethosu/preprocess.cc b/src/relay/backend/contrib/ethosu/preprocess.cc index 23a873b2d392..d87447f863e2 100644 --- a/src/relay/backend/contrib/ethosu/preprocess.cc +++ b/src/relay/backend/contrib/ethosu/preprocess.cc @@ -97,7 +97,7 @@ class ExternalFuncIOHandler : public ExprRewriter { Expr CreateSplitReshapedTensors(const Expr& input, const Array& original_args) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; Array rets; int total_size = 0; @@ -132,7 +132,7 @@ class ExternalFuncIOHandler : public ExprRewriter { if (func->params.size() > 1) { Array> shapes; Array flatten_tensor_sizes; - Array split_indices; + Array split_indices; auto func_name = gv->name_hint; int total_size = 0; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index b45987f6be33..de9c81a2706e 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -38,6 +38,6 @@ TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU) .set_attr(attr::kRelayToTIR, relay::contrib::example_target_hooks::RelayToTIR()) .set_attr("TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime) - .add_attr_option("example_attribute", Integer(0)); + .add_attr_option("example_attribute", Integer(0)); } // namespace tvm diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index f4babad50a3e..1dd5e3a4d772 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -177,12 +177,12 @@ class CollectFromCompositeFunctionBody : public ExprVisitor { std::vector indices_or_sections; std::vector mode; std::vector axis = {std::to_string(split_attr->axis)}; - if (const auto* sections = split_attr->indices_or_sections.as()) { + if (const auto* sections = split_attr->indices_or_sections.as()) { mode.emplace_back("sections"); indices_or_sections.emplace_back(std::to_string(sections->value)); } else { mode.emplace_back("indices"); - auto indices = Downcast>(split_attr->indices_or_sections); + auto indices = Downcast>(split_attr->indices_or_sections); for (const auto& i : indices) { indices_or_sections.emplace_back(std::to_string(i->value)); } diff --git a/src/relay/backend/contrib/tensorrt/target.cc b/src/relay/backend/contrib/tensorrt/target.cc index 0277787a8c12..a62dc25e329c 100644 --- a/src/relay/backend/contrib/tensorrt/target.cc +++ b/src/relay/backend/contrib/tensorrt/target.cc @@ -38,30 +38,30 @@ namespace tensorrt { * - Runtime: src/runtime/contrib/tensorrt/... */ TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)) .set_attr("RelayToTIR", CompileForTensorRT()) // A array of three integers given the major, minor, and patch numbers for the supported // TensorRT compiler version. If empty will be auto-detected from linked library. Default empty. - .add_attr_option>("tensorrt_version", Array()) + .add_attr_option>("tensorrt_version", Array()) // If true, the first tensor dimension for most operators is allowed to be Any and // TensorRT will assume it represents a batch dimension only known at inference time. // Fewer Relay operators are supported in implicit batch mode. Default true. - .add_attr_option("use_implicit_batch", Bool(true)) + .add_attr_option("use_implicit_batch", runtime::Bool(true)) // If true, excludes sub-graphs which do not have multiply-accumulate operations, even though // TensorRT supports them. ad. This is a simple heuristic to optimize the partitioning between // TensorRT and TVM. Not required if using Collage for partitioning. Defalut false. - .add_attr_option("remove_no_mac_subgraphs", Bool(false)) + .add_attr_option("remove_no_mac_subgraphs", runtime::Bool(false)) // How many bytes of workspace size to allow each subgraph to use for TensorRT engine creation. // Default 1G. - .add_attr_option("max_workspace_size", Integer(1 << 30)) + .add_attr_option("max_workspace_size", runtime::Int(1 << 30)) // If true, allows TensorRT to automatically convert float32 operations to float16. Must also be // enabled if any float16 operations are in the model. Note that TensorRT may still choose a // higher-precision kernel if it results in overall lower runtime, or if no low-precision // implementation exists. Default false. - .add_attr_option("use_fp16", Bool(false)) + .add_attr_option("use_fp16", runtime::Bool(false)) // If true, allows TensorRT to automatically convert float32 operations to uint8 // (aka quantized). Default false. - .add_attr_option("use_uint8", Bool(false)); + .add_attr_option("use_uint8", runtime::Bool(false)); } // namespace tensorrt } // namespace contrib diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc index 244f243749c1..0499c0bba198 100644 --- a/src/relay/backend/contrib/uma/targets.cc +++ b/src/relay/backend/contrib/uma/targets.cc @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") .add_attr_option("model") .add_attr_option>("libs") .add_attr_option("host") - .add_attr_option("from_device") + .add_attr_option("from_device") .set_attr( attr::kRelayToTIR, relay::contrib::uma::RelayToTIR(target_name)) .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); @@ -75,8 +75,9 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") } if (default_value->IsInstance()) { target_kind.add_attr_option(option_name, Downcast(default_value)); - } else if (default_value->IsInstance()) { - target_kind.add_attr_option(option_name, Downcast(default_value)); + } else if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, + Downcast(default_value)); } else { LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. " << "Given attribute option type: " << attr_option.second->GetTypeKey(); diff --git a/src/relay/backend/executor.cc b/src/relay/backend/executor.cc index 1d6caecb87ba..66feac4699e6 100644 --- a/src/relay/backend/executor.cc +++ b/src/relay/backend/executor.cc @@ -89,13 +89,13 @@ ExecutorRegEntry& ExecutorRegEntry::RegisterOrGet(const String& name) { /********** Register Executors and options **********/ TVM_REGISTER_EXECUTOR("aot") - .add_attr_option("link-params", Bool(true)) - .add_attr_option("unpacked-api") + .add_attr_option("link-params", runtime::Bool(true)) + .add_attr_option("unpacked-api") .add_attr_option("interface-api") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constant-byte-alignment"); + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constant-byte-alignment"); -TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", Bool(false)); +TVM_REGISTER_EXECUTOR("graph").add_attr_option("link-params", runtime::Bool(false)); /********** Registry **********/ diff --git a/src/relay/backend/runtime.cc b/src/relay/backend/runtime.cc index 923c9b2d5f65..0534298ea44d 100644 --- a/src/relay/backend/runtime.cc +++ b/src/relay/backend/runtime.cc @@ -88,9 +88,9 @@ RuntimeRegEntry& RuntimeRegEntry::RegisterOrGet(const String& name) { /********** Register Runtimes and options **********/ -TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCrt).add_attr_option("system-lib"); -TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); +TVM_REGISTER_RUNTIME(kTvmRuntimeCpp).add_attr_option("system-lib"); /********** Registry **********/ diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 0c0ff7290115..3e86e1c8eaf9 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -73,6 +73,42 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp } bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + // Unwrapping arrays may find user-provided FFI types in the + // attributes (e.g. Defining pad_value as ((0,0), (0,0)) will result + // in runtime::Int. These need to be converted to compile-time IR + // types when encountered. + if (lhs->IsInstance() || + lhs->IsInstance() || + lhs->IsInstance()) { + TVMRetValue lhs_convert; + lhs_convert = lhs; + PrimExpr lhs_expr = lhs_convert; + return MatchRetValue(lhs_expr, rhs); + } + + // StructuralEqual doesn't check for conversions between FFI types + // and IR types, but the pattern-matcher should. Therefore, + // explicitly recurse into the array. + if (auto opt_lhs_array = lhs.as>()) { + if (Optional> opt_rhs_array = rhs) { + Array lhs_array = opt_lhs_array.value(); + Array rhs_array = opt_rhs_array.value(); + if (lhs_array.size() != rhs_array.size()) { + return false; + } + for (size_t i = 0; i < lhs_array.size(); i++) { + TVMRetValue rhs_item; + rhs_item = rhs_array[i]; + if (!MatchRetValue(lhs_array[i], rhs_item)) { + return false; + } + } + return true; + } else { + return false; + } + } + switch (rhs.type_code()) { case kDLInt: if (auto* val = lhs.as()) { diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index 50d8531c7dd0..222aba4bd25b 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -79,7 +79,7 @@ Expr MakeReshape(Expr data, Array newshape, bool allowzero = false); Expr MakeReshapeLike(Expr lhs, Expr rhs, int lhs_begin, Integer lhs_end, int rhs_begin, Integer rhs_end); -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis); +Expr MakeSplit(Expr data, Variant> indices_or_sections, int axis); Expr MakeSqueeze(Expr data, Array axis); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index fde6daa4d851..96f833d80505 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2984,10 +2984,10 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, Layout ret = Layout::Undef(); size_t size = 0; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { size = sections->value; } else { - size = Downcast>(param->indices_or_sections).size() + 1; + size = Downcast>(param->indices_or_sections).size() + 1; } // If new_in_layouts are defined, this code tries to modify the layout. @@ -2998,13 +2998,12 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, param->axis = new_index; int factor = new_in_layouts[0].FactorOf(sp_dim); if (factor > 1) { - if (!param->indices_or_sections.as()) { - auto ios = Downcast>(param->indices_or_sections); - Array new_ios; + if (!param->indices_or_sections.as()) { + auto ios = Downcast>(param->indices_or_sections); + Array new_ios; for (const auto& v : ios) { - const IntImmNode* vint = v.as(); - new_ios.push_back(vint->value / factor); - if (vint->value % factor) { + new_ios.push_back(runtime::Int(v->value / factor)); + if (v->value % factor) { divisible = false; } } @@ -3041,7 +3040,7 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, ICHECK_LT(axis, data->shape.size()) << "axis should be within the input dimension range."; ICHECK_GE(axis, 0) << "axis should be within the input dimension range."; - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { if (!data->shape[axis].as()) { ICHECK(reporter->Assert(indexmod(data->shape[axis], sections->value) == tir::make_zero(DataType::Int(64)))) @@ -3061,8 +3060,8 @@ bool SplitRel(const Array& types, int num_inputs, const Attrs& attrs, reporter->Assign(types[1], TupleType(Array(fields))); } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } auto begin = IndexExpr(tir::make_zero(DataType::Int(32))); std::vector fields; @@ -3097,19 +3096,20 @@ Array SplitCompute(const Attrs& attrs, const Array& inpu const auto param = attrs.as(); ICHECK(param != nullptr); - if (const IntImmNode* sections = param->indices_or_sections.as()) { + if (const auto* sections = param->indices_or_sections.as()) { int64_t num_sections = sections->value; return Array{topi::split_sections(inputs[0], num_sections, param->axis)}; } else { Array indices; - for (auto i : Downcast>(param->indices_or_sections)) { - indices.push_back(IntImm(DataType::Int(32), i.as()->value)); + for (auto index : Downcast>(param->indices_or_sections)) { + indices.push_back(IntImm(DataType::Int(32), index->value)); } return Array{topi::split(inputs[0], indices, param->axis)}; } } -Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { +Expr MakeSplit(Expr data, Variant> indices_or_sections, + int axis) { auto attrs = make_object(); attrs->axis = axis; attrs->indices_or_sections = std::move(indices_or_sections); @@ -3117,17 +3117,7 @@ Expr MakeSplit(Expr data, ObjectRef indices_or_sections, int axis) { return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make.split").set_body([](const TVMArgs& args, TVMRetValue* rv) { - if (args.type_codes[1] == kDLInt) { - // Note: we change it from Int(64) to Int(32) for now as - // combine_parallel_dense will transform the graph with Int(32). - // More invetigation is needs to check which one we should use. - *rv = - MakeSplit(args[0], tir::make_const(DataType::Int(32), static_cast(args[1])), args[2]); - } else { - *rv = MakeSplit(args[0], args[1], args[2]); - } -}); +TVM_REGISTER_GLOBAL("relay.op._make.split").set_body_typed(MakeSplit); RELAY_REGISTER_OP("split") .describe(R"code(Splits an array along a particular axis into multiple sub-arrays. @@ -4157,11 +4147,13 @@ bool ScanopRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } -Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Bool exclusive) { +Expr MakeCumsum(Expr data, Integer axis, DataType dtype, Optional exclusive) { auto attrs = make_object(); attrs->dtype = dtype; attrs->axis = axis; - attrs->exclusive = exclusive; + if (exclusive.defined()) { + attrs->exclusive = exclusive.value(); + } static const Op& op = Op::Get("cumsum"); return Call(op, {data}, Attrs(attrs), {}); } diff --git a/src/relay/transforms/combine_parallel_op_batch.cc b/src/relay/transforms/combine_parallel_op_batch.cc index a41e1e0d6674..74827f166b51 100644 --- a/src/relay/transforms/combine_parallel_op_batch.cc +++ b/src/relay/transforms/combine_parallel_op_batch.cc @@ -159,7 +159,7 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, ExprSubstMap* subst_map) { int index = 0; - auto split = MakeSplit(data, Integer(branches.size()), 0); + auto split = MakeSplit(data, runtime::Int(branches.size()), 0); for (const auto& branch : branches) { auto split_data = TupleGetItem(split, index++); auto squeezed_data = MakeSqueeze(split_data, {0}); diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 34f986b251a2..df28506c6217 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -266,7 +266,7 @@ class ConstantFolder : public MixedModeMutator { // always use graph executor with no link-params dict.Set(tvm::attr::kExecutor, - relay::Executor::Create("graph", {{"link-params", Bool(false)}})); + relay::Executor::Create("graph", {{"link-params", runtime::Bool(false)}})); Expr result = ObjectToExpr(Eval(expr, module_->type_definitions, module_->Imports(), eval_cpu_dev_, eval_cpu_target_, dict)); VLOG(1) << "Evaluated to constant:" << std::endl << PrettyPrint(result); diff --git a/src/relay/transforms/higher_order_gradient.cc b/src/relay/transforms/higher_order_gradient.cc index edf1e4c99f4d..da7a8f6420cd 100644 --- a/src/relay/transforms/higher_order_gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -36,8 +36,6 @@ namespace tvm { namespace relay { -using namespace tvm::runtime; - /*! What is automatic differentiation(AD) and why is it important? * By AD, we roughly mean, given a term which denotes some mathematical function, * derive a term which denotes the derivative of that mathematical function. diff --git a/src/relay/transforms/to_mixed_precision.cc b/src/relay/transforms/to_mixed_precision.cc index 5026b1bcba79..1112755b76a0 100644 --- a/src/relay/transforms/to_mixed_precision.cc +++ b/src/relay/transforms/to_mixed_precision.cc @@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map, // Return array is of type : [MixedTypeConversionCategory (int), String, String] // The fields are : [ConversionCategory, accumulation_datatype, output_datatype] // Call is a call node, DataType is the mixed precision type -using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc( +using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc>( const Call& call_node, const std::string& target_dtype_str)>; /*! \brief This class transforms the given relay module into a version where @@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator { if (attr_map.count(op)) { // Calculate the conversion category and dtypes from registered attribute. FTVMMixedPrecisionConversionType func = attr_map[op]; - Array op_descriptor = + Array> op_descriptor = func(GetRef(pre_call_node), DLDataType2String(mixed_precision_type_)); ICHECK(op_descriptor.size() == 3) << "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size() diff --git a/src/runtime/boxed_primitive.cc b/src/runtime/boxed_primitive.cc new file mode 100644 index 000000000000..9ab83a7b471c --- /dev/null +++ b/src/runtime/boxed_primitive.cc @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/runtime/boxed_primitive.cc + * \brief Implementations of ObjectRef wrapper. + */ + +#include +#include + +namespace tvm { +namespace runtime { + +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); +TVM_REGISTER_OBJECT_TYPE(BoxNode); + +/* \brief Allow explicit construction of Box + * + * Convert a `bool` to `Box`. For use in FFI handling, to + * provide an umambiguous representation between `bool(true)` and + * `int(1)`. Will be automatically unboxed in the case where a + * `Box` is provided to a PackedFunc that requires `int` input, + * mimicking C++'s default conversions. + * + * This is only needed for Box, as Box and Box + * can be converted in C++ as part of `TVMArgValue::operator + * ObjectRef()` without ambiguity, postponing conversions until + * required. + */ +TVM_REGISTER_GLOBAL("runtime.BoxBool").set_body_typed([](bool value) { return Box(value); }); + +/* \brief Return the underlying boolean object. + * + * Used while unboxing a boolean return value during FFI handling. + * The return type is intentionally `int` and not `bool`, to avoid + * recursive unwrapping of boolean values. + * + * This is only needed for Box, as Box and Box + * can be unambiguously unboxed as part of + * `TVMRetValue::operator=(ObjectRef)`. + */ +TVM_REGISTER_GLOBAL("runtime.UnBoxBool").set_body_typed([](Box obj) -> int { + return obj->value; +}); + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index 57979b160ea7..04d36ad8bcab 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -361,14 +361,18 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r TVMAPISetLastError("ModuleGetFunction expects second argument to be a string"); return kTvmErrorFunctionCallWrongArgType; } - if (type_codes[2] != kDLInt) { + + if (type_codes[2] == kDLInt) { + query_imports = args[2].v_int64 != 0; + } else if (type_codes[2] == kTVMArgBool) { + query_imports = args[2].v_bool; + } else { TVMAPISetLastError("ModuleGetFunction expects third argument to be an integer"); return kTvmErrorFunctionCallWrongArgType; } mod = (TVMModuleHandle)args[0].v_handle; name = args[1].v_str; - query_imports = args[2].v_int64 != 0; to_return = TVMModGetFunction(mod, name, query_imports, &ret_value->v_handle); if (to_return == 0) { diff --git a/src/runtime/disco/bcast_session.cc b/src/runtime/disco/bcast_session.cc index 493bc3fb1dc9..f7204e372f6d 100644 --- a/src/runtime/disco/bcast_session.cc +++ b/src/runtime/disco/bcast_session.cc @@ -102,10 +102,10 @@ DRef BcastSessionObj::CallWithPacked(const TVMArgs& args) { int cnt = 0; for (int i = 3; i < num_args; ++i) { int type_code = type_codes[i]; - if (type_code != kDLInt && type_code != kDLUInt && type_code != kDLFloat && - type_code != kTVMDataType && type_code != kDLDevice && type_code != kTVMOpaqueHandle && - type_code != kTVMStr && type_code != kTVMNullptr && type_code != kTVMBytes && - type_code != kTVMObjectHandle) { + if (type_code != kDLInt && type_code != kDLUInt && type_code != kTVMArgBool && + type_code != kDLFloat && type_code != kTVMDataType && type_code != kDLDevice && + type_code != kTVMOpaqueHandle && type_code != kTVMStr && type_code != kTVMNullptr && + type_code != kTVMBytes && type_code != kTVMObjectHandle) { os << "\n Argument #" << i - 3 << " has unsupported type code: " << type_code << " (" << ArgTypeCode2Str(type_code) << ")"; cnt += 1; diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index d08dadb02bb9..485ebdb449da 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -325,6 +325,10 @@ struct RPCReference { channel->template Write(value.v_int64); break; } + case kTVMArgBool: { + channel->template Write(value.v_bool); + break; + } case kTVMDataType: { channel->Write(value.v_type); // padding @@ -432,6 +436,10 @@ struct RPCReference { channel->template Read(&(value.v_int64)); break; } + case kTVMArgBool: { + channel->template Read(&(value.v_bool)); + break; + } case kTVMDataType: { channel->Read(&(value.v_type)); int32_t padding = 0; diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 2af31f1d4021..af1cf9d20335 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -279,7 +279,11 @@ TVM_REGISTER_GLOBAL("vm.builtin.check_shape_info").set_body_typed(CheckShapeInfo * \param err_ctx Additional context if error occurs. */ void CheckPrimValueInfo(TVMArgValue arg, DataType dtype, Optional err_ctx) { - if (dtype.is_bool()) { + if (arg.IsObjectRef()) { + ObjectRef obj = arg.AsObjectRef(); + LOG(FATAL) << "TypeError: " << err_ctx.value_or("") << ", expected dtype " << dtype + << ", but received ObjectRef of type " << obj->GetTypeKey(); + } else if (dtype.is_bool()) { arg.operator bool(); } else if (dtype.is_int()) { arg.operator int64_t(); @@ -426,7 +430,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.to_device") * \return Bool */ bool ReadIfCond(TVMArgValue cond) { - if (cond.type_code() == kDLInt) return cond.operator bool(); + if (cond.type_code() == kDLInt || cond.type_code() == kTVMArgBool) { + return cond.operator bool(); + } NDArray arr = cond.operator tvm::runtime::NDArray(); if (arr->device.device_type != kDLCPU) { arr = arr.CopyTo(DLDevice{kDLCPU, 0}); diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 54194e7e2a41..61bdec680a29 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -323,12 +323,33 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } else if (const auto* float_imm = value.as()) { // TODO(yelite): Make float number printing roundtrippable - output_.precision(17); if (std::isinf(float_imm->value) || std::isnan(float_imm->value)) { output_ << '"' << float_imm->value << '"'; + } else if (std::nearbyint(float_imm->value) == float_imm->value) { + // Special case for floating-point values which would be + // formatted using %g, are not displayed in scientific + // notation, and whose fractional part is zero. + // + // By default, using `operator<<(std::ostream&, double)` + // delegates to the %g printf formatter. This strips off any + // trailing zeros, and also strips the decimal point if no + // trailing zeros are found. When parsed in python, due to the + // missing decimal point, this would incorrectly convert a float + // to an integer. Providing the `std::showpoint` modifier + // instead delegates to the %#g printf formatter. On its own, + // this resolves the round-trip errors, but also prevents the + // trailing zeros from being stripped off. + std::showpoint(output_); + std::fixed(output_); + output_.precision(1); + output_ << float_imm->value; } else { + std::defaultfloat(output_); + std::noshowpoint(output_); + output_.precision(17); output_ << float_imm->value; } + } else if (const auto* string_obj = value.as()) { output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { diff --git a/src/script/printer/ir/misc.cc b/src/script/printer/ir/misc.cc index ef68b89b5bf4..686f486da6eb 100644 --- a/src/script/printer/ir/misc.cc +++ b/src/script/printer/ir/misc.cc @@ -30,6 +30,21 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return LiteralDoc::Str(s, p); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Bool obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Boolean(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Int obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Int(obj->value, p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("", [](runtime::Float obj, ObjectPath p, IRDocsifier d) -> Doc { + return LiteralDoc::Float(obj->value, p); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch>( // "", [](Array array, ObjectPath p, IRDocsifier d) -> Doc { diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 6f9a8cbf8918..35a9f35db491 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -75,7 +75,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "relax", [](tvm::IntImm n, ObjectPath n_p, IRDocsifier d) -> Doc { // // TODO(@junrushao): support non-int64 cases - return LiteralDoc::Int(n->value, n_p); + if (n->dtype.is_bool()) { + return LiteralDoc::Boolean(n->value, n_p); + } else { + return LiteralDoc::Int(n->value, n_p); + } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/support/array.h b/src/support/array.h index 0ca57a2410c5..0d4c8134787b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -164,12 +164,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -177,12 +179,14 @@ struct AsVectorImpl { template struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { + inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -191,11 +195,13 @@ struct AsVectorImpl { template struct AsVectorImpl { inline std::vector operator()(const Array& array) const { + TVMRetValue ret_value; + ret_value = array; + Array as_int_vec = ret_value; + std::vector results; - for (const TSrcObjectRef& x : array) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); + for (const auto& value : as_int_vec) { + results.push_back(value->value); } return results; } @@ -221,8 +227,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -233,8 +241,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (int64_t x : vec) { - result.push_back(Integer(x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } @@ -245,8 +255,10 @@ struct AsArrayImpl { inline Array operator()(const std::vector& vec) const { Array result; result.reserve(vec.size()); - for (double x : vec) { - result.push_back(FloatImm(tvm::DataType::Float(64), x)); + for (auto x : vec) { + TVMRetValue ret_value; + ret_value = x; + result.push_back(ret_value); } return result; } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index aec57a1eb20d..928cdfcab80b 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -189,6 +189,58 @@ TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Varian TVM_REGISTER_GLOBAL("testing.AcceptsVariant") .set_body_typed([](Variant arg) -> String { return arg->GetTypeKey(); }); +TVM_REGISTER_GLOBAL("testing.AcceptsBool").set_body_typed([](bool arg) -> bool { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsInt").set_body_typed([](int arg) -> int { return arg; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRef").set_body_typed([](ObjectRef arg) -> ObjectRef { + return arg; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsObjectRefArray") + .set_body_typed([](Array arg) -> ObjectRef { return arg[0]; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsValue") + .set_body_typed([](Map map, ObjectRef key) -> ObjectRef { + return map[key]; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapReturnsMap") + .set_body_typed([](Map map) -> ObjectRef { return map; }); + +TVM_REGISTER_GLOBAL("testing.AcceptsPrimExpr").set_body_typed([](PrimExpr expr) -> ObjectRef { + return expr; +}); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfPrimExpr") + .set_body_typed([](Array arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance()) + << "Array contained " << item->GetTypeKey() << " when it should contain PrimExpr"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsArrayOfVariant") + .set_body_typed([](Array> arr) -> ObjectRef { + for (ObjectRef item : arr) { + CHECK(item->IsInstance() || item->IsInstance()) + << "Array contained " << item->GetTypeKey() + << " when it should contain either PrimExpr or PackedFunc"; + } + return arr; + }); + +TVM_REGISTER_GLOBAL("testing.AcceptsMapOfPrimExpr") + .set_body_typed([](Map map) -> ObjectRef { + for (const auto& kv : map) { + ObjectRef value = kv.second; + CHECK(value->IsInstance()) + << "Map contained " << value->GetTypeKey() << " when it should contain PrimExpr"; + } + return map; + }); + /** * Simple event logger that can be used for testing purposes */ diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 481ba39cc7b1..21899a12c4b0 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -347,18 +347,26 @@ CodeGenLLVM::TypedPointer CodeGenCPU::CreateStructRefPtr(DataType t, llvm::Value } case builtin::kTVMValueContent: { ICHECK_EQ(t.lanes(), 1); - ICHECK(t.is_handle() || t.bits() == 64); - if (t.is_int()) { + if (t.is_bool()) { + // The stride between adjacent entries is still + // `sizeof(TVMValue)==64`, even if the enum currently holds a + // boolean. + buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); + buf = builder_->CreateInBoundsGEP(t_int64_, buf, index); + buf = builder_->CreatePointerCast(buf, DTypeToLLVMType(t)->getPointerTo()); + return TypedPointer(t_int8_, buf); + } else if (t.is_int() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo()); return TypedPointer(t_int64_, builder_->CreateInBoundsGEP(t_int64_, buf, index)); - } else if (t.is_float()) { + } else if (t.is_float() && t.bits() == 64) { buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo()); return TypedPointer(t_float64_, builder_->CreateInBoundsGEP(t_float64_, buf, index)); - } else { - ICHECK(t.is_handle()); + } else if (t.is_handle()) { buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo()); buf = builder_->CreateInBoundsGEP(t_tvm_value_, buf, index); return TypedPointer(t_void_p_, builder_->CreatePointerCast(buf, t_void_p_->getPointerTo())); + } else { + LOG(DEBUG) << "DataType " << t << " cannot be stored into a TVMValue"; } } default: @@ -1366,9 +1374,16 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), MakeValue(op->args[1]), kind); if (kind == builtin::kArrAddr) { return builder_->CreatePointerCast(ref.addr, t_void_p_); - } else { - return builder_->CreateLoad(ref.type, ref.addr); } + + llvm::Value* struct_value = builder_->CreateLoad(ref.type, ref.addr); + + if (op->dtype == DataType::Bool()) { + struct_value = CreateCast(DataType::Int(8), op->dtype, struct_value); + } + + return struct_value; + } else if (op->op.same_as(builtin::tvm_struct_set())) { ICHECK_EQ(op->args.size(), 4U); int kind = op->args[2].as()->value; diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index dd5a3fb681ee..0406dcf951bb 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -294,10 +294,10 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) target_options_.MCOptions.ABIName = Downcast(target.Get("mabi")); } - auto maybe_level = Downcast(target.Get("opt-level")); + auto maybe_level = target.Get("opt-level").as(); #if TVM_LLVM_VERSION <= 170 if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOpt::None; } else if (level == 1) { @@ -313,7 +313,7 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) } #else if (maybe_level.defined()) { - int level = maybe_level->value; + int level = maybe_level.value()->value; if (level <= 0) { opt_level_ = llvm::CodeGenOptLevel::None; } else if (level == 1) { @@ -333,8 +333,12 @@ LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const TargetJSON& target) // Fast math options - auto GetBoolFlag = [&target](llvm::StringRef flag) -> bool { - return Downcast(target.Get(flag.str()).value_or(Bool(false))); + auto GetBoolFlag = [&target](llvm::StringRef name) -> bool { + if (auto flag = target.Get(name.str())) { + return Downcast(flag); + } else { + return false; + } }; if (GetBoolFlag("fast-math")) { #if TVM_LLVM_VERSION >= 60 diff --git a/src/target/tag.cc b/src/target/tag.cc index 9eca3072df0e..d45bf61a38f1 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -76,61 +76,61 @@ TVM_REGISTER_TARGET_TAG("raspberry-pi/4b-aarch64") {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}, + {"num-cores", runtime::Int(4)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a72")}, {"mattr", Array{"+neon"}}, - {"num-cores", Integer(4)}}}}); + {"num-cores", runtime::Int(4)}}}}); #if TVM_LLVM_VERSION >= 110 TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-xavier") .set_config({{"kind", String("cuda")}, {"arch", String("sm_72")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-orin-nano") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("carmel")}, - {"num-cores", Integer(6)}}}}); + {"num-cores", runtime::Int(6)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-32gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(8)}}}}); + {"num-cores", runtime::Int(8)}}}}); TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") .set_config({{"kind", String("cuda")}, {"arch", String("sm_87")}, - {"max_shared_memory_per_block", Integer(49152)}, - {"max_threads_per_block", Integer(1024)}, - {"thread_warp_size", Integer(32)}, - {"registers_per_block", Integer(65536)}, + {"max_shared_memory_per_block", runtime::Int(49152)}, + {"max_threads_per_block", runtime::Int(1024)}, + {"thread_warp_size", runtime::Int(32)}, + {"registers_per_block", runtime::Int(65536)}, {"host", Map{{"kind", String("llvm")}, {"mtriple", String("aarch64-linux-gnu")}, {"mcpu", String("cortex-a78")}, - {"num-cores", Integer(12)}}}}); + {"num-cores", runtime::Int(12)}}}}); #endif // TVM_LLVM_VERSION >= 110 #endif // TVM_LLVM_HAS_AARCH64_TARGET @@ -139,10 +139,10 @@ TVM_REGISTER_TARGET_TAG("nvidia/jetson-agx-orin-64gb") {"kind", String("cuda")}, \ {"keys", Array{"cuda", "gpu"}}, \ {"arch", String(Arch)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"max_threads_per_block", Integer(1024)}, \ - {"thread_warp_size", Integer(32)}, \ - {"registers_per_block", Integer(RegPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"max_threads_per_block", runtime::Int(1024)}, \ + {"thread_warp_size", runtime::Int(32)}, \ + {"registers_per_block", runtime::Int(RegPerBlock)}, \ }) // Naming convention for CUDA tags see https://developer.nvidia.com/cuda-gpus @@ -158,9 +158,9 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2075", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(41943040)); + .with_config("l2_cache_size_bytes", runtime::Int(41943040)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90a", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(52428800)); + .with_config("l2_cache_size_bytes", runtime::Int(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); @@ -263,7 +263,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/nvs-5400m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-5200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvs-4200m", "sm_21", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-4090", "sm_89", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(75497472)); + .with_config("l2_cache_size_bytes", runtime::Int(75497472)); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090-ti", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3090", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/geforce-rtx-3080-ti", "sm_86", 49152, 65536); @@ -416,7 +416,7 @@ TVM_REGISTER_CUDA_TAG("nvidia/tegra-x1", "sm_53", 49152, 32768); TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", String("llvm")}, \ {"keys", Array{"x86", "cpu"}}, \ {"mcpu", String(Arch)}, \ - {"num-cores", Integer(Cores)}}); + {"num-cores", runtime::Int(Cores)}}); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.large", 1, "skylake-avx512"); TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.xlarge", 2, "skylake-avx512"); @@ -432,9 +432,9 @@ TVM_REGISTER_TAG_AWS_C5("aws/cpu/c5.24xlarge", 48, "cascadelake"); #define TVM_REGISTER_METAL_GPU_TAG(Name, ThreadsPerBlock, SharedMem, WarpSize) \ TVM_REGISTER_TARGET_TAG(Name).set_config( \ {{"kind", String("metal")}, \ - {"max_threads_per_block", Integer(ThreadsPerBlock)}, \ - {"max_shared_memory_per_block", Integer(SharedMem)}, \ - {"thread_warp_size", Integer(WarpSize)}, \ + {"max_threads_per_block", runtime::Int(ThreadsPerBlock)}, \ + {"max_shared_memory_per_block", runtime::Int(SharedMem)}, \ + {"thread_warp_size", runtime::Int(WarpSize)}, \ {"host", Map{{"kind", String("llvm")}, \ {"mtriple", String("arm64-apple-macos")}, \ {"mcpu", String("apple-latest")}}}}); diff --git a/src/target/target.cc b/src/target/target.cc index cd2e3714e422..a8337b58ae9b 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -359,24 +359,31 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { std::string interp_str = Interpret(str); - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing integer + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex() || + info.type_index == runtime::Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + // Parsing integer or boolean std::istringstream is(interp_str); int v; if (!(is >> v)) { std::string lower(interp_str.size(), '\x0'); std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); - // Bool is a subclass of IntImm, so allow textual boolean values. + // Mimic C++ automatic conversions, allowing bool to be used for + // integer parameters. if (lower == "true") { v = 1; } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); + throw Error(": Cannot parse integer from string: " + interp_str); } } - return Integer(v); + + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return runtime::Int(v); + } else { + return runtime::Bool(v); + } } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing string, strip leading/trailing spaces, and enclosing quotes if any auto start = interp_str.find_first_not_of(' '); @@ -410,13 +417,13 @@ ObjectRef TargetInternal::ParseType(const std::string& str, ObjectRef TargetInternal::ParseType(const ObjectRef& obj, const TargetKindNode::ValueTypeInfo& info) { - if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + if (info.type_index == runtime::Int::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - return GetRef(ObjTypeCheck(obj, "Integer")); - } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + return GetRef(ObjTypeCheck(obj, "runtime.BoxInt")); + } else if (info.type_index == String::ContainerType::RuntimeTypeIndex()) { // Parsing string return GetRef(ObjTypeCheck(obj, "String")); - } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { + } else if (info.type_index == Target::ContainerType::RuntimeTypeIndex()) { // Parsing target if (auto opt = obj.as()) { return opt.value(); @@ -483,7 +490,11 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { - if (const auto* p = obj.as()) { + if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { + return std::to_string(p->value); + } else if (const auto* p = obj.as()) { return std::to_string(p->value); } if (auto tvm_str = obj.as()) { @@ -494,7 +505,7 @@ std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { } return u; } - LOG(FATAL) << "Cannot stringify this object"; + LOG(FATAL) << "Cannot stringify object of type " << obj->GetTypeKey(); } std::string TargetInternal::StringifyArray(const ArrayNode& array) { @@ -953,7 +964,7 @@ ObjectPtr TargetInternal::FromConfig(Map config) { // If requested, query attributes from the device. User-specified // parameters take precedence over queried parameters. if (attrs.count("from_device")) { - int device_id = Downcast(attrs.at("from_device")).IntValue(); + int device_id = Downcast(attrs.at("from_device"))->value; attrs.erase("from_device"); auto device_params = QueryDevice(device_id, target.get()); @@ -1006,38 +1017,13 @@ std::unordered_map TargetInternal::QueryDevice(int device_id, for (const auto& kv : target->kind->key2vtype_) { const String& key = kv.first; - const TargetKindNode::ValueTypeInfo& type_info = kv.second; TVMRetValue ret; api->GetTargetProperty(device, key, &ret); - switch (ret.type_code()) { - case kTVMNullptr: - // Nothing returned for this parameter, move on to the next one. - continue; - - case kTVMArgInt: - if (type_info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Integer(static_cast(ret)); - } else if (type_info.type_index == Bool::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - output[key] = Bool(static_cast(ret)); - } else { - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received integer from device api"; - } - break; - - case kTVMStr: - ICHECK_EQ(type_info.type_index, String::ContainerType::_GetOrAllocRuntimeTypeIndex()) - << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received string from device api"; - output[key] = String(ret.operator std::string()); - break; - - default: - LOG(FATAL) << "Expected " << type_info.type_key << " parameter for attribute '" << key - << "', but received TVMArgTypeCode(" << ret.type_code() << ") from device api"; - break; + // Delegate conversion from TVMRetValue to the FFI's default conversions. + if (Optional opt = ret) { + output[key] = opt.value(); } } diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 708d3ccd7621..fced74c3a559 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -243,7 +243,7 @@ TargetJSON UpdateROCmAttrs(TargetJSON target) { * \return The updated attributes */ TargetJSON TestTargetParser(TargetJSON target) { - Map features = {{"is_test", Bool(true)}}; + Map features = {{"is_test", runtime::Bool(true)}}; target.Set("features", features); return target; } @@ -256,16 +256,16 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .add_attr_option("mtriple") .add_attr_option("mfloat-abi") .add_attr_option("mabi") - .add_attr_option("num-cores") + .add_attr_option("num-cores") // Fast math flags, see https://llvm.org/docs/LangRef.html#fast-math-flags - .add_attr_option("fast-math") // implies all the below - .add_attr_option("fast-math-nnan") - .add_attr_option("fast-math-ninf") - .add_attr_option("fast-math-nsz") - .add_attr_option("fast-math-arcp") - .add_attr_option("fast-math-contract") - .add_attr_option("fast-math-reassoc") - .add_attr_option("opt-level") + .add_attr_option("fast-math") // implies all the below + .add_attr_option("fast-math-nnan") + .add_attr_option("fast-math-ninf") + .add_attr_option("fast-math-nsz") + .add_attr_option("fast-math-arcp") + .add_attr_option("fast-math-contract") + .add_attr_option("fast-math-reassoc") + .add_attr_option("opt-level") // LLVM command line flags, see below .add_attr_option>("cl-opt") // LLVM JIT engine mcjit/orcjit @@ -273,7 +273,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) .set_default_keys({"cpu"}) // Force the external codegen kind attribute to be registered, even if no external // codegen targets are enabled by the TVM build. - .set_attr(tvm::attr::kIsExternalCodegen, Bool(false)) + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(false)) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); // Note regarding the "cl-opt" attribute: @@ -301,28 +301,29 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("mcpu") .add_attr_option("march") - .add_attr_option("workspace-byte-alignment") - .add_attr_option("constants-byte-alignment") + .add_attr_option("workspace-byte-alignment") + .add_attr_option("constants-byte-alignment") .set_default_keys({"cpu"}) .set_target_parser(tvm::target::parsers::cpu::ParseTarget); TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", Integer(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", runtime::Int(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + runtime::Int(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") .add_attr_option("mtriple") - .add_attr_option("max_num_threads", Integer(1024)) - .add_attr_option("thread_warp_size", Integer(32)) + .add_attr_option("max_num_threads", runtime::Int(1024)) + .add_attr_option("thread_warp_size", runtime::Int(32)) .set_default_keys({"cuda", "gpu"}) .set_target_parser(UpdateNVPTXAttrs); @@ -332,24 +333,24 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option>("mattr") // TODO(masahi): Support querying from a target device // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(65536)) - .add_attr_option("thread_warp_size", Integer(64)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(65536)) + .add_attr_option("thread_warp_size", runtime::Int(64)) .set_default_keys({"rocm", "gpu"}) .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(16384)) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("texture_spatial_limit", Integer(16384)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(16384)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("texture_spatial_limit", runtime::Int(16384)) // Faced that Qualcomm OpenCL runtime crashed without any error message in // the case when the number of kernel arguments was pretty big. OpenCL doesn't // specify any limitations on the number of kernel arguments. max_function_args // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", Integer(128)) + .add_attr_option("max_function_args", runtime::Int(128)) .set_default_keys({"opencl", "gpu"}); // The metal has some limitations on the number of input parameters. This is why attribute @@ -358,55 +359,55 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) // https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc // See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("max_shared_memory_per_block", Integer(32768)) - .add_attr_option("thread_warp_size", Integer(16)) - .add_attr_option("max_function_args", Integer(31)) + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("max_shared_memory_per_block", runtime::Int(32768)) + .add_attr_option("thread_warp_size", runtime::Int(16)) + .add_attr_option("max_function_args", runtime::Int(31)) .set_default_keys({"metal", "gpu"}); TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option>("mattr") // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", Bool(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", Bool(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", runtime::Bool(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", runtime::Bool(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") // Physical device limits - .add_attr_option("max_num_threads", Integer(256)) - .add_attr_option("max_threads_per_block", Integer(256)) - .add_attr_option("thread_warp_size", Integer(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_num_threads", runtime::Int(256)) + .add_attr_option("max_threads_per_block", runtime::Int(256)) + .add_attr_option("thread_warp_size", runtime::Int(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") // Other device properties .add_attr_option("device_type") .add_attr_option("device_name") .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") // Tags .set_default_keys({"vulkan", "gpu"}); TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", Integer(256)) + .add_attr_option("max_num_threads", runtime::Int(256)) .set_default_keys({"webgpu", "gpu"}); TVM_REGISTER_TARGET_KIND("sdaccel", kDLOpenCL) // line break @@ -423,8 +424,8 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mcpu") .add_attr_option("mtriple") .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon", "cpu"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU) // line break diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 5797d2295bab..fb839c28da96 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -56,10 +56,25 @@ TVM_REGISTER_NODE_TYPE(ComputeOpNode); /// Verify if ComputeOp is valid with respect to Reduce operations. static void VerifyComputeOp(const ComputeOpNode* op); -inline bool ReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { - return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && StructuralEqual()(a->condition, b->condition) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); +static inline void AssertReduceEqual(const tir::ReduceNode* a, const tir::ReduceNode* b) { + const char* shared_text = + "When a TE compute node produces multiple outputs, " + "each of which is a reduction, " + "each reduction must be structurally identical, " + "except for the ReduceNode::value_index. "; + + StructuralEqual eq; + + ICHECK(a->combiner.same_as(b->combiner)) << shared_text << "However, the reduction operation " + << a->combiner << " does not match " << b->combiner; + ICHECK(a->source.same_as(b->source)) + << shared_text << "However, the input " << a->source << " does not match " << b->source; + ICHECK(eq(a->axis, b->axis)) << shared_text << "However, the reduction axis " << a->axis + << " does not match " << b->axis; + ICHECK(eq(a->condition, b->condition)) << shared_text << "However, the predicate " << a->condition + << " does not match " << b->condition; + ICHECK(eq(a->init, b->init)) << shared_text << "However, the initial value " << a->init + << " does not match " << b->init; } int ComputeOpNode::num_outputs() const { return body.size(); } @@ -529,8 +544,7 @@ class ComputeVerifier final : protected tir::ExprVisitor { << "with being Reduce operation or not."; if (reduce && reduce_) { - ICHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should " - << "have the same attribute except value_index"; + AssertReduceEqual(reduce, reduce_); } level_ = 0; diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 2eb0693685a6..b5a87d9446d8 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -355,11 +355,12 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in Array seq_stmt; if (compute_op->body[0]->IsInstance()) { auto f_reducer_equal = [](const ReduceNode* a, const ReduceNode* b) -> bool { - return a->combiner.same_as(b->combiner) && // - a->source.same_as(b->source) && // - a->axis.same_as(b->axis) && // - a->condition.same_as(b->condition) && // - ((a->init.empty() && b->init.empty()) || a->init.same_as(b->init)); + StructuralEqual eq; + return eq(a->combiner, b->combiner) && // + eq(a->source, b->source) && // + eq(a->axis, b->axis) && // + eq(a->condition, b->condition) && // + eq(a->init, b->init); }; PrimExpr expr_body = compute_op->body[0]; @@ -370,7 +371,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in const tir::ReduceNode* reduce_ = compute_op->body[k].as(); ICHECK(reduce_); ICHECK(f_reducer_equal(reduce_, reduce)) - << "The Reduce inputs of ComputeOp should have the same attribute except value_index"; + << "The Reduce inputs of ComputeOp should have the same attribute except value_index, " + << "but the first argument has body " << GetRef(reduce_) << ", while the " << k + << "-th argument has body " << GetRef(reduce); tensors.push_back(compute_op.output(k)); } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index 4f5df7ad3024..774a0f8f1f89 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -63,7 +63,17 @@ Tensor placeholder(Array shape, DataType dtype, std::string name) { } TVM_REGISTER_GLOBAL("te.Placeholder") - .set_body_typed([](Array shape, DataType dtype, std::string name) { + .set_body_typed([](Variant> shape_arg, DataType dtype, + std::string name) { + auto shape = [&]() -> Array { + if (auto arg_expr = shape_arg.as()) { + return {arg_expr.value()}; + } else if (auto arg_array = shape_arg.as>()) { + return arg_array.value(); + } else { + LOG(FATAL) << "Variant did not contain either allowed type"; + } + }(); return placeholder(shape, dtype, name); }); diff --git a/src/te/schedule/schedule_dataflow_rewrite.cc b/src/te/schedule/schedule_dataflow_rewrite.cc index c38c5a5c800b..1ad8914e48cc 100644 --- a/src/te/schedule/schedule_dataflow_rewrite.cc +++ b/src/te/schedule/schedule_dataflow_rewrite.cc @@ -124,9 +124,10 @@ void ReplaceDataFlow(const Array& stages, std::unordered_mapcombiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && (a->condition.same_as(b->condition)) && - ((a->init.empty() && b->init.empty()) || (a->init.same_as(b->init))); + StructuralEqual struct_equal; + return struct_equal(a->combiner, b->combiner) && struct_equal(a->source, b->source) && + struct_equal(a->axis, b->axis) && struct_equal(a->condition, b->condition) && + struct_equal(a->init, b->init); } Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 3a41c5ac5a25..70e82a605369 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -134,7 +134,7 @@ bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); if (target.defined() && target->kind->name == "hexagon") { - auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; + auto value = target->GetAttr("vtcm-capacity").value()->value; if (value > 0) return value; } return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 1506082003fd..c38237a664f7 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -35,6 +35,18 @@ namespace tvm { namespace tir { +/* \brief Convert an object to a PrimExpr + * + * All conversions to a PrimExpr are performed as part of the FFI, + * when calling a function that accepts a PrimExpr as an argument. If + * a function must normalize to a PrimExpr (e.g. before accessing the + * `expr.dtype` field), this function allows the FFI conversions to be + * explicitly invoked. + */ +TVM_REGISTER_GLOBAL("tir.convert").set_body_typed([](Variant> expr) { + return expr; +}); + #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ Name::Name(PrimExpr a, PrimExpr b, Span span) { \ using T = Name::ContainerType; \ @@ -546,7 +558,9 @@ Call::Call(DataType dtype, RelayExpr op, Array args, Span span) { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, RelayExpr op, Array args, Span span) { + .set_body_typed([](DataType type, RelayExpr op, + Array> args, + Span span) { Array prim_expr_args; for (const auto& it : args) { ICHECK(it->IsInstance() || it->IsInstance() || @@ -707,9 +721,11 @@ Reduce::Reduce(CommReducer combiner, Array source, Array axis if (!init.empty()) { ICHECK_EQ(init.size(), source.size()) << "Number of inits should match number of exprs"; for (size_t i = 0; i < init.size(); i++) { + ICHECK(init[i].defined()) << "Init value must be defined"; ICHECK(init[i]->IsInstance() || init[i]->IsInstance() || init[i]->IsInstance()) - << "init can only be a IntImm, FloatImm or ProducerLoad"; + << "init can only be a IntImm, FloatImm or ProducerLoad, " + << "but received " << init[i] << " of type " << init[i]->GetTypeKey(); } } n->dtype = source[value_index].dtype(); diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 14dd0eadb65c..2c94b9d8646b 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -27,6 +27,8 @@ #include #include +#include "utils.h" + namespace tvm { namespace tir { namespace { @@ -79,6 +81,11 @@ PrimFunc::PrimFunc(Array params, Stmt body, Type ret_type, if (!ret_type.defined()) { ret_type = VoidType(); } + + if (attrs.defined()) { + attrs = Downcast(NormalizeAttributeObject(attrs)); + } + auto n = make_object(); n->params = std::move(params); n->body = std::move(body); diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index b30d0caf6af3..78fb9365cc71 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx /**************** Implementation ****************/ -PrimFunc Specialize(PrimFunc func, const Map& param_map) { +PrimFunc Specialize(PrimFunc func, const Map>& param_map) { VarMap var_map; for (const auto& kv : param_map) { const Var& param = kv.first; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 5df76450ff1e..9c8f580b5413 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -27,6 +27,7 @@ #include #include "buffer_common.h" +#include "utils.h" namespace tvm { namespace tir { @@ -61,6 +62,15 @@ TVM_REGISTER_NODE_TYPE(LetStmtNode); // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + // The nodes are not required to be a TIR type, and may legally + // contain any ObjectRef. However, normalizing to an IR type if + // possible prevents spurious discrepancies in StructuralEqual(). + if (auto opt = node.as()) { + node = Bool(opt.value()); + } else if (auto opt = node.as()) { + node = Integer(opt.value()); + } + auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -109,13 +119,21 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") // For For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, Optional thread_binding, Map annotations, Span span) { + ICHECK(loop_var.defined()); ICHECK(min.defined()); ICHECK(extent.defined()); - ICHECK(min.dtype().is_scalar()); - ICHECK(extent.dtype().is_scalar()); - ICHECK(loop_var.dtype().is_scalar()); ICHECK(body.defined()); + auto require_scalar_int_dtype = [&](PrimExpr expr, const char* field_name) { + auto dtype = expr.dtype(); + CHECK(dtype.is_scalar() && (dtype.is_int() || dtype.is_uint())) + << "TIR For nodes require a scalar integer as the " << field_name << ", but received " + << expr << " with dtype " << dtype; + }; + require_scalar_int_dtype(loop_var, "loop_var"); + require_scalar_int_dtype(min, "min"); + require_scalar_int_dtype(extent, "extent"); + // When extent or min is an IntImm but has narrower dtype than loop_var, we directly promote them // without raising errors. auto try_promote_imm_dtype = [&](const PrimExpr& e) { @@ -136,6 +154,8 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, ICHECK(loop_var.dtype() == min.dtype()) << loop_var.dtype() << " vs " << min.dtype(); ICHECK(loop_var.dtype() == extent.dtype()) << loop_var.dtype() << " vs " << extent.dtype(); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->loop_var = std::move(loop_var); node->min = std::move(min); @@ -234,6 +254,8 @@ Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, Prim ICHECK(condition.defined()); ICHECK(condition.dtype().is_bool()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -288,6 +310,8 @@ AllocateConst::AllocateConst(Var buffer_var, DataType dtype, Array ext ICHECK(body.defined()); ICHECK(data_or_idx.defined()); + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->buffer_var = std::move(buffer_var); node->dtype = dtype; @@ -652,6 +676,8 @@ Block::Block(Array iter_vars, Array reads, Array init, Array alloc_buffers, Array match_buffers, Map annotations, Span span) { + annotations = Downcast>(NormalizeAttributeObject(annotations)); + ObjectPtr node = make_object(); node->iter_vars = std::move(iter_vars); node->reads = std::move(reads); diff --git a/src/tir/ir/utils.cc b/src/tir/ir/utils.cc new file mode 100644 index 000000000000..0e3dc1237894 --- /dev/null +++ b/src/tir/ir/utils.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tir/ir/utils.cc + * \brief Utilities for manipulating TIR + */ +#include "utils.h" + +#include + +namespace tvm { +namespace tir { + +ObjectRef NormalizeAttributeObject(ObjectRef obj) { + if (const auto* runtime_int = obj.as()) { + return Integer(runtime_int->value); + } else if (const auto* runtime_bool = obj.as()) { + return Bool(runtime_bool->value); + } else if (const auto* runtime_float = obj.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto opt_array = obj.as>()) { + return opt_array.value().Map(NormalizeAttributeObject); + } else if (auto opt_map = obj.as>()) { + Map new_map; + bool is_same = true; + + for (const auto& [key, obj] : opt_map.value()) { + ObjectRef new_obj = NormalizeAttributeObject(obj); + is_same = is_same && obj.same_as(new_obj); + new_map.Set(key, new_obj); + } + + if (is_same) { + return obj; + } else { + return new_map; + } + } else if (auto dict_attrs = obj.as()) { + auto new_attrs = Downcast>(NormalizeAttributeObject(dict_attrs->dict)); + if (new_attrs.same_as(dict_attrs->dict)) { + return GetRef(dict_attrs); + } else { + return DictAttrs(new_attrs); + } + } else { + return obj; + } +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/ir/utils.h b/src/tir/ir/utils.h new file mode 100644 index 000000000000..b1f7a722899f --- /dev/null +++ b/src/tir/ir/utils.h @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/ir/utils.h + * \brief Utilities for manipulating TIR + */ +#ifndef TVM_TIR_IR_UTILS_H_ +#define TVM_TIR_IR_UTILS_H_ + +#include + +namespace tvm { +namespace tir { + +/* \brief Normalize an ObjectRef held + * + * Where possible, the IR should be normalized contain IR types. For + * example, holding a `tir::IntImm` instead of a `runtime::Int`. In + * attributes, this is not always possible, as attributes may refer to + * non-IR objects. + * + * This function normalizes any `runtime::Int`, `runtime::Bool`, + * `runtime::Float`, or containers of those types to the corresponding + * IR type. + * + * \param obj The attribute object to be normalized + * + * \returns The normalized attribute + */ +ObjectRef NormalizeAttributeObject(ObjectRef obj); + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_IR_UTILS_H_ diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index c79a148e4b6e..dad4ea98d614 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -229,9 +229,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } PrimExpr ret(PrimExpr value, Span span) { + CHECK(value.defined()); return tir::Call(value.dtype(), tir::builtin::ret(), {value}, span); } +TVM_REGISTER_GLOBAL("tir.ret").set_body_typed(ret); + // maximum and min limits PrimExpr max_value(const DataType& dtype, Span span) { using namespace tir; @@ -1048,12 +1051,15 @@ TVM_TIR_REGISTER_OP("TVMBackendFreeWorkspace") // expose basic functions to node namespace TVM_REGISTER_GLOBAL("node._const").set_body([](TVMArgs args, TVMRetValue* ret) { - if (args[0].type_code() == kDLInt) { - *ret = tir::make_const(args[1], args[0].operator int64_t(), args[2]); - } else if (args[0].type_code() == kDLFloat) { - *ret = tir::make_const(args[1], args[0].operator double(), args[2]); + if (auto opt = args[0].TryAsInt()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsBool()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); + } else if (auto opt = args[0].TryAsFloat()) { + *ret = tir::make_const(args[1], opt.value(), args[2]); } else { - LOG(FATAL) << "only accept int or float"; // FIXME + LOG(FATAL) << "First argument to tvm.tir.const must be int, float, or bool, " + << "but instead received argument with type code " << args[0].type_code(); // FIXME } }); diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index cda501cd992e..73b5ff3fafd4 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -233,9 +233,9 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { return support::LinearCongruentialEngine(&rand_state_).ForkSeed(); } -ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); @@ -914,6 +914,14 @@ ObjectRef ConcreteScheduleNode::CheckAndGetAnnotationValue(const ObjectRef& ann_ if (ann_val.as()) { return ann_val; } + if (auto* runtime_int = ann_val.as()) { + return IntImm(DataType::Int(32), runtime_int->value); + } else if (auto* runtime_float = ann_val.as()) { + return FloatImm(DataType::Float(32), runtime_float->value); + } else if (auto* runtime_bool = ann_val.as()) { + return Bool(runtime_bool->value); + } + if (const auto* expr = ann_val.as()) { ICHECK(!ann_val->IsInstance()) << "TypeError: runtime::String is expected, but gets StringImm"; diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 4eccff10a2c7..092bcf0c79f9 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -87,8 +87,9 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) override; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) override; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) override; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 122c5ff0d9fe..9209e6578687 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -439,6 +439,11 @@ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os } else if (const auto* float_imm = obj.as()) { os.precision(17); os << float_imm->value; + } else if (const auto* runtime_int = obj.as()) { + os << runtime_int->value; + } else if (const auto* runtime_float = obj.as()) { + os.precision(17); + os << runtime_float->value; } else if (const auto* array = obj.as()) { os << '['; bool is_first = true; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fe1c1850dcd5..fd1349e4a3ec 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -55,8 +55,9 @@ std::vector SampleWithoutReplacement( * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision); + const Array& candidates, + const Array& probs, + Optional* decision); /*! * \brief Create a sampling function that does multinomial sampling. * \param rand_state The random state. diff --git a/src/tir/schedule/primitive/annotate.cc b/src/tir/schedule/primitive/annotate.cc index 92c3423bcbbb..4c7b208e964f 100644 --- a/src/tir/schedule/primitive/annotate.cc +++ b/src/tir/schedule/primitive/annotate.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include "../../ir/utils.h" #include "../utils.h" namespace tvm { @@ -97,6 +98,8 @@ struct AnnotateTraits : public UnpackedInstTraits { static void UnpackedApplyToSchedule(Schedule sch, ObjectRef block_or_loop_rv, ObjectRef ann_val, String ann_key) { + ann_val = NormalizeAttributeObject(ann_val); + if (auto block = block_or_loop_rv.as()) { return sch->Annotate(block.value(), ann_key, ann_val); } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 2a2f17355ca6..8e16f50b8b95 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -163,19 +163,18 @@ std::vector SampleWithoutReplacement( } int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, - const Array& candidates, const Array& probs, - Optional* decision) { + const Array& candidates, const Array& probs, + Optional* decision) { CHECK(candidates.size() == probs.size()) << "ValueError: number of candidates does not match number of probabilities."; int32_t i = -1; int32_t n = candidates.size(); if (decision->defined()) { - const auto* int_imm = decision->as(); - i = int_imm->value; + i = decision->value()->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - std::vector weights = support::AsVector(probs); + std::vector weights = support::AsVector(probs); std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); @@ -183,8 +182,8 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } - *decision = Integer(i); // decision is guaranteed not to be nullptr. - return candidates[i].IntValue(); + *decision = runtime::Int(i); // decision is guaranteed not to be nullptr. + return candidates[i]->value; } std::function MakeMultinomialSampler( @@ -461,24 +460,11 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // - Optional decision) { - Array probs_float = probs.Map([](const ObjectRef& prob) { - const auto* prob_float = prob.as(); - if (prob_float != nullptr) { - return GetRef(prob_float); - } - const auto* prob_int = prob.as(); - if (prob_int != nullptr) { - return FloatImm(DataType::Float(32), static_cast(prob_int->value)); - } - LOG(FATAL) - << "SampleCategorical does not accept probability with type other than float or int."; - throw; - }); - return sch->SampleCategorical(candidates, probs_float, decision); + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/src/tir/schedule/trace.cc b/src/tir/schedule/trace.cc index 4b10df7e9728..6e243bf19198 100644 --- a/src/tir/schedule/trace.cc +++ b/src/tir/schedule/trace.cc @@ -112,7 +112,9 @@ Array TranslateInputRVs( } else if (const auto* str_obj = input.as()) { // Case 2. string => "content" results.push_back(String('"' + std::string(str_obj->data) + '"')); - } else if (input->IsInstance() || input->IsInstance()) { + } else if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { // Case 3. integer or floating-point number results.push_back(input); } else if (input->IsInstance()) { @@ -149,7 +151,9 @@ Array TranslateInputRVs(const Array& inputs, results.reserve(inputs.size()); for (const ObjectRef& input : inputs) { // Case 3. integer or floating-point number - if (input->IsInstance() || input->IsInstance()) { + if (input->IsInstance() || input->IsInstance() || + input->IsInstance() || + input->IsInstance()) { results.push_back(input); continue; } @@ -388,9 +392,9 @@ void Trace::ApplyJSONToSchedule(ObjectRef json, Schedule sch) { try { const ArrayNode* arr = decision_entry.as(); ICHECK(arr && arr->size() == 2); - const IntImmNode* arr0 = arr->at(0).as(); + auto arr0 = arr->at(0).as(); ICHECK(arr0); - index = arr0->value; + index = arr0.value(); decision = arr->at(1); } catch (const tvm::Error& e) { LOG(FATAL) << "ValueError: Each entry of a json decision should be a tuple [index, " diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 16c4350aaee6..1611109d7735 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -53,9 +53,9 @@ Schedule TracedScheduleNode::Copy() { /******** Schedule: Sampling ********/ -ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, - const Array& probs, - Optional decision) { +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { ExprRV result = CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 686d84ebc6fe..78629e84f039 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,8 +47,9 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ - ExprRV SampleCategorical(const Array& candidates, const Array& probs, - Optional decision = NullOpt) final; + ExprRV SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision = NullOpt) final; Array SamplePerfectTile(const LoopRV& loop_rv, int n, int max_innermost_factor, Optional> decision = NullOpt) final; Array SamplePartitionedTile(const LoopRV& loop_rv, int n, int partition_pos, diff --git a/src/tir/transforms/inline_private_functions.cc b/src/tir/transforms/inline_private_functions.cc index cc33ba9f86c2..14672f568549 100644 --- a/src/tir/transforms/inline_private_functions.cc +++ b/src/tir/transforms/inline_private_functions.cc @@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator { << "Inlining of PrimFuncs with buffer arguments is not yet supported, " << "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; - Map param_map; + Map> param_map; for (size_t i = 0; i < callee->params.size(); i++) { param_map.Set(callee->params[i], args[i]); } diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 423b0ca92237..2948773321dd 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -155,6 +155,7 @@ inline DataType APIType(DataType t) { ICHECK(!t.is_void()) << "Cannot pass void type through packed API."; if (t.is_handle()) return t; ICHECK_EQ(t.lanes(), 1) << "Cannot pass vector type through packed API."; + if (t.is_bool()) return DataType::Bool(); if (t.is_uint() || t.is_int()) return DataType::Int(64); ICHECK(t.is_float()); return DataType::Float(64); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 1a3888a7cd48..1cde4f2ebe7d 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -511,6 +511,8 @@ class BuiltinLower : public StmtExprMutator { arg_tcode = kTVMStr; } else if (IsArrayHandle(arg)) { arg_tcode = kTVMDLTensorHandle; + } else if (arg.dtype().is_bool()) { + arg_tcode = kTVMArgBool; } // opaque handle need to set the kind properly if (arg_tcode == kTVMOpaqueHandle) { diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d327cdfa8393..9f2f1295fece 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -263,15 +263,15 @@ PrimFunc MakePackedAPI(PrimFunc func) { // --------------------------- // local function definitions // load i-th argument as type t - auto f_arg_value = [&](DataType t, int i) { + auto f_arg_value = [&](DataType arg_type, int i) { Array call_args{v_packed_args, IntImm(DataType::Int(32), i), IntImm(DataType::Int(32), builtin::kTVMValueContent)}; // load 64 bit version - DataType api_type = APIType(t); + DataType api_type = APIType(arg_type); PrimExpr res = Call(api_type, builtin::tvm_struct_get(), call_args); // cast to the target version. - if (api_type != t) { - res = Cast(t, res); + if (api_type != arg_type) { + res = Cast(arg_type, res); } return res; }; @@ -319,10 +319,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { continue; } - var_def.emplace_back(f_arg_value(param.dtype(), i), param); - if (func_ptr->buffer_map.count(param)) { - buffer_def.emplace_back(param, func_ptr->buffer_map[param]); - } + PrimExpr arg_value; // type code checks Var tcode(param->name_hint + ".code", DataType::Int(32)); @@ -335,15 +332,45 @@ PrimFunc MakePackedAPI(PrimFunc func) { seq_init.emplace_back(AssertStmt(tcode == kTVMOpaqueHandle || tcode == kTVMNDArrayHandle || tcode == kTVMDLTensorHandle || tcode == kTVMNullptr, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } else if (t.is_bool()) { + std::ostringstream msg; + msg << name_hint << ": Expect arg[" << i << "] to be boolean"; + seq_init.emplace_back( + AssertStmt(tcode == kTVMArgBool || tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgBool, + f_arg_value(DataType::Bool(), i), + cast(DataType::Bool(), f_arg_value(DataType::Int(64), i)), + }); + } else if (t.is_int() || t.is_uint()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be int"; - seq_init.emplace_back(AssertStmt(tcode == kDLInt, tvm::tir::StringImm(msg.str()), nop)); + seq_init.emplace_back( + AssertStmt(tcode == kDLInt || tcode == kTVMArgBool, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = Call(t, builtin::if_then_else(), + { + tcode == kTVMArgInt, + f_arg_value(t, i), + cast(t, f_arg_value(DataType::Bool(), i)), + }); } else { ICHECK(t.is_float()); std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be float"; seq_init.emplace_back(AssertStmt(tcode == kDLFloat, tvm::tir::StringImm(msg.str()), nop)); + + arg_value = f_arg_value(param.dtype(), i); + } + + var_def.emplace_back(arg_value, param); + if (func_ptr->buffer_map.count(param)) { + buffer_def.emplace_back(param, func_ptr->buffer_map[param]); } } diff --git a/tests/cpp/relay/backend/runtime_test.cc b/tests/cpp/relay/backend/runtime_test.cc index 53ea7e39ed59..adabb9b9b6cf 100644 --- a/tests/cpp/relay/backend/runtime_test.cc +++ b/tests/cpp/relay/backend/runtime_test.cc @@ -26,13 +26,13 @@ namespace tvm { namespace relay { TVM_REGISTER_RUNTIME("TestRuntime") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") .add_attr_option("another_option") - .add_attr_option("defaulty_the_default_option", Bool(false)); + .add_attr_option("defaulty_the_default_option", runtime::Bool(false)); TEST(Runtime, Create) { - Map attrs = {{"my_bool", Bool(true)}}; + Map attrs = {{"my_bool", runtime::Bool(true)}}; Runtime my_runtime = Runtime::Create("TestRuntime", attrs); ASSERT_EQ(my_runtime->GetAttr("my_bool"), true); ASSERT_EQ(my_runtime->GetAttr>("your_names").defined(), false); @@ -40,7 +40,7 @@ TEST(Runtime, Create) { } TEST(Runtime, UnknownAttr) { - Map attrs = {{"woofles", Bool(true)}}; + Map attrs = {{"woofles", runtime::Bool(true)}}; ASSERT_THROW(Runtime::Create("TestRuntime", attrs), Error); } @@ -64,7 +64,7 @@ TEST(RuntimeRegistry, ListRuntimeOptions) { Map attrs = Runtime::ListRuntimeOptions("TestRuntime"); ICHECK_EQ(attrs.empty(), false); - ICHECK_EQ(attrs["my_bool"], "IntImm"); + ICHECK_EQ(attrs["my_bool"], "runtime.BoxBool"); ICHECK_EQ(attrs["your_names"], "Array"); ICHECK_EQ(attrs["another_option"], "runtime.String"); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2db4b572bf60..0a2b8206d322 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -32,15 +32,15 @@ using namespace tvm; TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .set_attr("Attr1", "Value1") - .add_attr_option("my_bool") + .add_attr_option("my_bool") .add_attr_option>("your_names") - .add_attr_option>("her_maps"); + .add_attr_option>("her_maps"); TargetJSON TestTargetParser(TargetJSON target) { String mcpu = Downcast(target.at("mcpu")); target.Set("mcpu", String("super_") + mcpu); target.Set("keys", Array({"super"})); - target.Set("features", Map{{"test", Bool(true)}}); + target.Set("features", Map{{"test", runtime::Bool(true)}}); return target; } @@ -76,14 +76,14 @@ TEST(TargetKind, GetAttrMap) { TEST(TargetCreation, NestedConfig) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -91,13 +91,14 @@ TEST(TargetCreation, NestedConfig) { ICHECK_EQ(target->kind, TargetKind::Get("TestTargetKind").value()); ICHECK_EQ(target->tag, ""); ICHECK(target->keys.empty()); - Bool my_bool = target->GetAttr("my_bool").value(); + runtime::Bool my_bool = target->GetAttr("my_bool").value(); ICHECK_EQ(my_bool.operator bool(), true); Array your_names = target->GetAttr>("your_names").value(); ICHECK_EQ(your_names.size(), 2U); ICHECK_EQ(your_names[0], "junru"); ICHECK_EQ(your_names[1], "jian"); - Map her_maps = target->GetAttr>("her_maps").value(); + Map her_maps = + target->GetAttr>("her_maps").value(); ICHECK_EQ(her_maps.size(), 2U); ICHECK_EQ(her_maps["a"], 1); ICHECK_EQ(her_maps["b"], 2); @@ -105,15 +106,15 @@ TEST(TargetCreation, NestedConfig) { TEST(TargetCreationFail, UnrecognizedConfigOption) { Map config = { - {"my_bool", Bool(true)}, + {"my_bool", runtime::Bool(true)}, {"your_names", Array{"junru", "jian"}}, {"kind", String("TestTargetKind")}, {"bad", ObjectRef(nullptr)}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -133,9 +134,9 @@ TEST(TargetCreationFail, TypeMismatch) { {"kind", String("TestTargetKind")}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -150,13 +151,13 @@ TEST(TargetCreationFail, TypeMismatch) { TEST(TargetCreationFail, TargetKindNotFound) { Map config = { - {"my_bool", Bool("true")}, + {"my_bool", runtime::Bool("true")}, {"your_names", Array{"junru", "jian"}}, { "her_maps", - Map{ - {"a", 1}, - {"b", 2}, + Map{ + {"a", runtime::Int(1)}, + {"b", runtime::Int(2)}, }, }, }; @@ -178,15 +179,16 @@ TEST(TargetCreation, TargetParser) { TEST(TargetCreation, TargetFeatures) { Target test_target_with_parser("TestTargetParser -mcpu=woof"); - ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); + ASSERT_EQ(test_target_with_parser->GetFeature("test").value(), true); Target test_target_no_parser("TestTargetKind"); - ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); - ASSERT_EQ(test_target_no_parser->GetFeature("test", Bool(true)).value(), true); + ASSERT_EQ(test_target_no_parser->GetFeature("test"), nullptr); + ASSERT_EQ(test_target_no_parser->GetFeature("test", runtime::Bool(true)).value(), + true); } TEST(TargetCreation, TargetFeaturesBeforeParser) { - Map features = {{"test", Bool(true)}}; + Map features = {{"test", runtime::Bool(true)}}; Map config = { {"kind", String("TestTargetParser")}, {"mcpu", String("woof")}, @@ -469,13 +471,13 @@ TEST(TargetCreation, DetectSystemTriple) { #endif TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_1", kDLCUDA) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_2", kDLMetal) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)); + .set_attr(tvm::attr::kIsExternalCodegen, runtime::Bool(true)); TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU) .set_attr(tvm::attr::kRelayToTIR, diff --git a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py index bbfb8bd2db12..f5b1651e115a 100644 --- a/tests/python/all-platform-minimal-test/test_runtime_packed_func.py +++ b/tests/python/all-platform-minimal-test/test_runtime_packed_func.py @@ -15,10 +15,14 @@ # specific language governing permissions and limitations # under the License. """Test packed function FFI.""" +import gc + +import numpy as np + import tvm from tvm import te import tvm.testing -import numpy as np +from tvm.script import tir as T def test_get_global(): @@ -37,7 +41,7 @@ def my_packed_func(*args): def test_get_callback_with_node(): - x = tvm.runtime.convert(10) + x = T.int32(10) def test(y): assert y.handle != x.handle @@ -66,7 +70,7 @@ def add(x): myf = tvm.runtime.convert(addy) f = myf(10) - assert f(11).value == 21 + assert f(11) == 21 def test_convert(): @@ -113,6 +117,14 @@ def test_device_func(dev): def test_rvalue_ref(): def callback(x, expected_count): + # The use count of TVM objects is decremented as part of + # `ObjectRef.__del__`, which runs when the Python object is + # destructed. However, Python object destruction is not + # deterministic, and even CPython's reference-counting is + # considered an implementation detail. Therefore, to ensure + # correct results from this test, `gc.collect()` must be + # explicitly called. + gc.collect() assert expected_count == tvm.testing.object_use_count(x) return x diff --git a/tests/python/arith/test_arith_canonical_simplify.py b/tests/python/arith/test_arith_canonical_simplify.py index afd716cde389..42f5b0ccd0b8 100644 --- a/tests/python/arith/test_arith_canonical_simplify.py +++ b/tests/python/arith/test_arith_canonical_simplify.py @@ -16,16 +16,27 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir +from tvm.script import tir as T class CanonicalChecker: def __init__(self): self.analyzer = tvm.arith.Analyzer() + def _convert(self, expr): + # TODO(Lunderberg): Make utility functions `tir.convert` and + # `relax.convert` that convert to their respective IR types. + # Implementation should be in C++, and should only consist of + # conversions that are applied automatically through FFI. + if isinstance(expr, int): + return T.int32(expr) + else: + return expr + def verify(self, data, expected): res = self.analyzer.canonical_simplify(data) - expected = tvm.runtime.convert(expected) + expected = self._convert(expected) assert tvm.ir.structural_equal(res, expected), "\ndata={}\nres={}\nexpected={}".format( data, res, expected ) @@ -377,13 +388,13 @@ def test_simplify_normalize_min_value_expr(): x = te.var("x", "int32") ck.verify(te.min_value("int32") - x == 0, x == te.min_value("int32")) - ck.verify(te.min_value("int32") + x == 0, False) + ck.verify(te.min_value("int32") + x == 0, tir.const(False)) ck.verify(0 == te.min_value("int32") - x, x == te.min_value("int32")) - ck.verify(0 == te.min_value("int32") + x, False) + ck.verify(0 == te.min_value("int32") + x, tir.const(False)) ck.verify(-x + te.min_value("int32") == 0, x == te.min_value("int32")) - ck.verify(x + te.min_value("int32") == 0, False) + ck.verify(x + te.min_value("int32") == 0, tir.const(False)) ck.verify(0 == -x + te.min_value("int32"), x == te.min_value("int32")) - ck.verify(0 == x + te.min_value("int32"), False) + ck.verify(0 == x + te.min_value("int32"), tir.const(False)) def test_proddiv_simplify(): diff --git a/tests/python/arith/test_arith_iter_affine_map.py b/tests/python/arith/test_arith_iter_affine_map.py index 3a10ec05efeb..f0e6f05adfad 100644 --- a/tests/python/arith/test_arith_iter_affine_map.py +++ b/tests/python/arith/test_arith_iter_affine_map.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm.tir import floordiv, floormod +from tvm.script import tir as T def ifuse(inputs, pred_extent=None): @@ -537,7 +538,7 @@ def test_subspace_division(): tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) tvm.ir.assert_structural_equal(res[0][1], x + c) tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) - tvm.ir.assert_structural_equal(res[1][1], True) + tvm.ir.assert_structural_equal(res[1][1], T.bool(True)) # compound 1 i0 = create_iter("i0", 4) @@ -553,7 +554,7 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) @@ -569,7 +570,7 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -587,11 +588,11 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) tvm.ir.assert_structural_equal(res[1][1], i3[0]) tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) - tvm.ir.assert_structural_equal(res[2][1], True) + tvm.ir.assert_structural_equal(res[2][1], T.bool(True)) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])).indices assert len(res1) == 2 @@ -606,9 +607,9 @@ def test_subspace_division(): assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], i0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) - tvm.ir.assert_structural_equal(res[2][0], True) + tvm.ir.assert_structural_equal(res[2][0], T.bool(True)) tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])).indices @@ -642,10 +643,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])).indices @@ -661,9 +662,9 @@ def test_subspace_division(): assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], j0[0]) tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(floormod(l0[0] * 6 + l1[0], 6), 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) res1 = tvm.arith.detect_iter_map( @@ -690,10 +691,10 @@ def test_subspace_division(): res = convert_division(res) assert len(res) == 4 tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) - tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][0], T.int32(0)) tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) @@ -735,8 +736,8 @@ def test_subspace_divide_trivial_iters(): res = convert_division(res) assert len(res) == 3 tvm.ir.assert_structural_equal(res[0][0], x) - tvm.ir.assert_structural_equal(res[0][1], 0) - tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[0][1], T.int32(0)) + tvm.ir.assert_structural_equal(res[1][0], T.int32(0)) tvm.ir.assert_structural_equal(res[1][1], y) diff --git a/tests/python/arith/test_arith_narrow_predicate_expression.py b/tests/python/arith/test_arith_narrow_predicate_expression.py index d38fe70f6b5c..0aa353c60041 100644 --- a/tests/python/arith/test_arith_narrow_predicate_expression.py +++ b/tests/python/arith/test_arith_narrow_predicate_expression.py @@ -20,6 +20,7 @@ from tvm import tir from tvm.runtime import convert +from tvm.script import tir as T i = tir.Var("i", "int32") @@ -42,18 +43,18 @@ [i < n, i < 0], [i <= n, i <= 0], [i >= n, i >= 7], - [n > i, convert(0) > i], - [n < i, convert(7) < i], - [n <= i, convert(7) <= i], - [n >= i, convert(0) >= i], - [i == n, tir.all(i <= 0, convert(7) <= i)], - [n == i, tir.all(convert(7) <= i, i <= 0)], - [i != n, tir.any(i < 0, convert(7) < i)], - [n != i, tir.any(convert(7) < i, i < 0)], + [n > i, T.int32(0) > i], + [n < i, T.int32(7) < i], + [n <= i, T.int32(7) <= i], + [n >= i, T.int32(0) >= i], + [i == n, tir.all(i <= 0, T.int32(7) <= i)], + [n == i, tir.all(T.int32(7) <= i, i <= 0)], + [i != n, tir.any(i < 0, T.int32(7) < i)], + [n != i, tir.any(T.int32(7) < i, i < 0)], [i // 4 > n, i // 4 > 7], - [n < i // 4, convert(7) < i // 4], + [n < i // 4, T.int32(7) < i // 4], [(i + n) // 4 > 0, tir.Add(i, 0) // 4 > 0], - [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, convert(0) <= tir.Add(i, 0) // 4)], + [(i + n) // 4 == 0, tir.all(tir.Add(i, 7) // 4 <= 0, T.int32(0) <= tir.Add(i, 0) // 4)], [i + n < 10, i + 7 < 10], [i - n < 10, tir.Sub(i, 0) < 10], [tir.Not(i < n), tir.Not(i < 7)], diff --git a/tests/python/arith/test_arith_rewrite_simplify.py b/tests/python/arith/test_arith_rewrite_simplify.py index 90f0aeef47d7..7fc1862192d6 100644 --- a/tests/python/arith/test_arith_rewrite_simplify.py +++ b/tests/python/arith/test_arith_rewrite_simplify.py @@ -27,6 +27,8 @@ from tvm.tir import truncdiv as tdiv from tvm.tir import truncmod as tmod +from tvm.script import tir as T + class TestCase: def __init__(self, before, expected, preconditions=None): @@ -35,10 +37,21 @@ def __init__(self, before, expected, preconditions=None): if isinstance(expected, tir.expr.EqualOp): expected = expected.asobject() - self.before = before - self.expected = expected + self.before = self._convert(before) + self.expected = self._convert(expected) self.preconditions = preconditions + @staticmethod + def _convert(expr): + if isinstance(expr, tir.expr.EqualOp): + return expr.asobject() + elif isinstance(expr, int): + return T.int32(expr) + elif isinstance(expr, float): + return T.float32(expr) + else: + return expr + @property def constraint(self): if self.preconditions is None: @@ -1008,8 +1021,8 @@ class TestComparisons(BaseCompare): TestCase(tir.all(fld(x, 8) == -3, flm(x, 8) == 4), x == -20), TestCase(tir.all(flm(x, 8) == 4, fld(x, 8) == -3), x == -20), # Rewrite based on definition of integer division - TestCase(tir.all(tvm.runtime.convert(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), - TestCase(tir.all(x - y * 5 < 5, tvm.runtime.convert(0) <= x - y * 5), y == fld(x, 5)), + TestCase(tir.all(T.int32(0) <= x - y * 5, x - y * 5 < 5), y == fld(x, 5)), + TestCase(tir.all(x - y * 5 < 5, T.int32(0) <= x - y * 5), y == fld(x, 5)), # Narrow upper bound using floormod TestCase(tir.all(x < 20, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), TestCase(tir.all(x < 18, flm(x, 5) < 2), tir.all(x < 17, flm(x, 5) < 2)), @@ -1025,36 +1038,36 @@ class TestComparisons(BaseCompare): # Merge a known floordiv and an upper bound of floormod into a value range TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(50) <= x, x < 57), + tir.all(T.int32(50) <= x, x < 57), ), TestCase( tir.all(fld(x, 10) == 5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(50) <= x, x <= 57), + tir.all(T.int32(50) <= x, x <= 57), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) < 7), - tir.all(tvm.runtime.convert(-50) <= x, x < -43), + tir.all(T.int32(-50) <= x, x < -43), ), TestCase( tir.all(fld(x, 10) == -5, flm(x, 10) <= 7), - tir.all(tvm.runtime.convert(-50) <= x, x <= -43), + tir.all(T.int32(-50) <= x, x <= -43), ), # Merge a known floordiv and an lower bound of floormod into a value range TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(57) < x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(57) < x, x < 60), ), TestCase( - tir.all(fld(x, 10) == 5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(57) <= x, x < 60), + tir.all(fld(x, 10) == 5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(57) <= x, x < 60), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) < flm(x, 10)), - tir.all(tvm.runtime.convert(-43) < x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) < flm(x, 10)), + tir.all(T.int32(-43) < x, x < -40), ), TestCase( - tir.all(fld(x, 10) == -5, tvm.runtime.convert(7) <= flm(x, 10)), - tir.all(tvm.runtime.convert(-43) <= x, x < -40), + tir.all(fld(x, 10) == -5, T.int32(7) <= flm(x, 10)), + tir.all(T.int32(-43) <= x, x < -40), ), TestCase(tvm.te.min(x, 11) < 10, x < 10), TestCase(tvm.te.min(x, 8) < 10, tvm.tir.const(1, "bool")), @@ -1224,14 +1237,16 @@ class TestIfThenElse(BaseCompare): class TestCLZ(BaseCompare): test_case = tvm.testing.parameter( - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), 32), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), 31), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), 30), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), 24), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), 64), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), 63), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), 62), - TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), 56), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 0), T.int32(32)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 1), T.int32(31)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 2), T.int32(30)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", 128), T.int32(24)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 0)), T.int32(64)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 1)), T.int32(63)), + TestCase(tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 2)), T.int32(62)), + TestCase( + tvm.tir.call_intrin("int32", "tir.clz", tvm.tir.IntImm("int64", 128)), T.int32(56) + ), ) diff --git a/tests/python/arith/test_arith_solve_linear_equations.py b/tests/python/arith/test_arith_solve_linear_equations.py index 24eb860c55f6..3195a4ae514f 100644 --- a/tests/python/arith/test_arith_solve_linear_equations.py +++ b/tests/python/arith/test_arith_solve_linear_equations.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T def test_solution_consistency(): @@ -109,8 +110,8 @@ def test_unique_solution(): [x, y], ) assert list(solution.dst.variables) == [] - assert ir.structural_equal(solution.src_to_dst[x], 15) - assert ir.structural_equal(solution.src_to_dst[y], 5) + assert ir.structural_equal(solution.src_to_dst[x], T.int32(15)) + assert ir.structural_equal(solution.src_to_dst[y], T.int32(5)) def test_low_rank(): @@ -128,7 +129,7 @@ def test_low_rank(): [n0] = solution.dst.variables assert ir.structural_equal(solution.src_to_dst[x], n0 + 10) assert ir.structural_equal(solution.src_to_dst[y], -n0) - assert ir.structural_equal(solution.src_to_dst[z], 5) + assert ir.structural_equal(solution.src_to_dst[z], T.int32(5)) def test_infer_range(): @@ -149,12 +150,12 @@ def test_infer_range(): assert ir.structural_equal(solution.src_to_dst[x], n0) assert ir.structural_equal(solution.src_to_dst[y], -n0) # inferred from y's range - assert ir.structural_equal(solution.dst.ranges[n0].min, -9) - assert ir.structural_equal(solution.dst.ranges[n0].extent, 10) + assert ir.structural_equal(solution.dst.ranges[n0].min, T.int32(-9)) + assert ir.structural_equal(solution.dst.ranges[n0].extent, T.int32(10)) # additional inequality is added into the system for x [ineq] = solution.dst.relations assert isinstance(ineq, tvm.tir.LE) - assert ir.structural_equal(ineq.a, -5) + assert ir.structural_equal(ineq.a, T.int32(-5)) assert ir.structural_equal(ineq.b, n0) @@ -172,7 +173,7 @@ def test_ill_formed(): ) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/arith/test_arith_solve_linear_inequality.py b/tests/python/arith/test_arith_solve_linear_inequality.py index 5285da12e75d..664258ae7cf1 100644 --- a/tests/python/arith/test_arith_solve_linear_inequality.py +++ b/tests/python/arith/test_arith_solve_linear_inequality.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import te, arith, ir, tir, testing +from tvm.script import tir as T @pytest.mark.skip(reason="See https://github.com/apache/tvm/issues/11458") @@ -113,10 +114,10 @@ def test_dual_variable(): [x_new, y_new] = solution.dst.variables [rel] = solution.dst.relations assert ir.structural_equal(rel, (y_new * 2) + x_new <= 10) - assert ir.structural_equal(solution.dst.ranges[x_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[x_new].extent, 11) - assert ir.structural_equal(solution.dst.ranges[y_new].min, 0) - assert ir.structural_equal(solution.dst.ranges[y_new].extent, 6) + assert ir.structural_equal(solution.dst.ranges[x_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[x_new].extent, T.int32(11)) + assert ir.structural_equal(solution.dst.ranges[y_new].min, T.int32(0)) + assert ir.structural_equal(solution.dst.ranges[y_new].extent, T.int32(6)) assert ir.structural_equal(solution.src_to_dst[x], x_new + (y_new + 10)) assert ir.structural_equal(solution.src_to_dst[y], y_new) assert ir.structural_equal(solution.dst_to_src[x_new], x - y - 10) @@ -185,7 +186,7 @@ def test_no_solution(): solution = arith.solve_linear_inequalities(problem, [x], vranges, deskew_range=True) assert list(solution.dst.variables) == [] [rel] = solution.dst.relations - assert ir.structural_equal(rel, False) + ir.assert_structural_equal(rel, tir.const(False)) assert len(solution.src_to_dst) == 0 assert len(solution.dst_to_src) == 0 diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index 112c521d06d4..112d1151febd 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -769,7 +769,7 @@ def check_cuda(dtype, n, l, padding, lanes): (n // lanes, l + 2 * padding, lanes), lambda i, j, k: tvm.te.if_then_else( tvm.te.any(j < padding, j >= l + padding), - tvm.runtime.convert(0).astype(dtype), + tvm.tir.const(0, dtype), A[i * lanes + k, j - padding], ), name="B", diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index f50d63878e4f..d9a6fd6e62d1 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -1138,5 +1138,46 @@ def func(): tvm.build(func) +def test_int_parameter(): + """Boolean may be passed to functions accepting int""" + + @T.prim_func + def func(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg > 0: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(True) + assert output == 10 + + output = built(False) + assert output == 20 + + +def test_bool_parameter(): + """Integers may be passed to functions accepting bool""" + + @T.prim_func + def func(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm")}) + if arg: + return 10 + else: + return 20 + + built = tvm.build(func) + output = built(1) + assert output == 10 + + output = built(2) + assert output == 10 + + output = built(0) + assert output == 20 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/ir/test_container_structural_equal.py b/tests/python/ir/test_container_structural_equal.py index 61511c609ca4..238a77b4ef4b 100644 --- a/tests/python/ir/test_container_structural_equal.py +++ b/tests/python/ir/test_container_structural_equal.py @@ -56,20 +56,20 @@ def get_first_mismatch_ensure_symmetry(a, b): ( [1, 2, 3], [1, 4, 3], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], [10, 2, 30], - ObjectPath.root().array_index(0).attr("value"), - ObjectPath.root().array_index(0).attr("value"), + ObjectPath.root().array_index(0), + ObjectPath.root().array_index(0), ), ( [1, 3, 4], [1, 2, 3, 4], - ObjectPath.root().array_index(1).attr("value"), - ObjectPath.root().array_index(1).attr("value"), + ObjectPath.root().array_index(1), + ObjectPath.root().array_index(1), ), ( [1, 2, 3], @@ -121,14 +121,28 @@ def test_shape_tuple_structural_equal_to_self(contents): assert get_first_mismatch_ensure_symmetry(a, b) is None +@pytest.mark.parametrize( + "contents", + [ + {}, + {"a": 1, "b": 2}, + {"a": True, "b": False}, + ], +) +def test_string_map_structural_equal_to_self(contents): + a = tvm.runtime.convert({**contents}) + b = tvm.runtime.convert({**contents}) + assert get_first_mismatch_ensure_symmetry(a, b) is None + + @pytest.mark.parametrize( "a, b, expected_a_path, expected_b_path", [ ( dict(a=3, b=4), dict(a=3, b=5), - ObjectPath.root().map_value("b").attr("value"), - ObjectPath.root().map_value("b").attr("value"), + ObjectPath.root().map_value("b"), + ObjectPath.root().map_value("b"), ), ( dict(a=3, b=4), diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index aa482dd65cd7..1e3249197851 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -23,16 +23,19 @@ def test_array(): a = tvm.runtime.convert([1, 2, 3]) assert len(a) == 3 - assert a[-1].value == 3 + assert a[-1] == 3 a_slice = a[-3:-1] - assert (a_slice[0].value, a_slice[1].value) == (1, 2) + assert (a_slice[0], a_slice[1]) == (1, 2) def test_array_save_load_json(): - a = tvm.runtime.convert([1, 2, 3]) + a = tvm.runtime.convert([1, 2, 3.5, True]) json_str = tvm.ir.save_json(a) a_loaded = tvm.ir.load_json(json_str) - assert a_loaded[1].value == 2 + assert a_loaded[1] == 2 + assert a_loaded[2] == 3.5 + assert a_loaded[3] == True + assert isinstance(a_loaded[3], bool) def test_dir_array(): @@ -66,7 +69,7 @@ def test_str_map(): assert "a" in amap assert len(amap) == 2 dd = dict(amap.items()) - assert amap["a"].value == 2 + assert amap["a"] == 2 assert "a" in dd assert "b" in dd @@ -78,7 +81,7 @@ def test_map_save_load_json(): json_str = tvm.ir.save_json(amap) amap = tvm.ir.load_json(json_str) assert len(amap) == 2 - dd = {kv[0].name: kv[1].value for kv in amap.items()} + dd = {kv[0].name: kv[1] for kv in amap.items()} assert dd == {"a": 2, "b": 3} diff --git a/tests/python/ir/test_ir_type.py b/tests/python/ir/test_ir_type.py index 2355aa19adec..b70406c1bb7a 100644 --- a/tests/python/ir/test_ir_type.py +++ b/tests/python/ir/test_ir_type.py @@ -16,6 +16,7 @@ # under the License. """Test type nodes in the IR""" import tvm +from tvm.script import tir as T def check_json_roundtrip(node): @@ -38,11 +39,9 @@ def test_tensor_type_bad_constructor(): def test_tensor_type(): - shape = tvm.runtime.convert([1, 2, 3]) - dtype = "float32" - tt = tvm.ir.TensorType(shape, dtype) - assert tt.dtype == dtype - assert tt.shape == shape + tt = tvm.ir.TensorType([1, 2, 3], "float32") + assert tt.dtype == "float32" + assert list(tt.shape) == [T.int32(1), T.int32(2), T.int32(3)] assert tt.span == None str(tt) check_json_roundtrip(tt) diff --git a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py index f1709c449d16..b0ddbe93601e 100644 --- a/tests/python/relax/distributed/test_distributed_tvmscript_printer.py +++ b/tests/python/relax/distributed/test_distributed_tvmscript_printer.py @@ -40,7 +40,7 @@ def test_constant(): ) assert ( constant.__str__() - == """R.dist.const(1, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" + == """R.dist.const(1.0, R.DTensor((), "float32", R.device_mesh((2, 2), R.Range(0, 4)), "R, R"))""" ) @@ -144,7 +144,7 @@ def tir_func(x: T.Buffer((T.int64(128), T.int64(128)), "float32"), y: T.Buffer(( vi, vj = T.axis.remap("SS", [i, j]) T.reads(x[vi, vj]) T.writes(y[vi, vj]) - y[vi, vj] = x[vi, vj] + T.float32(1) + y[vi, vj] = x[vi, vj] + T.float32(1.0) @R.function def foo(x: R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R")) -> R.DTensor((128, 128), "float32", "mesh[0]", "S[0], R"): diff --git a/tests/python/relax/test_ast_printer.py b/tests/python/relax/test_ast_printer.py index 97ad9f5dd034..64d5c7381171 100644 --- a/tests/python/relax/test_ast_printer.py +++ b/tests/python/relax/test_ast_printer.py @@ -404,7 +404,7 @@ def f( "op": 'ExternFunc(global_symbol="contrib.tensor_array_stack")', "args": '[Var(name_hint="x"), Var(name_hint="y")]', "sinfo_args": "[ObjectStructInfo()]", - "attrs": '{"test_attr": 1}', + "attrs": '{"test_attr": True}', }, extern_call_text, ) diff --git a/tests/python/relax/test_backend_dispatch_sort_scan.py b/tests/python/relax/test_backend_dispatch_sort_scan.py index 2ab5afaabf24..1efbd690f034 100644 --- a/tests/python/relax/test_backend_dispatch_sort_scan.py +++ b/tests/python/relax/test_backend_dispatch_sort_scan.py @@ -63,6 +63,13 @@ def foo(x: R.Tensor((2, 3), "float32", "llvm")): def test_dispatch_scanop_cuda(): + """R.cumsum and R.cumprod may be lowered with TOPI for GPU + + For the purpose of testing, this test case intentionally uses the + `exclusive=True` argument to prevent the `R.cumsum` from being + lowered to the packed func `"gpu_2d_continuous_cumsum"`. + """ + @I.ir_module class Before: I.module_global_infos({"vdevice": [I.vdevice("cuda", 0)]}) @@ -70,7 +77,7 @@ class Before: @R.function def main(x: R.Tensor(("m", 3), "float32", "cuda")): with R.dataflow(): - lv0 = R.cumsum(x, axis=1) + lv0 = R.cumsum(x, axis=1, exclusive=True) lv1 = R.cumprod(lv0, axis=1) gv = lv1 R.output(gv) @@ -89,6 +96,7 @@ def main(x: R.Tensor(("m", 3), "float32", "cuda")): topi.cuda.cumsum, x, axis=1, + exclusive=True, ) out = bb.emit_te( topi.cuda.cumprod, diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 7b64eb1dee39..e93547d83e3c 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -395,7 +395,7 @@ def test_call_tir_with_grad(): """ v0: R.Tensor((54, 96), dtype="float32") x = T.int64() -R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": T.float32(1), "x": x}) +R.call_tir_with_grad(tir_func, (v0,), out_sinfo=R.Tensor((54, 96), dtype="float32"), te_grad_name="grad_func", te_grad_kwargs={"k": 1.0, "x": x}) """, ) @@ -758,7 +758,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -770,7 +770,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - R.func_attr({"relax.force_pure": 1}) + R.func_attr({"relax.force_pure": True}) R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index ab40e181a35a..30fd06d4f14d 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -566,7 +566,7 @@ def main(shape: R.Prim(value="n")): assert func(2) == 4 - with pytest.raises(tvm.TVMError): + with pytest.raises(TypeError): func(ShapeTuple([2])) diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 9a4817f5fd8a..60f096585dfe 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -118,9 +118,10 @@ class Expected: @T.prim_func def __vmtir__ife(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__ife"}) - if T.cast( - T.tvm_call_packed("vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))), + if T.Call( "bool", + tvm.ir.Op.get("tir.tvm_call_packed"), + ["vm.builtin.read_if_cond", T.anylist_getitem(r, T.int32(0))], ): T.anylist_setitem_call_packed( r, diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 4031790fc383..b79713e05ed3 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -18,6 +18,7 @@ import numpy as np import tvm +from tvm.script import tir as T from tvm import relay from tvm.relay.build_module import bind_params_by_name from tvm.relay.dataflow_pattern import * @@ -115,7 +116,7 @@ def test_DataTypePattern(): def test_ShapePattern(): - shape = [10, 10] + shape = [T.int32(10), T.int32(10)] pattern = has_shape(shape) assert isinstance(pattern, ShapePattern) tvm.ir.assert_structural_equal(pattern.shape, shape) diff --git a/tests/python/relay/test_executor.py b/tests/python/relay/test_executor.py index d703ef1f3d9a..04662f21ae9e 100644 --- a/tests/python/relay/test_executor.py +++ b/tests/python/relay/test_executor.py @@ -57,7 +57,7 @@ def test_create_executor_attr_type_incorrect(): with pytest.raises( TVMError, match='Attribute "interface-api" should have type "runtime.String"' - ' but instead found "IntImm"', + ' but instead found "runtime.BoxBool"', ): Executor("aot", {"interface-api": True}) diff --git a/tests/python/relay/test_runtime.py b/tests/python/relay/test_runtime.py index ea15dd0d3c88..db8252f3a3c4 100644 --- a/tests/python/relay/test_runtime.py +++ b/tests/python/relay/test_runtime.py @@ -51,7 +51,7 @@ def test_create_runtime_attr_not_found(): def test_create_runtime_attr_type_incorrect(): with pytest.raises( TVMError, - match='Attribute "system-lib" should have type "IntImm"' + match='Attribute "system-lib" should have type "runtime.BoxBool"' ' but instead found "runtime.String"', ): Runtime("crt", {"system-lib": "woof"}) @@ -65,7 +65,7 @@ def test_list_runtimes(): def test_list_runtime_options(runtime): aot_options = Runtime.list_registered_options(runtime) assert "system-lib" in aot_options - assert aot_options["system-lib"] == "IntImm" + assert aot_options["system-lib"] == "runtime.BoxBool" def test_list_runtime_options_not_found(): diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index f18994d52ce9..7d0cd51d3298 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -18,12 +18,13 @@ for expressions. """ import pytest +import numpy as np + import tvm -from tvm import IRModule, parser, relay, te -from tvm.relay import analysis, op, transform +from tvm import IRModule, relay +from tvm.relay import op, transform from tvm.relay.op import op as _op - -import numpy as np +from tvm.script import tir as T def infer_mod(mod, annotate_spans=True): @@ -554,40 +555,32 @@ def test_repeat_register(): assert "Operator custom_log3 is registered before" in str(cm.execption) -def test_argreduce_infer_return_type(): +@pytest.mark.parametrize("relay_op", [relay.op.argmax, relay.op.argmin]) +@pytest.mark.parametrize( + "shape_dtype", + [ + ("int32", T.int32), + ("int64", T.int64), + ], + ids=["int32", "int64"], +) +def test_argreduce_infer_return_type(relay_op, shape_dtype): x_shape = (1, 1) broadcast_shape = [1, 1] - shape_dtypes = [("int32", lambda x: np.int32(x)), ("int64", lambda x: np.int64(x))] - - # Testing with argmax - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmax = relay.op.argmax(broadcast_to, axis=[1]) - - f = relay.Function([x], argmax) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) - - # Testing with argmin - for (sdtype, conv) in shape_dtypes: - x = relay.var("data", relay.TensorType(x_shape, "float32")) - broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) - argmin = relay.op.argmin(broadcast_to, axis=[1]) - - f = relay.Function([x], argmin) - assert_has_type( - f, - relay.FuncType( - [relay.TensorType(broadcast_shape, "float32")], - relay.TensorType([conv(1)], dtype=sdtype), - ), - ) + (sdtype, conv) = shape_dtype + + x = relay.var("data", relay.TensorType(x_shape, "float32")) + broadcast_to = relay.op.broadcast_to(x, relay.const(broadcast_shape, dtype=sdtype)) + argmax = relay_op(broadcast_to, axis=[1]) + + f = relay.Function([x], argmax) + assert_has_type( + f, + relay.FuncType( + [relay.TensorType(broadcast_shape, "float32")], + relay.TensorType([conv(1)], dtype=sdtype), + ), + ) if __name__ == "__main__": diff --git a/tests/python/runtime/test_runtime_container.py b/tests/python/runtime/test_runtime_container.py index 7538075ae7f8..e0d216b33e9a 100644 --- a/tests/python/runtime/test_runtime_container.py +++ b/tests/python/runtime/test_runtime_container.py @@ -15,12 +15,13 @@ # specific language governing permissions and limitations # under the License. -import numpy as np +import pickle import random + +import numpy as np + import tvm import tvm.testing -import pickle -from tvm import te from tvm import nd, relay from tvm.runtime import container as _container @@ -96,8 +97,123 @@ def test_shape_tuple(): assert stuple == z +def test_bool_argument(): + """Boolean objects are currently stored as int""" + func = tvm.get_global_func("testing.AcceptsBool") + + assert isinstance(func(True), bool) + assert isinstance(func(1), bool) + assert isinstance(func(0), bool) + + +def test_int_argument(): + func = tvm.get_global_func("testing.AcceptsInt") + + assert isinstance(func(True), int) + assert isinstance(func(1), int) + assert isinstance(func(0), int) + + +def test_object_ref_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRef") + + assert isinstance(func(True), bool) + assert isinstance(func(1), int) + assert isinstance(func(3.5), float) + assert func(3.5) == 3.5 + + +def test_object_ref_array_argument(): + func = tvm.get_global_func("testing.AcceptsObjectRefArray") + + assert isinstance(func([True, 17, "hello"]), bool) + assert isinstance(func([True]), bool) + assert isinstance(func([17]), int) + assert isinstance(func(["hello"]), str) + + +def test_map_argument_returns_value(): + func = tvm.get_global_func("testing.AcceptsMapReturnsValue") + + res = func({"a": 1, "b": 2}, "a") + assert isinstance(res, int) + assert res == 1 + + res = func({"a": True, "b": False}, "a") + assert isinstance(res, bool) + assert res == True + + +def test_map_argument_returns_map(): + func = tvm.get_global_func("testing.AcceptsMapReturnsMap") + + res = func({"a": 1, "b": 2}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, int) + + res = func({"a": False, "b": True}) + for key, value in res.items(): + assert isinstance(key, str) + assert isinstance(value, bool) + + +def test_conversion_of_arg(): + """Arguments may be converted + + The calling side of the FFI converts to types that are available + at runtime. However, there may be additional type conversions + required, that must be performed on the callee-side of the FFI. + """ + + func = tvm.get_global_func("testing.AcceptsPrimExpr") + + res = func(1) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "int32" + + res = func(True) + assert isinstance(res, tvm.tir.IntImm) + assert res.dtype == "bool" + + +def test_conversion_of_array_elements(): + """Elements of an array may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to array elements. Here, the Python-side of the FFI + converts the array `[1,2]` to `Array{runtime::Int(1), + runtime::Int(2)}`, and the C++ side of the FFI converts to + `Array{IntImm(1), IntImm(2)}`. + """ + + func = tvm.get_global_func("testing.AcceptsArrayOfPrimExpr") + + res = func([1, False]) + assert isinstance(res[0], tvm.tir.IntImm) + assert res[0].dtype == "int32" + assert isinstance(res[1], tvm.tir.IntImm) + assert res[1].dtype == "bool" + + +def test_conversion_of_map_values(): + """Elements of a map may require conversion from FFI to param type + + Like `test_conversion_of_arg`, but conversions must be applied + recursively to map elements. Here, the Python-side of the FFI + converts the map `{'a':1, 'b':2}` to `Map{{"a", runtime::Int(1)}, + {"b", runtime::Int(2)}}`, and the C++ side of the FFI converts to + `Map{{"a", IntImm(1)}, {"b", IntImm(2)}}`. + """ + + func = tvm.get_global_func("testing.AcceptsMapOfPrimExpr") + + res = func({"a": 1, "b": False}) + assert isinstance(res["a"], tvm.tir.IntImm) + assert res["a"].dtype == "int32" + assert isinstance(res["b"], tvm.tir.IntImm) + assert res["b"].dtype == "bool" + + if __name__ == "__main__": - test_string() - test_adt_constructor() - test_tuple_object() - test_shape_tuple() + tvm.testing.main() diff --git a/tests/python/te/test_te_schedule_tensorize.py b/tests/python/te/test_te_schedule_tensorize.py index 79aecb78902a..419d3edb5c3d 100644 --- a/tests/python/te/test_te_schedule_tensorize.py +++ b/tests/python/te/test_te_schedule_tensorize.py @@ -16,6 +16,7 @@ # under the License. import tvm from tvm import te +from tvm.script import tir as T def intrin_vadd(xo, m, n): @@ -100,6 +101,7 @@ def add(m): def check(m, factor): x, y, z = add(m) + factor = T.int32(factor) s = te.create_schedule(z.op) xo, xi = s[z].split(z.op.axis[0], factor=factor) vadd = intrin_vadd(xo, m, factor) @@ -133,7 +135,7 @@ def check_cache_write(m, factor): finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[z_global], dom_map) # outer loop var will be rebased, so min value is the new loop var and extent is 1 - tvm.ir.assert_structural_equal(out_dom[xo].extent, 1) + tvm.ir.assert_structural_equal(out_dom[xo].extent, T.int32(1)) assert isinstance(out_dom[xo].min, tvm.tir.Var) assert xo.var.name == out_dom[xo].min.name @@ -183,7 +185,7 @@ def check(factor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -207,7 +209,7 @@ def check_rfactor(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -230,7 +232,7 @@ def check_rfactor_no_reset(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -254,7 +256,7 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): dom_map = tvm.te.schedule.InferBound(s) finfer = tvm.get_global_func("test.op.InferTensorizeRegion") out_dom, in_dom = finfer(s[C], dom_map) - tvm.ir.assert_structural_equal(out_dom[x].extent, 1) + tvm.ir.assert_structural_equal(out_dom[x].extent, T.int32(1)) tvm.ir.assert_structural_equal(out_dom[y].extent, factor) tvm.ir.assert_structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") @@ -264,10 +266,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) - check(16) - check_rfactor(16, 16) - check_rfactor_no_reset(16, 16) - check_rfactor_no_reset_multi_reduction(16, 16) + check(T.int32(16)) + check_rfactor(T.int32(16), T.int32(16)) + check_rfactor_no_reset(T.int32(16), T.int32(16)) + check_rfactor_no_reset_multi_reduction(T.int32(16), T.int32(16)) # This tests whether algorithm and intrinsics expressions are simplified diff --git a/tests/python/te/test_te_tag.py b/tests/python/te/test_te_tag.py index 6e88a12614cf..a4b76e7d6736 100644 --- a/tests/python/te/test_te_tag.py +++ b/tests/python/te/test_te_tag.py @@ -57,12 +57,12 @@ def test_with(): assert C.op.tag == "gemm" assert "hello" in C.op.attrs assert "xx" not in C.op.attrs - assert C.op.attrs["hello"].value == 1 + assert C.op.attrs["hello"] == 1 CC = tvm.ir.load_json(tvm.ir.save_json(C)) - assert CC.op.attrs["hello"].value == 1 - assert CC.op.attrs["arr"][0].value == 10 - # str format happened to be json compatible - assert json.loads(str(CC.op.attrs))["arr"][1] == 12 + assert CC.op.attrs["hello"] == 1 + assert len(CC.op.attrs["arr"]) == 2 + assert CC.op.attrs["arr"][0] == 10 + assert CC.op.attrs["arr"][1] == 12 def test_decorator(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index e94a4f09ec56..0e610cc1659b 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -122,7 +122,7 @@ def test_lower_build_tir_func(): def test_lower_build_tir_module(): func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", True) + func = func.with_attr("tir.noalias", T.bool(True)) ir_mod = IRModule({"main": func}) # check lowering with the CSE pass disabled as otherwise it would do some commoning with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): diff --git a/tests/python/tir-base/test_tir_buffer.py b/tests/python/tir-base/test_tir_buffer.py index b4b773197b14..d706e65d8186 100644 --- a/tests/python/tir-base/test_tir_buffer.py +++ b/tests/python/tir-base/test_tir_buffer.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing from tvm import te from tvm.tir import Buffer +from tvm.script import tir as T + import numpy as np +import pytest def test_buffer(): @@ -78,9 +81,9 @@ def test_buffer_access_ptr_extent(): # Test extent from input params aptr = Ab.access_ptr("rw", extent=200) - tvm.ir.assert_structural_equal(aptr.args[3], 200) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(200)) aptr = Ab.access_ptr("rw", offset=100, extent=100) - tvm.ir.assert_structural_equal(aptr.args[3], 100) + tvm.ir.assert_structural_equal(aptr.args[3], T.int32(100)) def test_buffer_vload(): @@ -88,7 +91,7 @@ def test_buffer_vload(): n = te.size_var("n") Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - tvm.ir.assert_structural_equal(load.indices, [2, 3]) + tvm.ir.assert_structural_equal(load.indices, [T.int32(2), T.int32(3)]) def test_buffer_offset_of(): @@ -259,7 +262,7 @@ def test_buffer_flatten(): buf = tvm.tir.decl_buffer([16, 32]) flat = buf.get_flattened_buffer() assert buf.data.same_as(flat.data) - tvm.ir.assert_structural_equal(flat.shape, [16 * 32]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(16 * 32)]) def test_buffer_flatten_preserves_identity(): @@ -273,8 +276,8 @@ def test_buffer_flatten_uses_axis_separators(): """Flattening to N-d physical buffers uses the axis separators""" buf = tvm.tir.decl_buffer([4, 16, 32], axis_separators=[2]) flat = buf.get_flattened_buffer() - tvm.ir.assert_structural_equal(flat.axis_separators, [1]) - tvm.ir.assert_structural_equal(flat.shape, [4 * 16, 32]) + tvm.ir.assert_structural_equal(flat.axis_separators, [T.int32(1)]) + tvm.ir.assert_structural_equal(flat.shape, [T.int32(4 * 16), T.int32(32)]) def test_invalid_axis_separators_raises_exception(): diff --git a/tests/python/tir-base/test_tir_index_map.py b/tests/python/tir-base/test_tir_index_map.py index e893ed897d65..3ddbd2f69f59 100644 --- a/tests/python/tir-base/test_tir_index_map.py +++ b/tests/python/tir-base/test_tir_index_map.py @@ -22,6 +22,7 @@ from tvm.ir import assert_structural_equal from tvm.runtime import const from tvm.tir import IndexMap, IntImm, floordiv, floormod +from tvm.script import tir as T def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: @@ -37,28 +38,22 @@ def assert_equal_index_map(map1: IndexMap, map2: IndexMap) -> None: def test_index_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_indices([0]), [0, 0]) - assert_structural_equal(index_map.map_indices([3]), [0, 3]) - assert_structural_equal(index_map.map_indices([4]), [1, 0]) - assert_structural_equal(index_map.map_indices([42]), [10, 2]) - assert_structural_equal( - index_map.map_indices([const(42, "int64")]), [const(10, "int64"), const(2, "int64")] - ) + assert_structural_equal(index_map.map_indices([0]), [T.int32(0), T.int32(0)]) + assert_structural_equal(index_map.map_indices([3]), [T.int32(0), T.int32(3)]) + assert_structural_equal(index_map.map_indices([4]), [T.int32(1), T.int32(0)]) + assert_structural_equal(index_map.map_indices([42]), [T.int32(10), T.int32(2)]) + assert_structural_equal(index_map.map_indices([T.int64(42)]), [T.int64(10), T.int64(2)]) def test_shape_mapping(): index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") - assert_structural_equal(index_map.map_shape([4]), [1, 4]) - assert_structural_equal(index_map.map_shape([16]), [4, 4]) + assert_structural_equal(index_map.map_shape([4]), [T.int32(1), T.int32(4)]) + assert_structural_equal(index_map.map_shape([16]), [T.int32(4), T.int32(4)]) - assert_structural_equal(index_map.map_shape([14]), [4, 4]) - assert_structural_equal( - index_map.map_shape([const(16, "int64")]), [const(4, "int64"), const(4, "int64")] - ) - assert_structural_equal( - index_map.map_shape([const(14, "int64")]), [const(4, "int64"), const(4, "int64")] - ) + assert_structural_equal(index_map.map_shape([14]), [T.int32(4), T.int32(4)]) + assert_structural_equal(index_map.map_shape([T.int64(16)]), [T.int64(4), T.int64(4)]) + assert_structural_equal(index_map.map_shape([T.int64(14)]), [T.int64(4), T.int64(4)]) def test_inverse(): @@ -82,28 +77,28 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[16], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(False), ), "right_padding": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), ), "left_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[15], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.And(i == 0, j < 1), ), "left_and_right_padding": dict( forward=lambda i: [(i + 1) // 4, (i + 1) % 4], inverse=lambda i, j: [4 * i + j - 1], pre_shape=[14], - post_shape=[4, 4], + post_shape=[T.int32(4), T.int32(4)], padding=lambda i, j: tvm.tir.Or( tvm.tir.And(i == 0, j < 1), tvm.tir.And(i == 3, tvm.runtime.convert(3) == j), @@ -113,7 +108,7 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [4 * i + j], pre_shape=[dynamic_N], - post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, 4], + post_shape=[(dynamic_N - dynamic_N % (-4)) // 4, T.int32(4)], padding=lambda i, j: tvm.tir.And( dynamic_N % (-4) != 0, tvm.tir.And(i == dynamic_N // 4, j >= dynamic_N % 4), @@ -127,10 +122,10 @@ def test_nonbijective_inverse_gives_error(): ], pre_shape=[14, 31], post_shape=[ - 4, # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 - 5, # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 - 4, # Range of iter%4 - 8, # Range of iter%8 + T.int32(4), # ceildiv(left_pad + i.extent, 4) = ceildiv(1 + 14, 4) = 4 + T.int32(5), # ceildiv(left_pad + j.extent, 8) = ceildiv(5 + 31, 8) = 5 + T.int32(4), # Range of iter%4 + T.int32(8), # Range of iter%8 ], padding=lambda i_outer, j_outer, i_inner, j_inner: tvm.tir.Or( tvm.tir.Or( @@ -147,35 +142,35 @@ def test_nonbijective_inverse_gives_error(): forward=lambda i: [i // 32, (i // 4) % 8, i % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_right_padding_transpose": dict( forward=lambda i: [(i // 4) % 8, i // 32, i % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k], pre_shape=[116], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 3, 4 * j + k >= 20), ), "multiple_left_padding": dict( forward=lambda i: [(i + 5) // 32, ((i + 5) // 4) % 8, (i + 5) % 4], inverse=lambda i, j, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[4, 8, 4], + post_shape=[T.int32(4), T.int32(8), T.int32(4)], padding=lambda i, j, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "multiple_left_padding_with_transpose": dict( forward=lambda i: [((i + 5) // 4) % 8, (i + 5) // 32, (i + 5) % 4], inverse=lambda j, i, k: [32 * i + 4 * j + k - 5], pre_shape=[123], - post_shape=[8, 4, 4], + post_shape=[T.int32(8), T.int32(4), T.int32(4)], padding=lambda j, i, k: tvm.tir.And(i == 0, j * 4 + k < 5), ), "outer_loop_extent_one": dict( forward=lambda i: [i // 4, i % 4], inverse=lambda i, j: [i * 4 + j], pre_shape=[3], - post_shape=[1, 4], + post_shape=[T.int32(1), T.int32(4)], padding=lambda i, j: tvm.runtime.convert(3) == j, ), } diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index eeedae1f127c..29efd95280be 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -32,7 +32,7 @@ def test_te_const(): assert isinstance(x, tvm.tir.IntImm) -def test_scalar_dtype_inference(): +def test_tir_const_dtype_inference(): for data in [ True, bool(1), @@ -49,28 +49,11 @@ def test_scalar_dtype_inference(): np.float64(1), ]: assert tvm.tir.const(data).dtype == str(np.array(data).dtype) + + assert tvm.tir.const(True).dtype == "bool" assert tvm.tir.const(1).dtype == "int32" assert tvm.tir.const(1.0).dtype == "float32" - for data in [ - True, - bool(1), - np.uint8(1), - np.uint16(1), - np.uint32(1), - np.uint64(1), - np.int8(1), - np.int16(1), - np.int32(1), - np.int64(1), - np.float16(1), - np.float32(1), - np.float64(1), - ]: - assert tvm.runtime.convert(data).dtype == str(np.array(data).dtype) - assert tvm.runtime.convert(1).dtype == "int32" - assert tvm.runtime.convert(1.0).dtype == "float32" - def test_make(): x = tvm.tir.const(1, "int32") @@ -133,7 +116,7 @@ def test_attr(): assert stmt.node == y a = tvm.runtime.convert(1) - assert a.value == 1 + assert a == 1 try: a.no_field assert False @@ -350,7 +333,7 @@ def test_prim_func(): assert len(func.buffer_map) == 1 f2 = func.with_attr({"calling_conv": 1, "tir.noalias": True}) - assert f2.attrs["calling_conv"].value == 1 + assert f2.attrs["calling_conv"] == 1 assert not func.attrs diff --git a/tests/python/tir-schedule/test_tir_schedule_sampling.py b/tests/python/tir-schedule/test_tir_schedule_sampling.py index c2f3f89e6e12..8ae576e9b922 100644 --- a/tests/python/tir-schedule/test_tir_schedule_sampling.py +++ b/tests/python/tir-schedule/test_tir_schedule_sampling.py @@ -146,7 +146,7 @@ def test_sample_categorical_serialize(): decisions.append(rv) new_sch = verify_trace_roundtrip(sch, mod=elementwise) for i, new_inst in enumerate(new_sch.trace.insts): - assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst]] def test_sample_perfect_tile_power_of_two(): diff --git a/tests/python/tir-schedule/test_tir_schedule_state.py b/tests/python/tir-schedule/test_tir_schedule_state.py index 74880e5a42d9..c023b9dbc59d 100644 --- a/tests/python/tir-schedule/test_tir_schedule_state.py +++ b/tests/python/tir-schedule/test_tir_schedule_state.py @@ -155,10 +155,10 @@ def test_replace_direct_write0(): old_hash = s.mod["main"].__hash__() sref = s.get_sref(s.mod["main"].body.block.body[1]) s.replace(sref, target) - # There is no other reference so the AST node can be written directly - assert old_hash == s.mod["main"].__hash__() # Check the replaced part is equal to the target tvm.ir.assert_structural_equal(s.mod["main"].body.block.body[1], target) + # There is no other reference so the AST node can be written directly + assert old_hash == s.mod["main"].__hash__() # The target reuse the stmt of the sref, so the sref won't be None assert sref.stmt is not None diff --git a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py index d5d5e0634ef6..cb7151f875e3 100644 --- a/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py +++ b/tests/python/tir-transform/test_tir_transform_compact_buffer_region.py @@ -1029,38 +1029,45 @@ class TestTileAwareCompaction(BaseCompactTest): # it is not an opaque block case intentionally is_lower_order_free = False - @T.prim_func - def before( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ): - for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): - A_local = T.decl_buffer((26, 128), scope="local") - B_local = T.decl_buffer((128, 26), scope="local") - C_local = T.decl_buffer((26, 26), scope="local") - for ax0, ax1 in T.grid(26, 128): - if i_0 * 26 + ax0 < 128: - A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] - for ax0, ax1 in T.grid(128, 26): - if j_0 * 26 + ax1 < 128: - B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] - for i_1, j_1, k in T.grid(26, 26, 128): - if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: - if k == 0: - C_local[i_1, j_1] = T.float32(0) - C_local[i_1, j_1] = C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] - for ax0, ax1 in T.grid(26, 26): - if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: - C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] - - # Get partitioned workload to compact - before_mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) - with tvm.transform.PassContext(config={"tir.LoopPartition": {"partition_const_loop": True}}): - before_mod = tvm.tir.transform.LowerOpaqueBlock()(before_mod) - before_mod = tvm.tir.transform.LoopPartition()(before_mod) - before = before_mod["main"] + @property + def before(self): + @T.prim_func + def main( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + C: T.Buffer((128, 128), "float32"), + ): + for i_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + for j_0 in range(5, annotations={"pragma_loop_partition_hint": 1}): + A_local = T.decl_buffer((26, 128), scope="local") + B_local = T.decl_buffer((128, 26), scope="local") + C_local = T.decl_buffer((26, 26), scope="local") + for ax0, ax1 in T.grid(26, 128): + if i_0 * 26 + ax0 < 128: + A_local[ax0, ax1] = A[i_0 * 26 + ax0, ax1] + for ax0, ax1 in T.grid(128, 26): + if j_0 * 26 + ax1 < 128: + B_local[ax0, ax1] = B[ax0, j_0 * 26 + ax1] + for i_1, j_1, k in T.grid(26, 26, 128): + if i_0 * 26 + i_1 < 128 and j_0 * 26 + j_1 < 128: + if k == 0: + C_local[i_1, j_1] = T.float32(0) + C_local[i_1, j_1] = ( + C_local[i_1, j_1] + A_local[i_1, k] * B_local[k, j_1] + ) + for ax0, ax1 in T.grid(26, 26): + if i_0 * 26 + ax0 < 128 and j_0 * 26 + ax1 < 128: + C[i_0 * 26 + ax0, j_0 * 26 + ax1] = C_local[ax0, ax1] + + # Get partitioned workload to compact + mod = tvm.IRModule.from_expr(main) + with tvm.transform.PassContext( + config={"tir.LoopPartition": {"partition_const_loop": True}} + ): + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + + return mod["main"] @T.prim_func def expected( diff --git a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py index 9f61b5a3920a..3078572bb508 100644 --- a/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py +++ b/tests/python/tir-transform/test_tir_transform_instrument_bound_checkers.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import pytest + import tvm import tvm.testing -from tvm import te +from tvm import te, tir + +import pytest import numpy as np @@ -184,7 +186,7 @@ def collect_branch_stmt(x): if isinstance(x, tvm.tir.IfThenElse): branch_collector.append(x) - n = 21 + n = tir.const(21) A = te.placeholder((n,), name="A") B = te.placeholder((n,), name="B") diff --git a/tests/python/tir-transform/test_tir_transform_make_packed_api.py b/tests/python/tir-transform/test_tir_transform_make_packed_api.py index 23a51a0817df..0b43db56f300 100644 --- a/tests/python/tir-transform/test_tir_transform_make_packed_api.py +++ b/tests/python/tir-transform/test_tir_transform_make_packed_api.py @@ -394,5 +394,144 @@ def func_without_arg( tvm.ir.assert_structural_equal(Expected, After) +def test_int_parameter(): + """Boolean may be passed to functions accepting int + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts an integer argument, the caller may call it with a boolean + value. + + This also provides backwards compatibility for functions that were + defined as accepting an integer, but are called with a boolean + argument. Prior to PackedFunc interface supporting boolean + arguments directly, the argument would be converted from boolean + to integer to be stored in a TVMValue. After adding support for + boolean arguments, this usage should not cause an error. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.int32) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg > 0: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 0 or arg_code == 15, "main: Expect arg[0] to be int" + arg: T.int32 = T.if_then_else( + arg_code == 0, + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "int64")), + T.Cast("int32", T.tvm_struct_get(args, 0, 12, "bool")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg > 0: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + +def test_bool_parameter(): + """An integer may be passed to a function acccepting Boolean + + A PackedFunc produced by compiling an IRModule should support the + same type conversions as the C++ implementation. When a function + accepts a boolean argument, the caller may call it with an integer + value. + + """ + + @I.ir_module + class Before: + @T.prim_func + def main(arg: T.bool) -> T.int32: + T.func_attr({"target": T.target("llvm", host="llvm")}) + if arg: + return 10 + else: + return 20 + + @I.ir_module + class Expected: + @T.prim_func + def main( + args: T.handle, + arg_type_ids: T.handle("int32"), + num_args: T.int32, + out_ret_value: T.handle("void"), + out_ret_tcode: T.handle("int32"), + resource_handle: T.handle, + ) -> T.int32: + T.func_attr( + { + "calling_conv": 1, + "target": T.target("llvm"), + } + ) + assert num_args == 1, "main: num_args should be 1" + assert not T.isnullptr(args), "main: TVMValue* arg pointer was NULL" + assert not T.isnullptr(arg_type_ids), "main: int* type_codes was NULL" + arg_type_ids_1 = T.decl_buffer((1,), "int32", data=arg_type_ids) + arg_code: T.int32 = arg_type_ids_1[0] + assert arg_code == 15 or arg_code == 0, "main: Expect arg[0] to be boolean" + arg: T.bool = T.if_then_else( + arg_code == 15, + T.tvm_struct_get(args, 0, 12, "bool"), + T.Cast("bool", T.tvm_struct_get(args, 0, 12, "int64")), + ) + with T.attr(0, "compute_scope", "main_compute_"): + out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,)) + out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,)) + if arg: + out_ret_value_1[0] = T.Cast("int64", 10) + out_ret_tcode_1[0] = 0 + return 0 + else: + out_ret_value_1[0] = T.Cast("int64", 20) + out_ret_tcode_1[0] = 0 + return 0 + return 0 + + After = tvm.tir.transform.MakePackedAPI()(Before) + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 4b71eb825414..68149e7d64bb 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -937,8 +937,8 @@ def test_vulkan_smem_reuse(): "kind": "vulkan", "max_num_threads": 256, "max_threads_per_block": 256, - "supports_float32": T.bool(True), - "supports_int32": T.bool(True), + "supports_float32": True, + "supports_int32": True, "tag": "", "thread_warp_size": 1, } diff --git a/tests/python/tvmscript/test_tvmscript_error_report.py b/tests/python/tvmscript/test_tvmscript_error_report.py index 279785fdca51..d8212d38854c 100644 --- a/tests/python/tvmscript/test_tvmscript_error_report.py +++ b/tests/python/tvmscript/test_tvmscript_error_report.py @@ -332,26 +332,35 @@ def convert_slice_to_bufferload() -> None: check_error(convert_slice_to_bufferload, 6) -def test_tvm_exception_catch(): +def test_tvm_exception_catch_from_special_stmt(): def special_stmt_except() -> None: A = T.alloc_buffer("(128, 128)", "float32") # error T.evaluate(1.0) + check_error(special_stmt_except, 2) + + +def test_tvm_exception_catch_from_scope_handler(): def scope_handler_except() -> None: for i in T.serial("1", "1"): # error T.evaluate(1) + check_error(scope_handler_except, 2) + + +def test_tvm_exception_catch_from_bare_intrin(): def intrin_except_unassign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") T.evaluate(A) # error + check_error(intrin_except_unassign, 3) + + +def test_tvm_exception_catch_from_assigned_intrin(): def intrin_except_assign(a: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") A[0, 0] = A[A] # error - check_error(special_stmt_except, 2) - check_error(scope_handler_except, 2) - check_error(intrin_except_unassign, 3) check_error(intrin_except_assign, 3) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index 8364e65a4178..b7ba57fa9387 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -230,7 +230,7 @@ def test_buffer_store(): obj, """ A = T.Buffer((128, 128), "float16") -A[128, 128] = A[128, 128] + T.float16(1) +A[128, 128] = A[128, 128] + T.float16(1.0) """, ) @@ -259,7 +259,7 @@ def test_let_stmt(): _assert_print( obj, """ -with T.LetStmt(T.float32(10)) as v: +with T.LetStmt(T.float32(10.0)) as v: T.evaluate(0) """, ) @@ -672,7 +672,7 @@ def test_call(): _assert_print( obj, """ -T.atan(T.float32(1)) +T.atan(T.float32(1.0)) """, ) @@ -682,7 +682,7 @@ def test_comm_reducer(): _assert_print( obj, """ -T.comm_reducer(lambda x, y: x + y, [T.float32(0)]) +T.comm_reducer(lambda x, y: x + y, [T.float32(0.0)]) """, ) @@ -712,7 +712,7 @@ def test_float_imm(): _assert_print( obj, """ -T.float16(1) +T.float16(1.0) """, ) @@ -942,7 +942,7 @@ def func(): @T.prim_func def func(): - T.evaluate(T.{dtype}(0)) + T.evaluate(T.{dtype}(0.0)) """ func = get_func(dtype) _assert_print(func, expected_output) diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index f81a80de6d61..b44ff5ad7241 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2689,14 +2689,14 @@ def test_match_buffer_region(): outer_block = root.body.body.body.block assert len(outer_block.match_buffers) == 1 buffer_C = outer_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_C.shape, [16, 1, 4]) + tvm.ir.assert_structural_equal(buffer_C.shape, [T.int32(16), T.int32(1), T.int32(4)]) assert isinstance(outer_block.body, tir.stmt.For) assert isinstance(outer_block.body.body, tir.stmt.BlockRealize) inner_block = outer_block.body.body.block assert len(inner_block.match_buffers) == 1 buffer_D = inner_block.match_buffers[0].buffer - tvm.ir.assert_structural_equal(buffer_D.shape, [4, 1, 4]) + tvm.ir.assert_structural_equal(buffer_D.shape, [T.int32(4), T.int32(1), T.int32(4)]) def block_elements(): @@ -3981,6 +3981,32 @@ def func() -> T.int32: return func +def func_attr_with_list(): + @T.prim_func + def func( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + D: T.Buffer((128, 128), "float32"), + ) -> None: + T.func_attr( + {"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [T.int32(1)]} + ) + C = T.alloc_buffer([128, 128], dtype="float32") + for i0, i1, i2 in T.grid(128, 128, 128): + with T.block("C"): + x, y, k = T.axis.remap("SSR", [i0, i1, i2]) + with T.init(): + C[x, y] = T.float32(0) + C[x, y] = C[x, y] + A[x, k] * B[y, k] + for i0, i1 in T.grid(128, 128): + with T.block("D"): + T.block_attr({"layout_free_placeholders": [C]}) + x, y = T.axis.remap("SS", [i0, i1]) + D[x, y] = C[x, y] + T.float32(1) + + return func + + def op_of_literal(): op_list = [ (T.exp, 0), @@ -4198,6 +4224,7 @@ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")): return_zero, return_zero_private, return_zero_private_with_attr, + func_attr_with_list, *op_of_literal(), *relax_match_cast_struct_info_proxy(), relax_symbolic_size_var, diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9bc9800c1cb8..ae83a9d66392 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -19,6 +19,7 @@ import tvm from tvm import te from tvm.topi import utils +from tvm.script import tir as T from .environment import get_env @@ -1046,19 +1047,19 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) tvm.ir.assert_structural_equal( - analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0 + analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), T.int32(0) ) - tvm.ir.assert_structural_equal(src_coeff[-2], 1) - tvm.ir.assert_structural_equal(dst_coeff[-2], 1) + tvm.ir.assert_structural_equal(src_coeff[-2], T.int32(1)) + tvm.ir.assert_structural_equal(dst_coeff[-2], T.int32(1)) if env.BATCH > 1: assert len(src_coeff) > 2 assert len(dst_coeff) > 2 assert len(extents) > 1 - tvm.ir.assert_structural_equal(src_coeff[-3], env.BLOCK_OUT) - tvm.ir.assert_structural_equal(dst_coeff[-3], env.BLOCK_OUT) + tvm.ir.assert_structural_equal(src_coeff[-3], T.int32(env.BLOCK_OUT)) + tvm.ir.assert_structural_equal(dst_coeff[-3], T.int32(env.BLOCK_OUT)) # Apply tensorization of the loop coefficients src_offset = src_coeff[-1]