Skip to content

Commit

Permalink
[QoL][IR] Provide default constructor for NameSupply/GlobalVarSupply (#…
Browse files Browse the repository at this point in the history
…17135)

Prior to this commit, a `tvm::NameSupply` needed to be constructed
with an explicit `const String& prefix` argument.  Omitting this
argument would fall back to the default constructor provided by the
`TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS` macro, producing a
`NameSupply` holding a nullptr.  This then leads to a segfault when
the null `NameSupply` is used.

The vast majority of usages of `NameSupply::NameSupply` (29 out of 31)
initialize it with an empty `prefix` string.  The remaining two use
cases initialize it with a non-empty `prefix` string.  There are no
cases in which a null `NameSupply` is initialized.

This commit updates `NameSupply` to use the
`TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS` macro instead of
`TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS`.  This allows the default
constructor to provide the common usage of a `NameSupply` with an
empty prefix, rather than the error-prone usage of a null `NameSupply`

A similar change is also made for `GlobalVarSupply`, as the majority
of its uses also default to an empty prefix (11 out of 13).
  • Loading branch information
Lunderberg authored Jul 12, 2024
1 parent 0346266 commit f60b08c
Show file tree
Hide file tree
Showing 23 changed files with 34 additions and 40 deletions.
7 changes: 4 additions & 3 deletions include/tvm/ir/global_var_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GlobalVarSupplyNode : public Object {
/*!
* \brief Empty constructor. Will use an empty NameSupply.
*/
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply("")) {}
GlobalVarSupplyNode() : GlobalVarSupplyNode(NameSupply()) {}

/*!
* \brief Constructor.
Expand Down Expand Up @@ -100,7 +100,7 @@ class GlobalVarSupply : public ObjectRef {
* \param name_supply The NameSupply to be used when generating new GlobalVars.
* \param name_to_var_map An optional map.
*/
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply,
TVM_DLL explicit GlobalVarSupply(const NameSupply& name_supply = NameSupply(),
std::unordered_map<std::string, GlobalVar> name_to_var_map = {});

/*!
Expand All @@ -117,7 +117,8 @@ class GlobalVarSupply : public ObjectRef {
*/
TVM_DLL explicit GlobalVarSupply(const IRModule module);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef, GlobalVarSupplyNode);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(GlobalVarSupply, ObjectRef,
GlobalVarSupplyNode);
};

} // namespace tvm
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/name_supply.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class NameSupply : public ObjectRef {
* \param prefix The prefix to be used with this NameSupply.
* \param name_map An optional map.
*/
TVM_DLL explicit NameSupply(const String& prefix,
TVM_DLL explicit NameSupply(const String& prefix = "",
std::unordered_map<std::string, int> name_map = {});

/*!
Expand All @@ -129,7 +129,7 @@ class NameSupply : public ObjectRef {
TVM_DLL explicit NameSupply(Iter begin, Iter end, Lambda f)
: NameSupply("", GetNameMap(begin, end, f)) {}

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(NameSupply, ObjectRef, NameSupplyNode);

private:
template <typename Iter, typename Lambda>
Expand Down
3 changes: 1 addition & 2 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,7 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
auto pass_ctx = tvm::transform::PassContext::Current();

auto mod = ScheduleToModule(sch, Array<ObjectRef>{tensors.begin(), tensors.end()}, name,
std::unordered_map<te::Tensor, te::Buffer>(),
GlobalVarSupply(NameSupply("")));
std::unordered_map<te::Tensor, te::Buffer>(), GlobalVarSupply());

bool disable_vectorize =
pass_ctx->GetConfig<Bool>("tir.disable_vectorize", Bool(false)).value();
Expand Down
2 changes: 1 addition & 1 deletion src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
/*! \brief Print the current indent spaces. */
inline void PrintIndent();
/*! \brief NameSupply for allocated ids. */
NameSupply ids_allocated = NameSupply("");
NameSupply ids_allocated;
/*!
* \brief Keys are either (tensors, value_index) or (variables, 0).
* Values are the corresponding IDs.*/
Expand Down
6 changes: 2 additions & 4 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,7 @@ TVM_REGISTER_GLOBAL("driver.schedule_to_module")
c_binds.insert({kv.first, kv.second});
}
}
IRModule mod =
ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")));
IRModule mod = ScheduleToModule(std::move(sch), args, name, c_binds, GlobalVarSupply());
return mod;
});

Expand Down Expand Up @@ -400,8 +399,7 @@ TVM_REGISTER_GLOBAL("driver.lower_schedule")
c_binds.insert({kv.first, kv.second});
}
}
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(NameSupply("")),
simple_mode);
return LowerSchedule(std::move(sch), args, name, c_binds, GlobalVarSupply(), simple_mode);
});

/**
Expand Down
2 changes: 1 addition & 1 deletion src/ir/global_var_supply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ std::string GetModuleName(const IRModule& module) {
return module->GetAttr<String>(tvm::attr::kModuleName).value_or("tvmgen_default");
}

GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply(NameSupply("")) {
GlobalVarSupply::GlobalVarSupply(const Array<IRModule>& modules) : GlobalVarSupply() {
if (!modules.empty()) {
IRModule first_mod = modules.front();
this->operator->()->name_supply_->prefix_ = GetModuleName(first_mod);
Expand Down
2 changes: 1 addition & 1 deletion src/relax/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class CodegenCutlass : public relax::MemoizedExprTranslator<OutputType>,
public relay::contrib::CodegenCBase {
public:
CodegenCutlass(const std::string& id, const Map<Var, Expr>& bindings)
: ext_func_id_(id), bindings_(bindings), name_sup_("") {}
: ext_func_id_(id), bindings_(bindings) {}

void AddParm(Var param) {
ext_func_args_.push_back(param);
Expand Down
3 changes: 1 addition & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ namespace relax {
//---------------------------------------
class BlockBuilderImpl : public BlockBuilderNode {
public:
explicit BlockBuilderImpl(IRModule context_mod)
: name_supply_(""), context_mod_(std::move(context_mod)) {}
explicit BlockBuilderImpl(IRModule context_mod) : context_mod_(std::move(context_mod)) {}

~BlockBuilderImpl() {
if (!block_stack_.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/allocate_workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class ExternFunctionRewriter : ExprMutator {
using ExprMutator::VisitExpr_;

ExternFunctionRewriter(IRModule mod, size_t max_workspace_size)
: ExprMutator(mod), name_sup_(""), max_workspace_size_(max_workspace_size) {}
: ExprMutator(mod), max_workspace_size_(max_workspace_size) {}

std::unordered_map<const GlobalVarNode*, Function> Run() {
std::unordered_map<const GlobalVarNode*, Function> ret;
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/normalize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class GlobalVarNormalizer : private ExprMutator {
}

private:
explicit GlobalVarNormalizer(const IRModule& m) : ExprMutator(), module_(m), name_supply_("") {}
explicit GlobalVarNormalizer(const IRModule& m) : ExprMutator(), module_(m) {}

using ExprMutator::VisitExpr_;

Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
/*! \brief function metadata */
Map<String, FunctionInfo> function_metadata_;
/*! \brief NameSupply */
NameSupply name_supply_ = NameSupply("");
NameSupply name_supply_;
};

class GraphExecutorCodegenModule : public runtime::ModuleNode {
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,

std::vector<std::tuple<std::string, Function, IRModule>> lower_results;

NameSupply constant_name_supply("");
NameSupply constant_name_supply;

PostOrderVisit(mod->Lookup("main"), [&](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Expand Down Expand Up @@ -129,7 +129,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(IRModule mod, Target target,

// Tasks are extracted via post order visit, return the reversed list.
std::reverse(tasks.begin(), tasks.end());
NameSupply name_supply = NameSupply("");
NameSupply name_supply;
for (ExtractedTask task : tasks) {
task->task_name = name_supply->FreshName(task->task_name);
}
Expand Down
5 changes: 2 additions & 3 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ TVM_REGISTER_OBJECT_TYPE(TECompilerNode);
class TECompilerImpl : public TECompilerNode {
public:
explicit TECompilerImpl(Optional<IRModule> opt_mod, Optional<String> opt_mod_name)
: global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("")))),
constant_name_supply_(NameSupply("")) {
: global_var_supply_(GlobalVarSupply(NameSupply(opt_mod_name.value_or("")))) {
// Make sure we don't collide with any existing globals in the module.
if (opt_mod) {
for (const auto& kv : opt_mod.value()->functions) {
Expand All @@ -160,7 +159,7 @@ class TECompilerImpl : public TECompilerNode {

// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
CCacheValue value = LowerInternal(key, GlobalVarSupply());
if (value->packed_func != nullptr) {
return value->packed_func;
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function&
}

tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) {
auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply(""));
auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply());
(void)_; // to suppress -Werror=unused-variable warning
if (f_opt) {
return f_opt.value();
Expand All @@ -1143,7 +1143,7 @@ TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc")

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt, NameSupply(""));
LowerToTECompute lower_te_compute(tgt, NameSupply());
auto outputs = lower_te_compute.Lower(prim_func);
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
/*! \brief A specialization of PrimFuncFor, meant to be used when the names of constants do not
* matter. */
inline CachedFunc PrimFuncFor(const Function& source_func, const Target& target) {
return PrimFuncFor(source_func, target, GlobalVarSupply(NameSupply("")), NameSupply(""));
return PrimFuncFor(source_func, target, GlobalVarSupply(), NameSupply());
}

CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
std::unordered_map<GlobalVar, String> internal_functions_;

/* \brief Name supply to generate unique function names */
NameSupply func_name_supply_{""};
NameSupply func_name_supply_;
};

} // namespace codegen
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_source_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace tvm {
namespace codegen {

void CodeGenSourceBase::ClearFuncState() {
name_supply_ = NameSupply("");
name_supply_ = NameSupply();
ssa_assign_map_.clear();
var_idmap_.clear();
scope_mark_.clear();
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/codegen_source_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class CodeGenSourceBase {
/*! \brief name of each variable */
std::unordered_map<const tir::VarNode*, std::string> var_idmap_;
/*! \brief NameSupply for allocation */
NameSupply name_supply_ = NameSupply("");
NameSupply name_supply_;

private:
/*! \brief assignment map of ssa */
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct CreateFuncInfo {
/*! \brief The buffers should be allocated at function root. */
Array<Buffer> root_alloc;
/*! \brief The NameSupply to make block name unique. */
NameSupply name_supply = NameSupply("");
NameSupply name_supply;

String FreshName(String base_name) { return name_supply->FreshName(base_name); }

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ IndexMap IndexMap::RenameVariables(
const std::function<Optional<String>(const Var& var)>& f_name_map) const {
std::unordered_set<std::string> used_names;
Map<Var, Var> var_remap;
NameSupply name_supply{""};
NameSupply name_supply;
const IndexMapNode* n = this->get();
if (f_name_map != nullptr) {
// Collect variables with pre-defined names provided by f_name_map.
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ TEST(BuildModule, Basic) {

auto target = Target("llvm");

auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply(NameSupply("")));
auto lowered = LowerSchedule(s, args, "func", binds, GlobalVarSupply());
auto module = build(lowered, target, Target());

auto mali_target = Target("opencl -model=Mali-T860MP4@800Mhz -device=mali");
Expand Down Expand Up @@ -121,7 +121,7 @@ TEST(BuildModule, Heterogeneous) {
auto args2 = Array<Tensor>({copy, C, elemwise_sub});

std::unordered_map<Tensor, Buffer> binds;
GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
GlobalVarSupply global_var_supply = GlobalVarSupply();
auto lowered_s1 = LowerSchedule(fcreate_s1(), args1, "elemwise_add", binds, global_var_supply);
auto lowered_s2 = LowerSchedule(fcreate_s2(), args2, "elemwise_sub", binds, global_var_supply);
Map<tvm::Target, IRModule> inputs = {{target_cuda, lowered_s1}, {target_llvm, lowered_s2}};
Expand Down
6 changes: 2 additions & 4 deletions tests/cpp/c_codegen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ TEST(CCodegen, MainFunctionOrder) {
auto args = Array<Tensor>({A, B, elemwise_add});

std::unordered_map<Tensor, Buffer> binds;
auto lowered =
LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply(NameSupply("")));
auto lowered = LowerSchedule(fcreate(), args, "elemwise_add", binds, GlobalVarSupply());
Map<tvm::Target, IRModule> inputs = {{target_c, lowered}};
runtime::Module module = build(inputs, Target());
Array<String> functions = module->GetFunction("get_func_names", false)();
Expand Down Expand Up @@ -82,8 +81,7 @@ auto BuildLowered(std::string op_name, tvm::Target target) {

auto args = Array<Tensor>({A, B, op});
std::unordered_map<Tensor, Buffer> binds;
auto lowered_s =
LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply(NameSupply("")));
auto lowered_s = LowerSchedule(fcreate_s(), args, op_name, binds, GlobalVarSupply());
return lowered_s;
}

Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/name_supply_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
using namespace tvm;

NameSupply preambleNameSupply() {
NameSupply name_supply = NameSupply("prefix");
NameSupply name_supply("prefix");
name_supply->FreshName("test");
return name_supply;
}
Expand Down Expand Up @@ -74,7 +74,7 @@ TEST(NameSupply, ReserveName) {
}

GlobalVarSupply preambleVarSupply() {
GlobalVarSupply global_var_supply = GlobalVarSupply(NameSupply(""));
GlobalVarSupply global_var_supply;
global_var_supply->FreshGlobal("test");
return global_var_supply;
}
Expand Down

0 comments on commit f60b08c

Please sign in to comment.