diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 2fb11f5a6f83b..e7d13c47b2bd4 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -35,7 +35,7 @@ namespace transform { * * \return The Pass. */ -TVM_DLL Pass VMBuiltinLower(); +TVM_DLL Pass LowerRuntimeBuiltin(); /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. diff --git a/include/tvm/relax/op_attr_types.h b/include/tvm/relax/op_attr_types.h index b44c4582d82d3..c644e208f916a 100644 --- a/include/tvm/relax/op_attr_types.h +++ b/include/tvm/relax/op_attr_types.h @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc; +/*! \brief The function type of a function to lower the runtime builtin. + * + * A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`. + * +* \param bb The BlockBuilder context. +* \param call The call to be lowered. +*/ +using FLowerBuiltin = runtime::TypedPackedFunc; + /*! * \brief Gradient for a specific op. * diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 14b2b84b0d366..0072981be513e 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -240,6 +240,10 @@ class TVM_DLL DeviceAPI { return device_type != kDLCPU && device_type != kDLMicroDev; } + static bool SupportsPointerArithmetics(int device_type) { + return device_type != kDLVulkan; + } + protected: /*! * \brief copy data from one place to another diff --git a/python/tvm/relax/op/memory/__init__.py b/python/tvm/relax/op/memory/__init__.py index 2ae1b676e035c..1191550085de8 100644 --- a/python/tvm/relax/op/memory/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -17,4 +17,4 @@ """Relax memory primitives.""" from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor -from .view import view, ensure_aligned +from .view import view, ensure_zero_offset diff --git a/python/tvm/relax/op/memory/view.py b/python/tvm/relax/op/memory/view.py index 233d07f6c9b71..95adc782092f3 100644 --- a/python/tvm/relax/op/memory/view.py +++ b/python/tvm/relax/op/memory/view.py @@ -94,7 +94,7 @@ def _normalize(expr, relax_cls): return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore -def ensure_aligned(data: Expr) -> Expr: +def ensure_zero_offset(data: Expr) -> Expr: """ Ensure the tensor has elem_offset == 0. A copy will be made if necessary. @@ -106,6 +106,6 @@ def ensure_aligned(data: Expr) -> Expr: Results ------- result : relax.Expr - The aligned tensor + The tensor with elem_offset == 0 """ - return _ffi_api.ensure_aligned(data) # type: ignore + return _ffi_api.ensure_zero_offset(data) # type: ignore diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index d068f800d0e9b..38242ff4d2d3d 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I transform.RewriteCUDAGraph(), transform.LowerAllocTensor(), transform.KillAfterLastUse(), - transform.VMBuiltinLower(), + transform.LowerRuntimeBuiltin(), transform.ComputePrimValue(), transform.VMShapeLower(), transform.AttachGlobalSymbol(), diff --git a/python/tvm/relax/transform/__init__.py b/python/tvm/relax/transform/__init__.py index 5e76fff6bd1e6..eef6d331375c4 100644 --- a/python/tvm/relax/transform/__init__.py +++ b/python/tvm/relax/transform/__init__.py @@ -16,6 +16,15 @@ # under the License. """Relax transformations. """ +# Import to register the legalization functions. +from . import legalize_ops, tuning_api +from .attach_external_modules import AttachExternModules +from .fast_math import FastMathTransform +from .ipc_allreduce_rewrite import IPCAllReduceRewrite +from .lazy_transform_params import LazyTransformParams +from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage +from .optimize_layout_transform import OptimizeLayoutTransform +from .remove_redundant_reshape import RemoveRedundantReshape from .transform import ( AdjustMatmulOrder, AllocateWorkspace, @@ -55,6 +64,7 @@ LegalizeOps, LiftTransformParams, LowerAllocTensor, + LowerRuntimeBuiltin, MergeCompositeFunctions, MetaScheduleApplyDatabase, MetaScheduleTuneIRMod, @@ -64,8 +74,8 @@ PatternCheckContext, RealizeVDevice, RemovePurityChecking, - RemoveUnusedParameters, RemoveUnusedOutputs, + RemoveUnusedParameters, ReorderPermuteDimsAfterConcat, ReorderTakeAfterMatmul, RewriteCUDAGraph, @@ -78,19 +88,7 @@ TopologicalSort, UpdateParamStructInfo, UpdateVDevice, - VMBuiltinLower, VMShapeLower, dataflowblock_pass, function_pass, ) - -from .ipc_allreduce_rewrite import IPCAllReduceRewrite -from .lazy_transform_params import LazyTransformParams -from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage -from .optimize_layout_transform import OptimizeLayoutTransform -from .remove_redundant_reshape import RemoveRedundantReshape -from .fast_math import FastMathTransform -from .attach_external_modules import AttachExternModules - -# Import to register the legalization functions. -from . import legalize_ops, tuning_api diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 3528b4429e6fc..e017bc113b2cb 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -586,14 +586,14 @@ def ComputePrimValue() -> tvm.ir.transform.Pass: return _ffi_api.ComputePrimValue() # type: ignore -def VMBuiltinLower() -> tvm.ir.transform.Pass: +def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass: """Lowering generic intrinsic to VM intrinsics. Returns ------- ret: tvm.ir.transform.Pass """ - return _ffi_api.VMBuiltinLower() # type: ignore + return _ffi_api.LowerRuntimeBuiltin() # type: ignore def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/lower_runtime_builtin.cc similarity index 86% rename from src/relax/backend/vm/vm_builtin_lower.cc rename to src/relax/backend/vm/lower_runtime_builtin.cc index 961aa9b600f8b..7fff6c95329dc 100644 --- a/src/relax/backend/vm/vm_builtin_lower.cc +++ b/src/relax/backend/vm/lower_runtime_builtin.cc @@ -17,13 +17,14 @@ * under the License. */ /*! - * \file src/relax/backend/vm/vm_builtin_lower.cc + * \file src/relax/backend/vm/lower_runtime_builtin.cc * \brief Lowers most builtin functions and packed calls. */ #include #include #include #include +#include #include #include #include @@ -33,11 +34,12 @@ namespace relax { // This pass lowers most ops to VM specific builtins. // TODO(relax-team): revisit after PrimValue. -class VMBuiltinLowerMutator : public ExprMutator { +class LowerRuntimeBuiltinMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; Expr VisitExpr_(const CallNode* call_node) final { + static const auto& lower_builtin_fmap = Op::GetAttrMap("FLowerBuiltin"); // post-order mutation Call call = Downcast(VisitExprPostOrder_(call_node)); @@ -47,10 +49,6 @@ class VMBuiltinLowerMutator : public ExprMutator { return Reshape(call); } else if (call->op == shape_of_op_) { return ShapeOf(call); - } else if (call->op == view_op_) { - return View(call); - } else if (call->op == ensure_aligned_op_) { - return EnsureAligned(call); } else if (call->op == to_vdevice_op_) { return ToDevice(call); } else if (call->op == make_closure_op_) { @@ -68,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator { return MakeMemAllocTensor(call); } else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) { return MakeMemKillObject(call); - } else { - return call; + } else if (const auto* op_node = call->op.as()) { + Op op = GetRef(op_node); + if (lower_builtin_fmap.count(op)) { + return lower_builtin_fmap[op](builder_, call); + } } + return call; } Expr MakeMemAllocStorage(const Call& call) { @@ -128,19 +130,6 @@ class VMBuiltinLowerMutator : public ExprMutator { } } - Expr View(const Call& view_node) { - StructInfoDeriveFunc infer_sinfo_env_func; - infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); - auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); - ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); - return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo}); - } - - Expr EnsureAligned(const Call& call_node) { - ICHECK(call_node->args.size() == 1); - return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)}); - } - Expr ShapeOf(const Call& call_node) { ICHECK(call_node->args.size() == 1); ICHECK(call_node->struct_info_.defined()); @@ -205,8 +194,6 @@ class VMBuiltinLowerMutator : public ExprMutator { const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); const Op& reshape_op_ = Op::Get("relax.reshape"); const Op& shape_of_op_ = Op::Get("relax.shape_of"); - const Op& view_op_ = Op::Get("relax.memory.view"); - const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned"); const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice"); const Op& make_closure_op_ = Op::Get("relax.make_closure"); const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); @@ -227,20 +214,20 @@ class VMBuiltinLowerMutator : public ExprMutator { const ExternFunc builtin_to_device_{"vm.builtin.to_device"}; const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; - const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"}; + }; -Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } +Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); } namespace transform { -Pass VMBuiltinLower() { +Pass LowerRuntimeBuiltin() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(VMBuiltinLower(f)); }; - return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); + [=](Function f, IRModule m, PassContext pc) { return Downcast(LowerRuntimeBuiltin(f)); }; + return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {}); } -TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); +TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin); } // namespace transform } // namespace relax diff --git a/src/relax/op/memory/view.cc b/src/relax/op/memory/view.cc index d43cc01838aea..b582748e64b5f 100644 --- a/src/relax/op/memory/view.cc +++ b/src/relax/op/memory/view.cc @@ -58,9 +58,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { if (auto opt = sinfo.as()) { return opt.value(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects first argument to be a tensor, " - << "but received " << arg_data << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects first argument to be a tensor, " << "but received " << arg_data + << " with type " << sinfo; } }(); auto view_shape_sinfo = [&]() -> const ShapeStructInfoNode* { @@ -73,10 +73,10 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // The `R.view` operation returns a different shape. return ptr; } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects second argument to be a ShapeExpr, " - << "or a void-type (empty relax tuple), " - << "but received " << arg_shape << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects second argument to be a ShapeExpr, " + << "or a void-type (empty relax tuple), " << "but received " << arg_shape + << " with type " << sinfo; } }(); @@ -111,10 +111,9 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // being changed into. return DataType::Void(); } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op - << " expects the dtype argument to be a relax::DataTypeImm, " - << "but received " << arg_dtype << " with type " << sinfo; + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects the dtype argument to be a relax::DataTypeImm, " << "but received " + << arg_dtype << " with type " << sinfo; } }(); @@ -126,8 +125,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return IntImm(DataType::Int(64), 0); } else if (auto prim_sinfo = sinfo.as()) { CHECK_EQ(prim_sinfo->dtype, DataType::Int(64)) - << "TypeError: " - << "Operator " << call->op + << "TypeError: " << "Operator " << call->op << " expects the relative_byte_offset to be a 64-bit integer, but received " << arg_relative_byte_offset << ", which has type " << sinfo; if (prim_sinfo->value.defined()) { @@ -139,9 +137,8 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { return NullOpt; } } else { - LOG(FATAL) << "TypeError: " - << "Operator " << call->op << " expects the relative_byte_offset argument " - << "to be a Relax PrimValue. " + LOG(FATAL) << "TypeError: " << "Operator " << call->op + << " expects the relative_byte_offset argument " << "to be a Relax PrimValue. " << "However, expression " << call << " provides relative_byte_offset of " << arg_relative_byte_offset << ", which has type " << sinfo; } @@ -246,8 +243,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) { // view to be larger than the original array. CHECK_GE(input_element_size.value()->value, output_element_size.value()->value) - << "ValueError: " - << "Operator " << call->op + << "ValueError: " << "Operator " << call->op << " may not produce a view that exceeds the bounds of the original array. " << "In expression " << call << " the data type is changed from " << data_sinfo->dtype << " to " << view_dtype.value() << ", increasing the size per element from " @@ -313,9 +309,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(data_shape.defined()) << "Legalization of " << call->op << " requires that either the output shape be explicitly specified, " - << "or the input shape is known. " - << "However, in expression " << call << ", no output shape is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown shape."; + << "or the input shape is known. " << "However, in expression " << call + << ", no output shape is specified, " << "and the input " << data << " of type " + << data->struct_info_ << " has unknown shape."; shape = ShapeExpr(data_shape.value()); } @@ -324,9 +320,9 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { CHECK(!data_dtype.is_void()) << "Legalization of " << call->op << " requires that either the output dtype be explicitly specified, " - << "or the input dtype is known. " - << "However, in expression " << call << ", no output dtype is specified, " - << "and the input " << data << " of type " << data->struct_info_ << " has unknown dtype."; + << "or the input dtype is known. " << "However, in expression " << call + << ", no output dtype is specified, " << "and the input " << data << " of type " + << data->struct_info_ << " has unknown dtype."; dtype = relax::DataTypeImm(data_dtype); } @@ -342,6 +338,14 @@ Expr LegalizeView(const BlockBuilder& bb, const Call& call) { return Call(call->op, {data, shape, dtype, relative_byte_offset}); } +Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) { + StructInfoDeriveFunc infer_sinfo_env_func; + infer_sinfo_env_func= EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo"); + auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true); + ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo); + return Call(runtime_view_func, call->args, call->attrs, {runtime_view_sinfo}); +} + TVM_REGISTER_OP("relax.memory.view") .set_num_inputs(4) .add_argument("x", "Tensor", "The input tensor.") @@ -352,30 +356,37 @@ TVM_REGISTER_OP("relax.memory.view") .set_attr("RequiresArgumentShapes", Bool(false)) .set_attr("FInferStructInfo", InferStructInfoView) .set_attr("FLegalize", LegalizeView) - .set_attr("FPurity", Bool(true)); + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinView); -Expr ensure_aligned(const Expr& x) { - static const Op& op = Op::Get("relax.memory.ensure_aligned"); +Expr ensure_zero_offset(const Expr& x) { + static const Op& op = Op::Get("relax.memory.ensure_zero_offset"); return Call(op, {x}); } -TVM_REGISTER_GLOBAL("relax.op.memory.ensure_aligned").set_body_typed(ensure_aligned); +TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset); -StructInfo InferStructInfoEnsureAligned(const Call& call, const BlockBuilder& ctx) { +StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) { if (call->args.size() != 1) { ctx->ReportFatal(Diagnostic::Error(call) - << "Operator " << call->op << " should receive 1 argument, " - << "but received " << call->args); + << "Operator " << call->op << " should receive 1 argument, " << "but received " + << call->args); } return GetStructInfo(call->args[0]); } -TVM_REGISTER_OP("relax.memory.ensure_aligned") +Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) { + const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"}; + return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)}); +} + +TVM_REGISTER_OP("relax.memory.ensure_zero_offset") .set_num_inputs(1) .add_argument("x", "Tensor", "The input tensor.") .set_attr("RequiresArgumentShapes", Bool(false)) - .set_attr("FInferStructInfo", InferStructInfoEnsureAligned) - .set_attr("FPurity", Bool(true)); + .set_attr("FInferStructInfo", InferStructInfoEnsureZeroOffset) + .set_attr("FPurity", Bool(true)) + .set_attr("FLowerBuiltin", LowerBuiltinEnsureZeroOffset); } // namespace relax } // namespace tvm diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 2922de6dcc7ee..74200526b699c 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -286,12 +286,12 @@ class TokenAllocator1D { std::vector full_pool_; }; -/*! \brief Check if the input op is a memory op that return the same buffer as the input buffer. */ +/*! \brief Check if the input op is a memory op that may return the same buffer. */ bool IsInplaceMemoryOp(const Expr& op) { static const Op& reshape_op = Op::Get("relax.reshape"); static const Op& view_op = Op::Get("relax.memory.view"); - static const Op& ensure_aligned_op = Op::Get("relax.memory.ensure_aligned"); - return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_aligned_op); + static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset"); + return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op); } /*! \brief The base class for the storage allocation visitor. */ diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc index 83b016446548a..1227c5163c314 100644 --- a/src/runtime/relax_vm/builtin.cc +++ b/src/runtime/relax_vm/builtin.cc @@ -545,17 +545,21 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data return ShapeTuple(out_shape); }); -TVM_REGISTER_GLOBAL("vm.builtin.ensure_aligned").set_body_typed([](NDArray data) { +TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) { if (data->byte_offset == 0) { return data; } - DLManagedTensor* dl_tensor = data.ToDLPack(); - dl_tensor->dl_tensor.data = - reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; - dl_tensor->dl_tensor.byte_offset = 0; - // For platforms that does not support pointer arithmetic, we need to copy the data to a new - // buffer. - return NDArray::FromDLPack(dl_tensor); + if (DeviceAPI::SupportsPointerArithmetics(data->device.device_type)) { + DLManagedTensor* dl_tensor = data.ToDLPack(); + dl_tensor->dl_tensor.data = + reinterpret_cast(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset; + dl_tensor->dl_tensor.byte_offset = 0; + return NDArray::FromDLPack(dl_tensor); + } else { + auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device); + new_array.CopyFrom(data); + return new_array; + } }); } // namespace relax_vm diff --git a/tests/python/relax/test_op_view.py b/tests/python/relax/test_op_view.py index 1e21612f9fff4..033aee9882a42 100644 --- a/tests/python/relax/test_op_view.py +++ b/tests/python/relax/test_op_view.py @@ -731,7 +731,7 @@ def main(A: R.Tensor([4096], "float32")): shape=R.shape([16, 64]), relative_byte_offset=32 * 64 * 4, ) - C = R.memory.ensure_aligned(B) + C = R.memory.ensure_zero_offset(B) return C built = tvm.relax.build(Module, target=target) diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 3ab468844b013..911a19d43592b 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1461,7 +1461,7 @@ def main(): cls = Before x = R.builtin.alloc_tensor(R.shape([16, 16]), dtype="float32", runtime_device_index=0) x1 = R.memory.view(x, [128], "float32", 0) - x2 = R.memory.ensure_aligned(x1) + x2 = R.memory.ensure_zero_offset(x1) y = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) cls.tir_exp(x2, y) z = R.builtin.alloc_tensor(R.shape([128]), dtype="float32", runtime_device_index=0) @@ -1486,7 +1486,7 @@ def main() -> R.Tensor((128,), dtype="float32"): x1: R.Tensor((128,), dtype="float32") = R.memory.view( x, R.shape([128]), R.dtype("float32"), R.prim_value(0) ) - x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_aligned(x1) + x2: R.Tensor((128,), dtype="float32") = R.memory.ensure_zero_offset(x1) storage1: R.Object = R.memory.alloc_storage( R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") )