Skip to content

Commit

Permalink
[Relax][Transform] Handle identical PrimFunc with distinct VDevice (#…
Browse files Browse the repository at this point in the history
…16959)

* [Relax][Transform] Handle identical PrimFunc with distinct VDevice

Prior to this commit, if an `IRModule` contained two expressions,
where the types of the arguments differed only by the `VDevice`, these
would be legalized to produce a single PrimFunc.  This PrimFunc would
have the a `tvm::attr::kTarget` annotation specific to one of those
expressions, and would be incorrect for use in the other location.

This commit updates the `LegalizeOps` transform to handle this case,
producing multiple TIR PrimFuncs if required by the `VDevice`
annotations.

* Fix breakage in tests, caused by unused PrimFunc without target attr
  • Loading branch information
Lunderberg authored May 13, 2024
1 parent 5b5f8d0 commit c2d14ae
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 23 deletions.
112 changes: 89 additions & 23 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace relax {
Expand Down Expand Up @@ -74,16 +75,22 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
const auto& mod = builder_->GetContextIRModule();
for (const auto& gv : mod->GetGlobalVars()) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = mod->Lookup(gv).as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}

IRModule output = builder_->GetContextIRModule();
if (generated_tir_with_target_attr_) {
// It is possible that every call to a legalized PrimFunc
// contains VDevice annotations. In that case, the PrimFunc
// without a target annotation no longer has any callers, and
// should be removed.
output = relax::transform::DeadCodeElimination()(output);

// Avoid accidental sharing of TIR variables in the legalized
// PrimFuncs, when kernels for multiple devices are generated
// from the same PrimFunc.
output = tir::transform::ConvertSSA()(output);
}
return builder_->GetContextIRModule();

return output;
}

private:
Expand Down Expand Up @@ -129,7 +136,7 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}

Target GetTarget(const Array<StructInfo>& sinfos) {
Optional<Target> GetTarget(const Array<StructInfo>& sinfos) {
for (auto sinfo : sinfos) {
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
if (tinfo->vdevice.defined()) {
Expand All @@ -142,18 +149,76 @@ class LegalizeMutator : public ExprMutator {
return GetTarget(tup_sinfo->fields);
}
}
return Target();
return NullOpt;
}

void SaveTarget(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
auto call = Downcast<Call>(expr);
auto target = GetTarget(call->sinfo_args);
const GlobalVarNode* gvar_node;
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
}
Expr BindTarget(Expr expr) {
if (!expr->IsInstance<CallNode>()) {
// FLegalize returned something other than a relax::Call. This
// post-processing only handles cases where legalization
// produces a lowered call node. In principle, this
// post-processing isn't necessary, and FLegalize should already
// have generated vdevice-aware kernels, so hopefully the
// FLegalize implementation did so.
return expr;
}

auto call = Downcast<Call>(expr);

auto vdevice_target = GetTarget(call->sinfo_args);
if (!vdevice_target.defined()) {
// No vdevice annotation is present, so we don't need to apply
// any updates.
return expr;
}

if (call->args.empty()) {
return expr;
}

auto gvar = call->args[0].as<GlobalVar>();
if (!gvar.defined()) {
// This is not a call into a legalized function within the
// current IRModule, so no post-processing is required.
return expr;
}

auto base_func = builder_->GetContextIRModule()->Lookup(gvar.value());
auto opt_prim_func = base_func.as<tir::PrimFunc>();
if (!opt_prim_func) {
// The call is to something other than a PrimFunc. It may be
// another Relax function, in which case the legalization of its
// body will handle any additional target annotations.
return expr;
}
auto prim_func = opt_prim_func.value();

auto func_target = prim_func->GetAttr<Target>(tvm::attr::kTarget);
if (func_target && func_target.value()->kind == vdevice_target.value()->kind) {
// The function already has compatible annotations for the
// target, so no modifications are required.
return expr;
}

// The FLegalize function generated a PrimFunc, but that PrimFunc
// doesn't have annotations compatible with the vdevice required
// by the Relax StructInfo. Update the call to instead call a
// `PrimFunc` with the appropriate target annotation. In the
// future, this may be treated as a bug in the FLegalize
// implementation, rather than expected output from it.
auto new_prim_func = WithAttr(prim_func, tvm::attr::kTarget, vdevice_target.value());
auto new_gvar_name = [&]() -> std::string {
std::stringstream ss;
ss << gvar.value()->name_hint;
ss << "_";
ss << vdevice_target.value()->kind->name;
return ss.str();
}();
auto new_gvar = builder_->AddFunction(new_prim_func, new_gvar_name);
generated_tir_with_target_attr_ = true;

call.CopyOnWrite()->args.Set(0, new_gvar);
return call;
}

Expr VisitExpr_(const CallNode* call) final {
Expand Down Expand Up @@ -268,8 +333,9 @@ class LegalizeMutator : public ExprMutator {
}
Expr legalized = legalization_func(builder_, visited_call);

// Save the expected target info. into tmap_
SaveTarget(legalized);
// Append the target attribute to any PrimFunc generated in
// legalization.
legalized = BindTarget(legalized);

legalized = builder_->Normalize(legalized);

Expand Down Expand Up @@ -303,8 +369,8 @@ class LegalizeMutator : public ExprMutator {
IRModule mod_;
/*! \brief The customized legalization function map. */
Map<String, PackedFunc> cmap_;
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
Map<GlobalVar, Target> tmap_;
/*! \brief If VDevice annotations produced at least one PrimFunc with a Target attr*/
bool generated_tir_with_target_attr_{false};
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
36 changes: 36 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,42 @@ class IRConvertSSA final : public StmtExprMutator {
return std::move(decl);
}

Stmt VisitStmt_(const BlockNode* op) final {
Block block = GetRef<Block>(op);

// The BlockNode is the point of definition for the IterVar
// instances. These re-defines must be present before visiting
// the body of the BlockNode.
std::vector<ScopedRedefine> redefines;
Array<IterVar> iter_vars = op->iter_vars.Map([&](IterVar iter_var) {
if (defined_.count(iter_var->var.get())) {
redefines.emplace_back(this, iter_var->var);
iter_var.CopyOnWrite()->var = redefines.back().new_var;
} else {
defined_.insert(iter_var->var.get());
}
return iter_var;
});
Array<BufferRegion> reads =
block->reads.Map([&](const auto& region) { return VisitBufferAccess(region); });
Array<BufferRegion> writes =
block->writes.Map([&](const auto& region) { return VisitBufferAccess(region); });

if (!reads.same_as(block->reads) || !writes.same_as(block->writes) ||
!iter_vars.same_as(op->iter_vars)) {
auto write_ptr = block.CopyOnWrite();
write_ptr->reads = reads;
write_ptr->writes = writes;
write_ptr->iter_vars = iter_vars;
}

Stmt output = Downcast<Block>(StmtExprMutator::VisitStmt_(block.get()));

while (redefines.size()) redefines.pop_back();

return output;
}

template <typename Node>
Node VisitBufferAccess(Node node) {
Buffer new_buf = GetRemappedBuffer(node->buffer);
Expand Down
81 changes: 81 additions & 0 deletions tests/python/relax/test_transform_legalize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,86 @@ def main(
tvm.ir.assert_structural_equal(AfterFirstIter, AfterSecondIter)


def test_legalize_with_vdevice():
"""Legalization may generate kernels for multiple targets
This is a regression test. In previous implementations, Relax
expressions whose argument types differed only by their `vdevice`
would be legalized to use the same `PrimFunc`.
"""

@I.ir_module
class Before:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

@R.function
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C

@R.function
def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
):
C = R.add(A, B)
return C

@I.ir_module
class Expected:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

@R.function
def func_cuda(
A: R.Tensor((32, 32), dtype="float32"),
B: R.Tensor((32, 32), dtype="float32"),
):
cls = Expected
C = R.call_tir(cls.add, (A, B), out_sinfo=R.Tensor((32, 32), dtype="float32"))
return C

@T.prim_func(private=True)
def add(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

@R.function
def func_llvm(
A: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
B: R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
):
cls = Expected
C = R.call_tir(
cls.add_llvm,
(A, B),
out_sinfo=R.Tensor((32, 32), dtype="float32", vdevice="llvm"),
)
return C

@T.prim_func(private=True)
def add_llvm(
A: T.Buffer((T.int64(32), T.int64(32)), "float32"),
B: T.Buffer((T.int64(32), T.int64(32)), "float32"),
C: T.Buffer((T.int64(32), T.int64(32)), "float32"),
):
T.func_attr({"target": T.target("llvm"), "tir.noalias": T.bool(True)})
for iters in T.grid(T.int64(32), T.int64(32)):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", iters)
C[ax0, ax1] = A[ax0, ax1] + B[ax0, ax1]

with tvm.target.Target("cuda"):
After = tvm.relax.transform.LegalizeOps()(Before)

tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c2d14ae

Please sign in to comment.