diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 2fb11f5a6f83..e7d13c47b2bd 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -35,7 +35,7 @@ namespace transform { * * \return The Pass. */ -TVM_DLL Pass VMBuiltinLower(); +TVM_DLL Pass LowerRuntimeBuiltin(); /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index b44c4582d82d..291bee597c03 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc; +/*! \brief The function type of a function to lower the runtime builtin. + * + * A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`. + * + * \param bb The BlockBuilder context. + * \param call The call to be lowered. + */ +using FLowerBuiltin = runtime::TypedPackedFunc; + /*! * \brief Gradient for a specific op. * diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 14b2b84b0d36..c33606d98ed3 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI { return device_type != kDLCPU && device_type != kDLMicroDev; } + /*! + * \brief Whether pointer arithmetics on a device owned pointer may be performed on the host. + */ + virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; } + protected: /*! * \brief copy data from one place to another diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 422c5d2e1f53..1191550085de 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view +from .view import view, ensure_zero_offset diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 0c3d8a03b2dd..95adc782092f 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls): relative_byte_offset = _normalize(relative_byte_offset, PrimValue) return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore + + +def ensure_zero_offset(data: Expr) -> Expr: + """ + Ensure the tensor has elem_offset == 0. A copy will be made if necessary. + + Parameters + ---------- + data : relax.Expr + The input tensor + + Results + ------- + result : relax.Expr + The tensor with elem_offset == 0 + """ + return _ffi_api.ensure_zero_offset(data) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index d068f800d0e9..38242ff4d2d3 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I transform.RewriteCUDAGraph(), transform.LowerAllocTensor(), transform.KillAfterLastUse(), - transform.VMBuiltinLower(), + transform.LowerRuntimeBuiltin(), transform.ComputePrimValue(), transform.VMShapeLower(), transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5789e2fcf235..1ce864651cd9 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -55,6 +55,7 @@ LegalizeOps, LiftTransformParams, LowerAllocTensor, + LowerRuntimeBuiltin, MergeCompositeFunctions, MetaScheduleApplyDatabase, MetaScheduleTuneIRMod, @@ -64,8 +65,8 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, - RemoveUnusedParameters, RemoveUnusedOutputs, + RemoveUnusedParameters, ReorderPermuteDimsAfterConcat, ReorderTakeAfterMatmul, RewriteCUDAGraph, @@ -84,14 +85,14 @@ function_pass, ) +from .attach_external_modules import AttachExternModules +from .fast_math import FastMathTransform +from .fuse_transpose_matmul import FuseTransposeMatmul from .ipc_allreduce_rewrite import IPCAllReduceRewrite from .lazy_transform_params import LazyTransformParams from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage from .optimize_layout_transform import OptimizeLayoutTransform from .remove_redundant_reshape import RemoveRedundantReshape -from .fast_math import FastMathTransform -from .fuse_transpose_matmul import FuseTransposeMatmul -from .attach_external_modules import AttachExternModules # Import to register the legalization functions. from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3528b4429e6f..2546284625e9 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,6 +19,7 @@ import functools import inspect import types +import warnings from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np # type: ignore @@ -586,6 +587,16 @@ def ComputePrimValue() -> tvm.ir.transform.Pass: return _ffi_api.ComputePrimValue() # type: ignore +def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.LowerRuntimeBuiltin() # type: ignore + + def VMBuiltinLower() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. @@ -593,7 +604,11 @@ def VMBuiltinLower() -> tvm.ir.transform.Pass: ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.VMBuiltinLower() # type: ignore + warnings.warn( + "tvm.relax.transform.VMBuiltinLower has been renamed to 'LowerRuntimeBuiltin'. " + "This wrapper is for backwards compatibility, and will be removed in a later update." + ) + return _ffi_api.LowerRuntimeBuiltin() # type: ignore def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/lower_runtime_builtin.cc similarity index 90% rename from src/relax/backend/vm/vm_builtin_lower.cc rename to src/relax/backend/vm/lower_runtime_builtin.cc index 887998d004c7..a3867ae92448 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -17,13 +17,14 @@ * under the License. */ /*! - * \file src/relax/backend/vm/vm_builtin_lower.cc + * \file src/relax/backend/vm/lower_runtime_builtin.cc * \brief Lowers most builtin functions and packed calls. */ #include #include #include #include +#include #include #include #include @@ -33,11 +34,12 @@ namespace relax { // This pass lowers most ops to VM specific builtins. // TODO(relax-team): revisit after PrimValue. -class VMBuiltinLowerMutator : public ExprMutator { +class LowerRuntimeBuiltinMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { + static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); // post-order mutation Call call = Downcast(VisitExprPostOrder_(call_node)); @@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator { return MakeMemAllocTensor(call); } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); - } else { - return call; + } else if (const auto* op_node = call->op.as()) { + Op op = GetRef(op_node); + if (lower_builtin_fmap.count(op)) { + return lower_builtin_fmap[op](builder_, call); + } } + return call; } Expr MakeMemAllocStorage(const Call& call) { @@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; }; -Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } +Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); } namespace transform { -Pass VMBuiltinLower() { +Pass LowerRuntimeBuiltin() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; - return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); + [=](Function f, IRModule m, PassContext pc) { + return Downcast(LowerRuntimeBuiltin(f)); + }; + return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); +TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index e7634c7edfce..21a72f6200b0 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView); -Expr LegalizeView(const BlockBuilder& bb, const Call& call) { +Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { Expr data = call->args[0]; Expr shape = call->args[1]; Expr dtype = call->args[2]; @@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view") "The view's byte offset, relative to the input tensor's byte offset.") .set_attr("RequiresArgumentShapes", Bool(false)) .set_attr("FInferStructInfo", InferStructInfoView) - .set_attr("FLegalize", LegalizeView) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinView); + +Expr ensure_zero_offset(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_zero_offset"); + return Call(op, {x}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); + +StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Operator " << call->op << " should receive 1 argument, " + << "but received " << call->args); + } + return GetStructInfo(call->args[0]); +} + +Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { + const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; + return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)}); +} + +TVM_REGISTER_OP("relax.memory.ensure_zero_offset") + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("RequiresArgumentShapes", Bool(false)) + .set_attr("FInferStructInfo", InferStructInfoEnsureZeroOffset) + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinEnsureZeroOffset); } // namespace relax } // namespace tvm diff --git a/src/relax/op/memory/view.h b/src/relax/op/memory/view.h index bc8002fa5b69..77ec7e9833cc 100644 --- a/src/relax/op/memory/view.h +++ b/src/relax/op/memory/view.h @@ -32,6 +32,9 @@ namespace relax { /*! \brief View a tensor with different properties. */ Expr view(Expr x, Optional shape, Optional dtype, Optional relative_byte_offset); +/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */ +Expr ensure_aligned(const Expr& x); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2b16d8650906..74200526b699 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,8 +286,13 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is "relax.reshape". */ -bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); } +/*! \brief Check if the input op is a memory op that may return the same buffer. */ +bool IsInplaceMemoryOp(const Expr& op) { + static const Op& reshape_op = Op::Get("relax.reshape"); + static const Op& view_op = Op::Get("relax.memory.view"); + static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op); +} /*! \brief The base class for the storage allocation visitor. */ class StorageAllocatorBaseVisitor : public ExprVisitor { @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor { // Create a storage token for builtin alloc_tensor. this->CreateToken(call); return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { // Reuse the input's token for builtin reshape. SetTokens(call, GetTokens(call->args[0])); return; @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor { block_tokens.push_back(new_token.get()); } return; - } else if (IsReshape(call->op)) { + } else if (IsInplaceMemoryOp(call->op)) { Tokens tokens = GetTokens(call->args[0]); ICHECK(!tokens.IsNested()); if (tokens.IsLeaf()) { diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index 774335f5660b..ccd726a6ece6 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI { void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CPUDeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 66357a191541..33908d750d6d 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI { CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data); } + bool SupportsDevicePointerArithmeticsOnHost() final { return true; } + static CUDADeviceAPI* Global() { // NOTE: explicitly use new to avoid exit-time destruction of global state // Global state will be recycled by OS as the process exits. diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index af1cf9d20335..9fe6fba80f5c 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -551,6 +551,25 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); +TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { + if (data->byte_offset == 0) { + return data; + } + auto* device_api = DeviceAPI::Get(data->device); + if (device_api->SupportsDevicePointerArithmeticsOnHost() && + data->byte_offset % tvm::runtime::kAllocAlignment == 0) { + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + return NDArray::FromDLPack(dl_tensor); + } else { + auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + new_array.CopyFrom(data); + return new_array; + } +}); + } // namespace relax_vm } // namespace runtime } // namespace tvm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 2433821c2abd..0900e1be306b 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -452,7 +452,9 @@ def inferred_sinfo(A: R.Tensor, relative_byte_offset: R.Prim("int64")): tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) -def test_legalize_without_any_changes_is_no_op(): +def test_legalize_is_no_op(): + """R.memory.view is not legalized until LowerRuntimeBuiltin""" + @I.ir_module class Before: @R.function @@ -460,18 +462,13 @@ def main(A: R.Tensor([4096], "float32")): B = R.memory.view(A) return B - @I.ir_module - class Expected: - @R.function - def main(A: R.Tensor([4096], "float32")): - B = A - return B + Expected = Before After = tvm.relax.transform.LegalizeOps()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_shape_change(): +def test_lower_runtime_builtin_shape_change(): @I.ir_module class Before: @R.function @@ -497,11 +494,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_shape_from_unknown(): +def test_lower_runtime_builtin_view_shape_from_unknown(): """R.memory.view does not require the input tensor to have a known shape""" @I.ir_module @@ -529,11 +526,11 @@ def main(A: R.Tensor(dtype="float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_dtype_change(): +def test_lower_runtime_builtin_dtype_change(): @I.ir_module class Before: @R.function @@ -559,11 +556,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_byte_offset(): +def test_lower_runtime_builtin_byte_offset(): @I.ir_module class Before: @R.function @@ -589,11 +586,11 @@ def main(A: R.Tensor([4096], "float32")): ) return B - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) -def test_legalize_view_with_multiple_updated_fields(): +def test_lower_runtime_builtin_view_with_multiple_updated_fields(): """R.memory.view may update more than one field in the view In this test case, a 4-kilobyte buffer is provided. The first @@ -650,7 +647,7 @@ def main(A: R.Tensor([4096], "uint8")): ) return (B, C) - After = tvm.relax.transform.LegalizeOps()(Before) + After = tvm.relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 63f422d4cfbe..f9e632d34897 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -185,7 +185,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) mod = relax.transform.LowerAllocTensor()(mod) mod = relax.transform.KillAfterLastUse()(mod) - mod = relax.transform.VMBuiltinLower()(mod) + mod = relax.transform.LowerRuntimeBuiltin()(mod) tvm.ir.assert_structural_equal(mod, ExpectedLowered) @@ -1449,5 +1449,60 @@ def main( tvm.ir.assert_structural_equal(mod, Expected) +def test_view(): + @I.ir_module + class Before: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main(): + cls = Before + x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) + x1 = R.memory.view(x, [128], "float32", 0) + x2 = R.memory.ensure_zero_offset(x1) + y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(x2, y) + z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) + cls.tir_exp(y, z) + return z + + @I.ir_module + class Expected: + @T.prim_func + def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle): + T.evaluate(0) + + @R.function + def main() -> R.Tensor((128,), dtype="float32"): + cls = Expected + storage: R.Object = R.memory.alloc_storage( + R.shape([1024]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + x: R.Tensor((16, 16), dtype="float32") = R.memory.alloc_tensor( + storage, R.prim_value(0), R.shape([16, 16]), R.dtype("float32") + ) + x1: R.Tensor((128,), dtype="float32") = R.memory.view( + x, R.shape([128]), R.dtype("float32"), R.prim_value(0) + ) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_zero_offset(x1) + storage1: R.Object = R.memory.alloc_storage( + R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + y: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([128]), R.dtype("float32") + ) + cls.tir_exp(x2, y) + z: R.Tensor((128,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([128]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.tir_exp(y, z) + return z + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_vm_builtin_lower.py b/tests/python/relax/test_vm_builtin_lower.py index df28db4d46d2..984f9f958ca2 100644 --- a/tests/python/relax/test_vm_builtin_lower.py +++ b/tests/python/relax/test_vm_builtin_lower.py @@ -57,7 +57,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: gv0 = alloc return gv0 - After = relax.transform.VMBuiltinLower()(Before) + After = relax.transform.LowerRuntimeBuiltin()(Before) tvm.ir.assert_structural_equal(Expected, After) @@ -79,7 +79,7 @@ def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: return gv0 with pytest.raises(tvm.TVMError): - relax.transform.VMBuiltinLower()(Before) + relax.transform.LowerRuntimeBuiltin()(Before) if __name__ == "__main__":