diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 027fd6f824db..f3544d8613c8 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -578,9 +578,12 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2); * * Any binding blocks that are left empty will be removed by the normalizer. * + * \param entry_functions Names of functions that should be considered + * as entry points, in addition to any externally exposed functions. + * * \return The Pass. */ -TVM_DLL Pass DeadCodeElimination(Array entry_functions); +TVM_DLL Pass DeadCodeElimination(Array entry_functions = {}); /*! * \brief Pass that changes calls to operators that can be done in-place diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index 4ad291e91cce..11785ab73ac6 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -961,57 +961,73 @@ std::vector GetTupleAccessedIndices(const FunctionNode* func, const Var& */ class TIRFuseMutator : public ExprMutator { public: - static IRModule Transform(const IRModule& mod) { - Map funcs_to_keep; - for (const auto& [gv, func] : mod->functions) { - // 1. If a TIR function has global symbol, we keep the function. - // 2. Always keep ExternFunc. - if (const auto* prim_func = func.as()) { - if (prim_func->GetAttr("global_symbol").defined()) { - funcs_to_keep.Set(gv, func); + static IRModule Transform(IRModule mod) { + // Collect all primitive relax functions + Map primitive_relax; + for (const auto& [gvar, base_func] : mod->functions) { + // Only fuse primitive relax functions + if (base_func->HasNonzeroAttr(attr::kPrimitive)) { + if (auto func = base_func.as()) { + primitive_relax.Set(gvar, func.value()); } - } else if (func->IsInstance()) { - funcs_to_keep.Set(gv, func); } } + + if (primitive_relax.empty()) { + return mod; + } + + mod.CopyOnWrite(); + + IRModule updates; + std::unordered_map replacements; + // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. - TIRFuseMutator mutator(mod); + // Step 1. Fuse all primitive relax functions, store the result in `fused_tir_funcs_` - for (const auto& [gv, func] : mod->functions) { - // Only fuse primitive relax functions - if (func->IsInstance() && func->HasNonzeroAttr(attr::kPrimitive)) { - const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, gv); - mutator.fused_tir_funcs_.Set(gv, prim_func); - if (!indices.empty()) { - mutator.inplace_indices_.Set(gv, indices); - } - } + for (const auto& [old_gvar, func] : primitive_relax) { + const auto& [prim_func, indices] = FusedTIRConstructor::GetFusedTIR(mod, old_gvar); + + GlobalVar new_gvar(old_gvar->name_hint); + UpdateStructInfo(new_gvar, + FuncStructInfo::OpaqueFunc(StructInfoFromType(prim_func->ret_type))); + + mod->Remove(old_gvar); + updates->Add(new_gvar, prim_func); + replacements[old_gvar] = Replacement{new_gvar, func, indices}; } + TIRFuseMutator mutator(replacements); + // Step 2. Update all non-primitive relax functions and add it, with the dependent function, // into the new IRModule + for (const auto& [gv, func] : mod->functions) { - if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + if (func->IsInstance()) { + ICHECK(!func->HasNonzeroAttr(attr::kPrimitive)) + << "Module should not contain any primitive relax functions at this point"; relax::Function update_func = Downcast(mutator.VisitExpr(func)); - mutator.builder_->AddFunction(update_func, gv->name_hint); - } - } - - // Step 3. Add all functions that need to be kept. - auto modified_mod = mutator.builder_->GetContextIRModule(); - for (const auto& [gv, func] : funcs_to_keep) { - if (!modified_mod->ContainGlobalVar(gv->name_hint)) { - modified_mod->Add(gv, func); + if (!update_func.same_as(func)) { + updates->Add(gv, update_func); + } } } - // Step 4. Copy over module attributes and return. - if (mod->attrs.defined()) modified_mod = WithAttrs(modified_mod, mod->attrs->dict); - return modified_mod; + // Step 4. Copy over updated functions and return. + mod->Update(updates); + return mod; } private: - explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + struct Replacement { + GlobalVar fused_tir_gvar; + Function original_function; + Array inplace_indices; + }; + + explicit TIRFuseMutator( + std::unordered_map replacements) + : replacements_(replacements) {} using ExprMutator::VisitExpr_; @@ -1035,92 +1051,86 @@ class TIRFuseMutator : public ExprMutator { Call call = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(op))); - if (call->op->IsInstance()) { - // Case 1. It is a relax cross function call - GlobalVar old_gv = Downcast(call->op); - auto relax_func = Downcast(mod_->Lookup(old_gv)); - auto it = fused_tir_funcs_.find(old_gv); - if (it != fused_tir_funcs_.end()) { - const tir::PrimFunc& fused_tir = (*it).second; - // Case 1.1. It calls a primitive relax function, update the call into a call_tir - GlobalVar fused_tir_gv = this->builder_->AddFunction(fused_tir, old_gv->name_hint); - // Step a. Flatten all args since call_tir does not support Tuple value. - Array arg_list; - Array tir_vars; - for (size_t i = 0; i < call->args.size(); ++i) { - auto arg = call->args[i]; - auto sinfo = GetStructInfo(arg); - - ICHECK(!relax_func->params[i]->struct_info_->IsInstance() && - !sinfo.as()) - << "InternalError: " - << "All tuple parameters should be expanded before this point in FuseTIR. " - << "However, argument " << arg << " with struct info " << arg->struct_info_ - << " is passed as argument " << i << " to Primitive Relax function " << old_gv - << ", which expects parameter " << relax_func->params[i] << " to have struct info " - << relax_func->params[i]->struct_info_; - - if (const auto* shape = sinfo.as()) { - CHECK(shape->values.defined()) - << "FuseTIR requires all shape input has struct_info value."; - for (const PrimExpr& prim_value : shape->values.value()) { - CHECK(prim_value->IsInstance()) - << "All shape inputs are expected to be single tir var."; - tir_vars.push_back(prim_value); - } - } else if (const auto* prim_value = sinfo.as()) { - CHECK(prim_value->value.defined()) - << "FuseTIR requires all R.Prim arguments to have a known value."; - PrimExpr expr = prim_value->value.value(); - CHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " - "arguments to provide a single tir::Var."; - tir_vars.push_back(expr); - - } else { - arg_list.push_back(arg); - } - } - // Step b. Create call_tir or call_tir_inplace - Array call_args = {fused_tir_gv, Tuple(arg_list)}; - if (!tir_vars.empty()) { - call_args.push_back(ShapeExpr(tir_vars)); - } - Op call_op = call_tir_op_; - Attrs call_attrs = call->attrs; - if (auto it = inplace_indices_.find(old_gv); it != inplace_indices_.end()) { - call_op = call_tir_inplace_op_; - auto inplace_attrs = make_object(); - inplace_attrs->inplace_indices = (*it).second; - call_attrs = Attrs(inplace_attrs); + auto opt_gvar = call->op.as(); + if (!opt_gvar) { + // Case 1. The Call isn't a relax-to-relax function call, no need to update. + return call; + } + GlobalVar old_gvar = opt_gvar.value(); + + auto it = replacements_.find(old_gvar); + if (it == replacements_.end()) { + // Case 2. The callee function is not a primitive relax + // function, no need to update. + return call; + } + const Replacement& replacement = it->second; + const GlobalVar& fused_tir_gv = replacement.fused_tir_gvar; + const Function& relax_func = replacement.original_function; + + // Case 3. It calls a primitive relax function, update the call + // into a call_tir or call_tir_inplace. + + // Step a. Collect all relax/symbolic arguments. Tuple arguments + // are not supported by PrimFunc, so this step verifies that + // ExpandTupleArguments has already removed them. + Array arg_list; + Array tir_vars; + for (size_t i = 0; i < call->args.size(); ++i) { + auto arg = call->args[i]; + auto sinfo = GetStructInfo(arg); + + ICHECK(!relax_func->params[i]->struct_info_->IsInstance() && + !sinfo.as()) + << "InternalError: " + << "All tuple parameters should be expanded before this point in FuseTIR. " + << "However, argument " << arg << " with struct info " << arg->struct_info_ + << " is passed as argument " << i << " to Primitive Relax function " << old_gvar + << ", which expects parameter " << relax_func->params[i] << " to have struct info " + << relax_func->params[i]->struct_info_; + + if (const auto* shape = sinfo.as()) { + CHECK(shape->values.defined()) << "FuseTIR requires all shape input has struct_info value."; + for (const PrimExpr& prim_value : shape->values.value()) { + CHECK(prim_value->IsInstance()) + << "All shape inputs are expected to be single tir var."; + tir_vars.push_back(prim_value); } - return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); + } else if (const auto* prim_value = sinfo.as()) { + CHECK(prim_value->value.defined()) + << "FuseTIR requires all R.Prim arguments to have a known value."; + PrimExpr expr = prim_value->value.value(); + CHECK(expr->IsInstance()) << "FuseTIR currently requires all R.Prim " + "arguments to provide a single tir::Var."; + tir_vars.push_back(expr); + } else { - // Case 1.2. The callee function is not primitive, nothing to do. - return call; - } - } else if (call->op == call_tir_op_ || call->op == call_tir_inplace_op_) { - // Case 2. It is a call_tir or call_tir_inplace, re-emit the PrimFunc. - if (const auto* gv = call->args[0].as()) { - tir::PrimFunc func = Downcast(mod_->Lookup(GetRef(gv))); - GlobalVar new_gv = this->builder_->AddFunction(func, gv->name_hint); - Array new_args = call->args; - new_args.Set(0, new_gv); - return Call(call->op, new_args, call->attrs, call->sinfo_args, call->span); + arg_list.push_back(arg); } } - // Case 3. CallNode in other types. Leave it as it is. - return call; + // Step b. Create call_tir or call_tir_inplace + Array call_args = {fused_tir_gv, Tuple(arg_list)}; + if (!tir_vars.empty()) { + call_args.push_back(ShapeExpr(tir_vars)); + } + Op call_op = call_tir_op_; + Attrs call_attrs = call->attrs; + if (replacement.inplace_indices.size()) { + call_op = call_tir_inplace_op_; + auto inplace_attrs = make_object(); + inplace_attrs->inplace_indices = replacement.inplace_indices; + call_attrs = Attrs(inplace_attrs); + } + return Call(call_op, call_args, call_attrs, {GetStructInfo(call)}); } private: - /*! \brief The IRModule */ - const IRModule& mod_; - /*! \brief The map from global var of primitive relax function to generated prim func. */ - Map fused_tir_funcs_; - /*! \brief The map from global var of primitive relax function to in-place indices - * (if there are any). */ - Map> inplace_indices_; + /*! \brief The map from global var to how it should be replaced + * + * Has one entry for each primitive relax function in the IRModule. + */ + std::unordered_map replacements_; }; IRModule FuseTIR(IRModule mod) { @@ -1142,6 +1152,7 @@ Pass FuseTIR() { ExpandTupleArguments(), RemoveUnusedParameters(), inner_pass, + DeadCodeElimination(), }, "FuseTIR"); } diff --git a/tests/python/relax/test_transform_fuse_tir.py b/tests/python/relax/test_transform_fuse_tir.py index c0a6f4448b5c..90baeaad04bb 100644 --- a/tests/python/relax/test_transform_fuse_tir.py +++ b/tests/python/relax/test_transform_fuse_tir.py @@ -2254,5 +2254,65 @@ def main( _check(Module, Expected) +def test_private_nonprimitive_func(): + """Input IRModule may contain calls to non-primitive functions + + This is a regression test. Prior implementations did not preserve + relax-to-relax function calls. + """ + + @I.ir_module + class Before: + @R.function + def main( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = Before + with R.dataflow(): + gv = cls.fused_func(input_ids, input_embeds) + R.output(gv) + return gv + + @R.function(private=True) + def fused_func( + input_ids: R.Tensor((1,), dtype="int32"), + input_embeds: R.Tensor((4096, 4096), dtype="float16"), + ) -> R.Tensor((1, 4096), dtype="float16"): + cls = Before + with R.dataflow(): + lv = R.call_tir( + cls.add, (input_embeds,), out_sinfo=R.Tensor((4096, 4096), dtype="float16") + ) + gv = R.call_tir( + cls.take, (lv, input_ids), out_sinfo=R.Tensor((1, 4096), dtype="float16") + ) + R.output(gv) + return gv + + @T.prim_func(private=True) + def add( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + Out: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + ): + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("add"): + vi, vj = T.axis.remap("SS", [i, j]) + Out[vi, vj] = A[vi, vj] + T.float16(1.0) + + @T.prim_func(private=True) + def take( + A: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + B: T.Buffer((T.int64(1),), "int32"), + T_take: T.Buffer((T.int64(1), T.int64(4096)), "float16"), + ): + for ax0, ax1 in T.grid(T.int64(1), T.int64(4096)): + with T.block("T_take"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T_take[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + _check(Before, Before) + + if __name__ == "__main__": tvm.testing.main()