Skip to content

Commit

Permalink
[Relax] Implement R.ensure_zero_offset and update memory planning for…
Browse files Browse the repository at this point in the history
… R.view (#17145)

Previously, `R.view` was legalized to extern call to
`runtime.TVMArrayCreateView` during `LegalizeOps`. This call to extern
func can't be properly handled by `StaticBlockPlanMemory` because it
assumes the extern func does not retain the input buffer. Extern func
returning a view of the input would break the ref count of the
buffer. This PR defers the legalization of `R.view` so that it can be
explicitly handled by memory planning.

A new op `R.ensure_aligned` is added as discussed in #16955
  • Loading branch information
vinx13 authored Aug 6, 2024
1 parent 591cf1e commit 05e2bc3
Show file tree
Hide file tree
Showing 18 changed files with 211 additions and 44 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relax/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace transform {
*
* \return The Pass.
*/
TVM_DLL Pass VMBuiltinLower();
TVM_DLL Pass LowerRuntimeBuiltin();

/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR functions.
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call ca
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*! \brief The function type of a function to lower the runtime builtin.
*
* A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`.
*
* \param bb The BlockBuilder context.
* \param call The call to be lowered.
*/
using FLowerBuiltin = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*!
* \brief Gradient for a specific op.
*
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ class TVM_DLL DeviceAPI {
return device_type != kDLCPU && device_type != kDLMicroDev;
}

/*!
* \brief Whether pointer arithmetics on a device owned pointer may be performed on the host.
*/
virtual bool SupportsDevicePointerArithmeticsOnHost() { return false; }

protected:
/*!
* \brief copy data from one place to another
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view
from .view import view, ensure_zero_offset
17 changes: 17 additions & 0 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,20 @@ def _normalize(expr, relax_cls):
relative_byte_offset = _normalize(relative_byte_offset, PrimValue)

return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore


def ensure_zero_offset(data: Expr) -> Expr:
"""
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
Parameters
----------
data : relax.Expr
The input tensor
Results
-------
result : relax.Expr
The tensor with elem_offset == 0
"""
return _ffi_api.ensure_zero_offset(data) # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
transform.RewriteCUDAGraph(),
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
transform.VMBuiltinLower(),
transform.LowerRuntimeBuiltin(),
transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
Expand Down
9 changes: 5 additions & 4 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
LowerRuntimeBuiltin,
MergeCompositeFunctions,
MetaScheduleApplyDatabase,
MetaScheduleTuneIRMod,
Expand All @@ -64,8 +65,8 @@
PatternCheckContext,
RealizeVDevice,
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
RemoveUnusedParameters,
ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
Expand All @@ -84,14 +85,14 @@
function_pass,
)

from .attach_external_modules import AttachExternModules
from .fast_math import FastMathTransform
from .fuse_transpose_matmul import FuseTransposeMatmul
from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .fast_math import FastMathTransform
from .fuse_transpose_matmul import FuseTransposeMatmul
from .attach_external_modules import AttachExternModules

# Import to register the legalization functions.
from . import legalize_ops, tuning_api
17 changes: 16 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import inspect
import types
import warnings
from typing import Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np # type: ignore
Expand Down Expand Up @@ -586,14 +587,28 @@ def ComputePrimValue() -> tvm.ir.transform.Pass:
return _ffi_api.ComputePrimValue() # type: ignore


def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.LowerRuntimeBuiltin() # type: ignore


def VMBuiltinLower() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMBuiltinLower() # type: ignore
warnings.warn(
"tvm.relax.transform.VMBuiltinLower has been renamed to 'LowerRuntimeBuiltin'. "
"This wrapper is for backwards compatibility, and will be removed in a later update."
)
return _ffi_api.LowerRuntimeBuiltin() # type: ignore


def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* under the License.
*/
/*!
* \file src/relax/backend/vm/vm_builtin_lower.cc
* \file src/relax/backend/vm/lower_runtime_builtin.cc
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/op.h>
Expand All @@ -33,11 +34,12 @@ namespace relax {

// This pass lowers most ops to VM specific builtins.
// TODO(relax-team): revisit after PrimValue.
class VMBuiltinLowerMutator : public ExprMutator {
class LowerRuntimeBuiltinMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call_node) final {
static const auto& lower_builtin_fmap = Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));

Expand All @@ -64,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
return MakeMemKillObject(call);
} else {
return call;
} else if (const auto* op_node = call->op.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
if (lower_builtin_fmap.count(op)) {
return lower_builtin_fmap[op](builder_, call);
}
}
return call;
}

Expr MakeMemAllocStorage(const Call& call) {
Expand Down Expand Up @@ -210,17 +216,19 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
};

Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); }

namespace transform {

Pass VMBuiltinLower() {
Pass LowerRuntimeBuiltin() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(VMBuiltinLower(f)); };
return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(LowerRuntimeBuiltin(f));
};
return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin);

} // namespace transform
} // namespace relax
Expand Down
35 changes: 32 additions & 3 deletions src/relax/op/memory/view.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ StructInfo InferStructInfoView(const Call& call, const BlockBuilder& ctx) {

TVM_REGISTER_GLOBAL("tvm.relax.struct_info.infer_view_sinfo").set_body_typed(InferStructInfoView);

Expr LegalizeView(const BlockBuilder& bb, const Call& call) {
Expr LowerBuiltinView(const BlockBuilder& bb, const Call& call) {
Expr data = call->args[0];
Expr shape = call->args[1];
Expr dtype = call->args[2];
Expand Down Expand Up @@ -352,8 +352,37 @@ TVM_REGISTER_OP("relax.memory.view")
"The view's byte offset, relative to the input tensor's byte offset.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoView)
.set_attr<FLegalize>("FLegalize", LegalizeView)
.set_attr<Bool>("FPurity", Bool(true));
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinView);

Expr ensure_zero_offset(const Expr& x) {
static const Op& op = Op::Get("relax.memory.ensure_zero_offset");
return Call(op, {x});
}

TVM_REGISTER_GLOBAL("relax.op.memory.ensure_zero_offset").set_body_typed(ensure_zero_offset);

StructInfo InferStructInfoEnsureZeroOffset(const Call& call, const BlockBuilder& ctx) {
if (call->args.size() != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Operator " << call->op << " should receive 1 argument, "
<< "but received " << call->args);
}
return GetStructInfo(call->args[0]);
}

Expr LowerBuiltinEnsureZeroOffset(const BlockBuilder& bb, const Call& call) {
const ExternFunc builtin_ensure_zero_offset_{"vm.builtin.ensure_zero_offset"};
return Call(builtin_ensure_zero_offset_, call->args, Attrs(), {GetStructInfo(call)});
}

TVM_REGISTER_OP("relax.memory.ensure_zero_offset")
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<Bool>("RequiresArgumentShapes", Bool(false))
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoEnsureZeroOffset)
.set_attr<Bool>("FPurity", Bool(true))
.set_attr<FLowerBuiltin>("FLowerBuiltin", LowerBuiltinEnsureZeroOffset);

} // namespace relax
} // namespace tvm
3 changes: 3 additions & 0 deletions src/relax/op/memory/view.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace relax {
/*! \brief View a tensor with different properties. */
Expr view(Expr x, Optional<Expr> shape, Optional<Expr> dtype, Optional<Expr> relative_byte_offset);

/*! \brief Ensure the tensor has elem_offset == 0. A copy will be made if necessary. */
Expr ensure_aligned(const Expr& x);

} // namespace relax
} // namespace tvm

Expand Down
13 changes: 9 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,13 @@ class TokenAllocator1D {
std::vector<StorageToken> full_pool_;
};

/*! \brief Check if the input op is "relax.reshape". */
bool IsReshape(const Expr& op) { return op.same_as(Op::Get("relax.reshape")); }
/*! \brief Check if the input op is a memory op that may return the same buffer. */
bool IsInplaceMemoryOp(const Expr& op) {
static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& view_op = Op::Get("relax.memory.view");
static const Op& ensure_zero_offset_op = Op::Get("relax.memory.ensure_zero_offset");
return op.same_as(reshape_op) || op.same_as(view_op) || op.same_as(ensure_zero_offset_op);
}

/*! \brief The base class for the storage allocation visitor. */
class StorageAllocatorBaseVisitor : public ExprVisitor {
Expand Down Expand Up @@ -498,7 +503,7 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
// Create a storage token for builtin alloc_tensor.
this->CreateToken(call);
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
// Reuse the input's token for builtin reshape.
SetTokens(call, GetTokens(call->args[0]));
return;
Expand Down Expand Up @@ -751,7 +756,7 @@ class StorageAllocator : public StorageAllocatorBaseVisitor {
block_tokens.push_back(new_token.get());
}
return;
} else if (IsReshape(call->op)) {
} else if (IsInplaceMemoryOp(call->op)) {
Tokens tokens = GetTokens(call->args[0]);
ICHECK(!tokens.IsNested());
if (tokens.IsLeaf()) {
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;

bool SupportsDevicePointerArithmeticsOnHost() final { return true; }

static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(dev, data);
}

bool SupportsDevicePointerArithmeticsOnHost() final { return true; }

static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
Expand Down
19 changes: 19 additions & 0 deletions src/runtime/relax_vm/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,25 @@ TVM_REGISTER_GLOBAL("vm.builtin.tensor_to_shape").set_body_typed([](NDArray data
return ShapeTuple(out_shape);
});

TVM_REGISTER_GLOBAL("vm.builtin.ensure_zero_offset").set_body_typed([](NDArray data) {
if (data->byte_offset == 0) {
return data;
}
auto* device_api = DeviceAPI::Get(data->device);
if (device_api->SupportsDevicePointerArithmeticsOnHost() &&
data->byte_offset % tvm::runtime::kAllocAlignment == 0) {
DLManagedTensor* dl_tensor = data.ToDLPack();
dl_tensor->dl_tensor.data =
reinterpret_cast<char*>(dl_tensor->dl_tensor.data) + dl_tensor->dl_tensor.byte_offset;
dl_tensor->dl_tensor.byte_offset = 0;
return NDArray::FromDLPack(dl_tensor);
} else {
auto new_array = NDArray::Empty(data.Shape(), data->dtype, data->device);
new_array.CopyFrom(data);
return new_array;
}
});

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
Expand Down
Loading

0 comments on commit 05e2bc3

Please sign in to comment.