Skip to content

Commit

Permalink
Rename to ensure_zero_offset and LowerRuntimeBuiltin
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 15, 2024
1 parent 094428d commit c23d027
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 98 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relax/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace transform {
*
* \return The Pass.
*/
TVM_DLL Pass VMBuiltinLower();
TVM_DLL Pass LowerRuntimeBuiltin();

/*!
* \brief Lower the shape expression in relax to VM shape heap and TIR functions.
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relax/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ using FNormalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, Call ca
*/
using FLegalize = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*! \brief The function type of a function to lower the runtime builtin.
*
* A builtin function may be lowered to a lowered form in `LowerRuntimeBuiltin`.
*
* \param bb The BlockBuilder context.
* \param call The call to be lowered.
*/
using FLowerBuiltin = runtime::TypedPackedFunc<Expr(const BlockBuilder& bb, const Call& call)>;

/*!
* \brief Gradient for a specific op.
*
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ class TVM_DLL DeviceAPI {
return device_type != kDLCPU && device_type != kDLMicroDev;
}

static bool SupportsPointerArithmetics(int device_type) {
return device_type != kDLVulkan;
}

protected:
/*!
* \brief copy data from one place to another
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
"""Relax memory primitives."""

from .memory import alloc_storage, alloc_tensor, kill_storage, kill_tensor
from .view import view, ensure_aligned
from .view import view, ensure_zero_offset
6 changes: 3 additions & 3 deletions python/tvm/relax/op/memory/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _normalize(expr, relax_cls):
return _ffi_api.view(data, shape, dtype, relative_byte_offset) # type: ignore


def ensure_aligned(data: Expr) -> Expr:
def ensure_zero_offset(data: Expr) -> Expr:
"""
Ensure the tensor has elem_offset == 0. A copy will be made if necessary.
Expand All @@ -106,6 +106,6 @@ def ensure_aligned(data: Expr) -> Expr:
Results
-------
result : relax.Expr
The aligned tensor
The tensor with elem_offset == 0
"""
return _ffi_api.ensure_aligned(data) # type: ignore
return _ffi_api.ensure_zero_offset(data) # type: ignore
2 changes: 1 addition & 1 deletion python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
transform.RewriteCUDAGraph(),
transform.LowerAllocTensor(),
transform.KillAfterLastUse(),
transform.VMBuiltinLower(),
transform.LowerRuntimeBuiltin(),
transform.ComputePrimValue(),
transform.VMShapeLower(),
transform.AttachGlobalSymbol(),
Expand Down
24 changes: 11 additions & 13 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
# under the License.
"""Relax transformations. """

# Import to register the legalization functions.
from . import legalize_ops, tuning_api
from .attach_external_modules import AttachExternModules
from .fast_math import FastMathTransform
from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .transform import (
AdjustMatmulOrder,
AllocateWorkspace,
Expand Down Expand Up @@ -55,6 +64,7 @@
LegalizeOps,
LiftTransformParams,
LowerAllocTensor,
LowerRuntimeBuiltin,
MergeCompositeFunctions,
MetaScheduleApplyDatabase,
MetaScheduleTuneIRMod,
Expand All @@ -64,8 +74,8 @@
PatternCheckContext,
RealizeVDevice,
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
RemoveUnusedParameters,
ReorderPermuteDimsAfterConcat,
ReorderTakeAfterMatmul,
RewriteCUDAGraph,
Expand All @@ -78,19 +88,7 @@
TopologicalSort,
UpdateParamStructInfo,
UpdateVDevice,
VMBuiltinLower,
VMShapeLower,
dataflowblock_pass,
function_pass,
)

from .ipc_allreduce_rewrite import IPCAllReduceRewrite
from .lazy_transform_params import LazyTransformParams
from .lower_gpu_ipc_alloc_storage import LowerGPUIPCAllocStorage
from .optimize_layout_transform import OptimizeLayoutTransform
from .remove_redundant_reshape import RemoveRedundantReshape
from .fast_math import FastMathTransform
from .attach_external_modules import AttachExternModules

# Import to register the legalization functions.
from . import legalize_ops, tuning_api
4 changes: 2 additions & 2 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,14 +586,14 @@ def ComputePrimValue() -> tvm.ir.transform.Pass:
return _ffi_api.ComputePrimValue() # type: ignore


def VMBuiltinLower() -> tvm.ir.transform.Pass:
def LowerRuntimeBuiltin() -> tvm.ir.transform.Pass:
"""Lowering generic intrinsic to VM intrinsics.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.VMBuiltinLower() # type: ignore
return _ffi_api.LowerRuntimeBuiltin() # type: ignore


def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
* under the License.
*/
/*!
* \file src/relax/backend/vm/vm_builtin_lower.cc
* \file src/relax/backend/vm/lower_runtime_builtin.cc
* \brief Lowers most builtin functions and packed calls.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/backend.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/type.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/op.h>
Expand All @@ -33,11 +34,12 @@ namespace relax {

// This pass lowers most ops to VM specific builtins.
// TODO(relax-team): revisit after PrimValue.
class VMBuiltinLowerMutator : public ExprMutator {
class LowerRuntimeBuiltinMutator : public ExprMutator {
public:
using ExprMutator::VisitExpr_;

Expr VisitExpr_(const CallNode* call_node) final {
static const auto& lower_builtin_fmap = Op::GetAttrMap<FLowerBuiltin>("FLowerBuiltin");
// post-order mutation
Call call = Downcast<Call>(VisitExprPostOrder_(call_node));

Expand All @@ -47,10 +49,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
return Reshape(call);
} else if (call->op == shape_of_op_) {
return ShapeOf(call);
} else if (call->op == view_op_) {
return View(call);
} else if (call->op == ensure_aligned_op_) {
return EnsureAligned(call);
} else if (call->op == to_vdevice_op_) {
return ToDevice(call);
} else if (call->op == make_closure_op_) {
Expand All @@ -68,9 +66,13 @@ class VMBuiltinLowerMutator : public ExprMutator {
return MakeMemAllocTensor(call);
} else if (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_) {
return MakeMemKillObject(call);
} else {
return call;
} else if (const auto* op_node = call->op.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
if (lower_builtin_fmap.count(op)) {
return lower_builtin_fmap[op](builder_, call);
}
}
return call;
}

Expr MakeMemAllocStorage(const Call& call) {
Expand Down Expand Up @@ -128,19 +130,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
}
}

Expr View(const Call& view_node) {
StructInfoDeriveFunc infer_sinfo_env_func;
infer_sinfo_env_func = EnvFunc::Get("tvm.relax.struct_info.infer_view_sinfo");
auto runtime_view_sinfo = FuncStructInfo::OpaqueFunc(infer_sinfo_env_func, true);
ExternFunc runtime_view_func("runtime.TVMArrayCreateView", runtime_view_sinfo);
return Call(runtime_view_func, view_node->args, view_node->attrs, {runtime_view_sinfo});
}

Expr EnsureAligned(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
return Call(builtin_ensure_aligned_, call_node->args, Attrs(), {GetStructInfo(call_node)});
}

Expr ShapeOf(const Call& call_node) {
ICHECK(call_node->args.size() == 1);
ICHECK(call_node->struct_info_.defined());
Expand Down Expand Up @@ -205,8 +194,6 @@ class VMBuiltinLowerMutator : public ExprMutator {
const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
const Op& reshape_op_ = Op::Get("relax.reshape");
const Op& shape_of_op_ = Op::Get("relax.shape_of");
const Op& view_op_ = Op::Get("relax.memory.view");
const Op& ensure_aligned_op_ = Op::Get("relax.memory.ensure_aligned");
const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
const Op& make_closure_op_ = Op::Get("relax.make_closure");
const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
Expand All @@ -227,20 +214,20 @@ class VMBuiltinLowerMutator : public ExprMutator {
const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
const ExternFunc builtin_ensure_aligned_{"vm.builtin.ensure_aligned"};

};

Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); }
Expr LowerRuntimeBuiltin(const Expr& e) { return LowerRuntimeBuiltinMutator().VisitExpr(e); }

namespace transform {

Pass VMBuiltinLower() {
Pass LowerRuntimeBuiltin() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(VMBuiltinLower(f)); };
return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {});
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(LowerRuntimeBuiltin(f)); };
return CreateFunctionPass(pass_func, 0, "LowerRuntimeBuiltin", {});
}

TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower);
TVM_REGISTER_GLOBAL("relax.transform.LowerRuntimeBuiltin").set_body_typed(LowerRuntimeBuiltin);

} // namespace transform
} // namespace relax
Expand Down
Loading

0 comments on commit c23d027

Please sign in to comment.