From 25b76c0e7bdbc74e9802734e03ec0d74d43f6da1 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 1 Apr 2024 21:09:55 -0700 Subject: [PATCH] [Relax] Share storage allocs among functions after cuda graph rewriting (#16830) --- src/relax/transform/rewrite_cuda_graph.cc | 386 +++++++++++++----- .../test_transform_rewrite_cuda_graph.py | 241 ++++++++++- 2 files changed, 518 insertions(+), 109 deletions(-) diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 25b229ebce579..d0e20ffd766b7 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -49,17 +49,19 @@ * 2. Lift the regions identified in step 1 to a separate function and rewrite the original function * with `CUDAGraphRewriter`. */ - +#include #include #include #include #include #include +#include +#include + #include "../../support/arena.h" #include "../../support/ordered_set.h" #include "../../support/utils.h" - namespace tvm { namespace relax { @@ -79,9 +81,10 @@ struct LiftedFunctionRewritePlan { // Variable remappings between the original function and the lifted function // The bindings in the original function that are lifted - std::unordered_set lifted_bindings; + std::vector lifted_bindings; // The corresponding binding vars in the original function of the outputs of the lifted function - std::vector outputs; + // to the index of the element in the output tuple of the lifted function. + std::unordered_map outputs; // The corresponding binding vars in the original function of the inputs of the lifted function std::vector inputs; // The tir vars in the original function that are propagated to the lifted function @@ -170,13 +173,68 @@ class FuncBuilder : public ExprMutator { Map tir_var_remap_; }; +// Collect the storage objects that are used as the function output +class OutputStorageCollector : public ExprVisitor { + public: + static std::unordered_set Collect(const Function& func) { + OutputStorageCollector collector; + collector.VisitExpr(func); + return std::move(collector.output_storages_); + } + + private: + void VisitExpr_(const SeqExprNode* seq_expr) final { + auto output_vars = FreeVars(seq_expr->body); + for (const auto& var : output_vars) { + output_vars_.insert(var.get()); + } + // Visit the blocks in reverse order for backward propagation + for (auto it = seq_expr->blocks.rbegin(); it != seq_expr->blocks.rend(); ++it) { + VisitBindingBlock(*it); + } + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { + static const auto& mem_alloc_tensor_op = Op::Get("relax.memory.alloc_tensor"); + if (output_vars_.count(binding->var.get()) && call->op.same_as(mem_alloc_tensor_op)) { + output_storages_.insert(call->args[0].as()); + } + } + + void VisitBindingBlock_(const BindingBlockNode* binding_block) override { + // Visit the bindings in reverse order + for (auto it = binding_block->bindings.rbegin(); it != binding_block->bindings.rend(); ++it) { + VisitBinding(*it); + } + } + + void VisitBinding_(const VarBindingNode* binding, const VarNode* var) final { + if (output_vars_.count(binding->var.get())) { + output_vars_.insert(var); + } + } + + void VisitBinding_(const VarBindingNode* binding, const TupleNode* tuple) final { + if (output_vars_.count(binding->var.get())) { + for (const auto& field : tuple->fields) { + output_vars_.insert(field.as()); + } + } + } + + std::unordered_set output_storages_; + std::unordered_set output_vars_; +}; + /*! * \brief The planner for rewriting the function to enable cuda graph capturing. */ class CUDAGraphRewritePlanner : public ExprVisitor { public: - explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {} - std::vector Plan() { + explicit CUDAGraphRewritePlanner(const IRModule& mod, support::Arena* arena) + : mod_(mod), arena_(arena) {} + std::pair, std::vector> + Plan() { for (const auto& pair : mod_->functions) { if (pair.second->IsInstance()) { // If a function has the num_input attribute, the last func->params.size() - num_inputs @@ -188,41 +246,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } } CollectSymbolicVarHints(func); + disabled_storage_vars_ = OutputStorageCollector::Collect(func); VisitExpr(func); } } - std::vector plans; - - auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) -> LiftedFunctionRewritePlan { - LiftedFunctionRewritePlan plan; - plan.is_alloc = true; - plan.func = region->Build(); + auto region_to_plan = [&](FuncBuilder* region, bool is_alloc) -> LiftedFunctionRewritePlan* { + auto* plan = arena_->make(); + plan->is_alloc = true; + plan->func = region->Build(); ICHECK(region->size()); - plan.launch_point = region->bindings_.front()->var.get(); - plan.is_alloc = is_alloc; - for (const auto* binding : region->bindings_) { - plan.lifted_bindings.insert(binding->var.get()); - } + plan->launch_point = region->bindings_.front()->var.get(); + plan->is_alloc = is_alloc; + plan->lifted_bindings = std::move(region->bindings_); if (region->shape_expr_inputs_.size()) { Array tir_vars; for (const auto* var : region->shape_expr_inputs_) { tir_vars.push_back(GetRef(var)); } - plan.propogated_tir_vars = ShapeExpr(tir_vars); + plan->propogated_tir_vars = ShapeExpr(tir_vars); + } + plan->inputs.assign(region->inputs_.begin(), region->inputs_.end()); + for (const auto* var : region->outputs_) { + plan->outputs[var] = plan->outputs.size(); } - plan.inputs.assign(region->inputs_.begin(), region->inputs_.end()); - plan.outputs.assign(region->outputs_.begin(), region->outputs_.end()); return plan; }; - for (auto* region : alloc_storages_) { - plans.push_back(region_to_plan(region, /*is_alloc=*/true)); - } - - for (auto* region : captured_regions_) { - plans.push_back(region_to_plan(region, /*is_alloc=*/false)); - } - return plans; + std::vector alloc_plans, capture_plans; + alloc_plans.reserve(alloc_storages_.size()); + capture_plans.reserve(captured_regions_.size()); + std::transform(alloc_storages_.begin(), alloc_storages_.end(), std::back_inserter(alloc_plans), + [&](FuncBuilder* region) { return region_to_plan(region, /*is_alloc=*/true); }); + std::transform(captured_regions_.begin(), captured_regions_.end(), + std::back_inserter(capture_plans), + [&](FuncBuilder* region) { return region_to_plan(region, /*is_alloc=*/false); }); + return {std::move(alloc_plans), std::move(capture_plans)}; } /*! @@ -241,31 +299,36 @@ class CUDAGraphRewritePlanner : public ExprVisitor { *\brief Start a new static region. This method should be called when encountering a * CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters. */ - void StartRegion() { current_.capture_builder = arena_.make(); } + void StartRegion() { current_block_scope_.capture_builder = arena_->make(); } /*! * \brief Finish a static region. This method should be called when non-static bindings or * unsupported operations are encountered. */ void EndRegion() { - if (current_.capture_builder && current_.capture_builder->size()) { - captured_regions_.emplace_back(current_.capture_builder); + if (current_block_scope_.capture_builder && current_block_scope_.capture_builder->size()) { + captured_regions_.emplace_back(current_block_scope_.capture_builder); } - current_.capture_builder = nullptr; + current_block_scope_.capture_builder = nullptr; + } + + void VisitExpr_(const FunctionNode* func) final { + current_function_scope_.alloc_storage_builder = arena_->make(); + ExprVisitor::VisitExpr_(func); + if (current_function_scope_.alloc_storage_builder->outputs_.size()) { + alloc_storages_.emplace_back(current_function_scope_.alloc_storage_builder); + } + current_function_scope_.alloc_storage_builder = nullptr; } void VisitBindingBlock_(const BindingBlockNode* binding_block) final { - Scope new_scope; - std::swap(new_scope, current_); - current_.alloc_storage_builder = arena_.make(); + BindingBlockScope new_scope; + std::swap(new_scope, current_block_scope_); for (const auto& binding : binding_block->bindings) { VisitBinding(binding); } EndRegion(); - if (current_.alloc_storage_builder->outputs_.size()) { - alloc_storages_.emplace_back(current_.alloc_storage_builder); - } - std::swap(new_scope, current_); + std::swap(new_scope, current_block_scope_); } void VisitBinding_(const VarBindingNode* binding, const CallNode* call) final { @@ -273,8 +336,10 @@ class CUDAGraphRewritePlanner : public ExprVisitor { static const auto& builtin_alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); static const auto& call_builtin_with_ctx_op = Op::Get("relax.call_builtin_with_ctx"); - if (call->op.same_as(mem_alloc_storage_op) && IsStaticAllocStorage(binding)) { - AddStaticBinding(binding, /*is_alloc_storage=*/true); + if (call->op.same_as(mem_alloc_storage_op)) { + if (IsStaticAllocStorage(binding)) { + AddStaticBinding(binding, /*is_alloc_storage=*/true); + } return; } else if (call->op.same_as(builtin_alloc_tensor_op)) { return; @@ -321,7 +386,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor { } return false; }(); - if (current_.capture_builder == nullptr && is_kernel_launch) { + if (current_block_scope_.capture_builder == nullptr && is_kernel_launch) { StartRegion(); } AddStaticBinding(binding, /*is_alloc_storage=*/false); @@ -335,24 +400,24 @@ class CUDAGraphRewritePlanner : public ExprVisitor { void MarkAsFuncInput(const std::vector& vars, const std::vector& tir_vars = {}) { - if (current_.capture_builder == nullptr) { + if (current_block_scope_.capture_builder == nullptr) { return; } for (const VarNode* var : vars) { auto it = binding_to_region_.find(var); - if (it == binding_to_region_.end() || it->second != current_.capture_builder) { - current_.capture_builder->MarkInput(var); + if (it == binding_to_region_.end() || it->second != current_block_scope_.capture_builder) { + current_block_scope_.capture_builder->MarkInput(var); } } for (const tir::VarNode* tir_var : tir_vars) { - current_.capture_builder->MarkShapeExprInput(tir_var); + current_block_scope_.capture_builder->MarkShapeExprInput(tir_var); } } void MarkAsFuncOutput(const std::vector& vars) { for (const VarNode* var : vars) { if (auto it = binding_to_region_.find(var); - it != binding_to_region_.end() && it->second != current_.capture_builder) { + it != binding_to_region_.end() && it->second != current_block_scope_.capture_builder) { it->second->MarkOutput(var); } } @@ -476,6 +541,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor { private: bool IsStaticAllocStorage(const VarBindingNode* binding) { + if (disabled_storage_vars_.count(binding->var.get())) { + return false; + } // Check if the allocation has constant shape const auto* alloc_storage_call = binding->value.as(); auto shape = Downcast(alloc_storage_call->args[0]); @@ -491,33 +559,41 @@ class CUDAGraphRewritePlanner : public ExprVisitor { */ void AddStaticBinding(const VarBindingNode* binding, bool is_alloc_storage) { if (is_alloc_storage) { - current_.alloc_storage_builder->AddBinding(binding); - binding_to_region_[binding->var.get()] = current_.alloc_storage_builder; - } else if (current_.capture_builder != nullptr) { + current_function_scope_.alloc_storage_builder->AddBinding(binding); + binding_to_region_[binding->var.get()] = current_function_scope_.alloc_storage_builder; + } else if (current_block_scope_.capture_builder != nullptr) { // Add the binding if the capture builder exists. It is possible that capture builder is // null when it is not capturing. This is the case that there are not yet any kernel launches // encountered, in this case static bindings (e.g. binding of other non-kernel-launch // operations) are marked but are not lifted. - current_.capture_builder->AddBinding(binding); - binding_to_region_[binding->var.get()] = current_.capture_builder; + current_block_scope_.capture_builder->AddBinding(binding); + binding_to_region_[binding->var.get()] = current_block_scope_.capture_builder; } static_vars_.emplace(binding->var.get()); } - /*! \brief The states of the current scope (the BindingBlock) which is a pair of FuncBuilder. + /*! \brief The states of the current scope (the BindingBlock) which is a FuncBuilder. * The FuncBuilder are initialized with nullptr, meaning the planner is currently not doing any * lifting. They are initialized lazily when a binding that can be lifted is encountered. * They are reset to nullptr when an unsupported operation is encountered. */ - struct Scope { + struct BindingBlockScope { + FuncBuilder* capture_builder = nullptr; // The builder for the capture function + }; + + /*! \brief The states of the current function scope which is a FuncBuilder to build the storage + * allocation function. + */ + struct FunctionScope { FuncBuilder* alloc_storage_builder = nullptr; // The builder for the allocation function - FuncBuilder* capture_builder = nullptr; // The builder for the capture function }; // The IRModule IRModule mod_; - // States of the current scope - Scope current_; + // States of the current block scope + BindingBlockScope current_block_scope_; + // States of the current function scope + FunctionScope current_function_scope_; // Variables whose buffer address is fixed std::unordered_set static_vars_; // The name of the variables that are allowed to be symbolic @@ -529,64 +605,183 @@ class CUDAGraphRewritePlanner : public ExprVisitor { std::vector captured_regions_; // The regions for allocation. std::vector alloc_storages_; + // The binding variables that are not allowed to be captured. + std::unordered_set disabled_storage_vars_; // The arena. - support::Arena arena_; + support::Arena* arena_; }; +/*! + * \brief Merge storage allocations from different functions by reusing the largest allocation that + * can be shared among all the functions. The original rewriting plans are updated in-place to use + * the merged storage allocations. + * + * When multiple functions are rewritten to be executed with CUDA graph, the storage allocations + * from different functions can be reused. This functions merge multiple storage allocations + * functions to a single function that allocates the sufficiently large storage to be shared among + * all the functions. + * + * \param alloc_plans The allocation plans of the functions to be merged. + * \return The new allocation function that merges the storage allocations. + */ +Function MergeAllocationPlans(const std::vector& alloc_plans) { + // The storage record that contains the size of the storage allocation and the binding of the + // storage allocation. + struct StorageRecord { + // The size of the storage object in bytes + int64_t size; + // The binding of the storage allocation + const VarBindingNode* binding; + // The source rewriting plan that the storage record is from + LiftedFunctionRewritePlan* src; + + bool operator<(const StorageRecord& other) const { return size < other.size; } + }; + // Using an (ordered) map to make sure the result is deterministic + std::map>> storage_records; + static const auto& mem_alloc_storage_op = Op::Get("relax.memory.alloc_storage"); + + // Collect the storage records for each storage scope. Storage records are stored separately + // for each original function. + for (int plan_id = 0; plan_id < static_cast(alloc_plans.size()); ++plan_id) { + LiftedFunctionRewritePlan* plan = alloc_plans[plan_id]; + ICHECK(plan->is_alloc); + for (const VarBindingNode* binding : plan->lifted_bindings) { + // Extract the stroage record from the Call expr. + Call alloc_storage = Downcast(binding->value); + ICHECK(alloc_storage->op.same_as(mem_alloc_storage_op)); + auto storage_shape = Downcast(alloc_storage->args[0]); + ICHECK_EQ(storage_shape->values.size(), 1); + int64_t size = Downcast(storage_shape->values[0])->value; + int64_t virtual_device_id = + Downcast(Downcast(alloc_storage->args[1])->value)->value; + ICHECK_EQ(virtual_device_id, 0); + String storage_scope = Downcast(alloc_storage->args[2])->value; + auto [it, _] = storage_records.try_emplace(storage_scope, alloc_plans.size()); + it->second[plan_id].emplace_back(StorageRecord{size, binding, plan}); + } + } + + // Merge the storage records within each storage scope. + // This is achieved by sorting the storage records in descending order of size and then merging + // storage allocations from different functions to the largest allocation that can be shared + // among all the functions. + // This assumes that multiple functions will not run concurrently. + std::vector merged_allocs; + // Merge the storage records within each storage scope. + for (auto& [storage_scope, curr_scope_records] : storage_records) { + // The number of storages needed for the current storage scope, which is the maximum number of + // storage records among all the functions. + int num_storages = 0; + for (auto& records_of_plan : curr_scope_records) { + // Sort descending by size, preserve the original order if the sizes are equal. + std::stable_sort(records_of_plan.rbegin(), records_of_plan.rend()); + num_storages = std::max(num_storages, static_cast(records_of_plan.size())); + } + // The iterators to scan the storage records of all functions from the left to the right + // at the same time. + std::vector iters(alloc_plans.size(), 0); + for (int i = 0; i < num_storages; i++) { + // The storage records from different functions that can be merged to the same storage. + std::vector to_merge; + for (int plan_index = 0; plan_index < static_cast(curr_scope_records.size()); + plan_index++) { + if (iters[plan_index] < static_cast(curr_scope_records[plan_index].size())) { + to_merge.push_back(curr_scope_records[plan_index][iters[plan_index]++]); + } + } + const StorageRecord& largest_storage = + *std::max_element(to_merge.begin(), to_merge.end(), + [](const auto& lhs, const auto& rhs) { return lhs < rhs; }); + // Merge the records to the largest allocation by updating the index of the output element + // to that of the new allocation function. + int storage_index = static_cast(merged_allocs.size()); + for (const StorageRecord& rec : to_merge) { + auto* plan = rec.src; + plan->outputs.at(rec.binding->var.get()) = storage_index; + } + merged_allocs.push_back(largest_storage.binding); + } + } + // Create the new allocation function for the merged allocations. + FuncBuilder builder; + for (const auto* binding : merged_allocs) { + builder.AddBinding(binding); + builder.MarkOutput(binding->var.get()); + } + return builder.Build(); +} + /*! \brief The rewriter for CUDA graph */ class CUDAGraphRewriter : public ExprMutator { public: explicit CUDAGraphRewriter(const IRModule& mod) : ExprMutator(mod) {} IRModule Rewrite() { - CUDAGraphRewritePlanner planner(builder_->GetContextIRModule()); - auto plans = planner.Plan(); - for (const auto& plan : plans) { - subgraph_launches_[plan.launch_point] = plan; - } + CUDAGraphRewritePlanner planner(builder_->GetContextIRModule(), &arena_); + // Collect the target functions for rewriting before any mutation. + std::vector> target_functions; for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) { if (func->IsInstance()) { - auto new_func = Downcast(VisitExpr(func)); - if (!new_func.same_as(func)) { - builder_->UpdateFunction(gv, new_func); - } + target_functions.emplace_back(gv, Downcast(func)); + } + } + + auto [alloc_plans, capture_plans] = planner.Plan(); + if (alloc_plans.size()) { + auto global_alloc_func = MergeAllocationPlans(alloc_plans); + gv_global_alloc_ = builder_->AddFunction(global_alloc_func, "cuda_graph_alloc"); + } + for (const auto* plan : alloc_plans) { + subgraph_launches_[plan->launch_point] = plan; + } + for (const auto* plan : capture_plans) { + subgraph_launches_[plan->launch_point] = plan; + } + + for (const auto& [gv, func] : target_functions) { + current_func_ = gv; + auto new_func = Downcast(VisitExpr(func)); + if (!new_func.same_as(func)) { + builder_->UpdateFunction(gv, new_func); } } return builder_->GetContextIRModule(); } - void LaunchSubgraph(const VarBindingNode* op, const LiftedFunctionRewritePlan& plan) { + void LaunchSubgraph(const VarBindingNode* op, const LiftedFunctionRewritePlan* plan) { static const auto& call_builtin_with_ctx_op = Op::Get("relax.call_builtin_with_ctx"); static const auto& builtin_run_or_capture = ExternFunc("vm.builtin.cuda_graph.run_or_capture"); static const auto& builtin_get_cached_alloc = ExternFunc("vm.builtin.cuda_graph.get_cached_alloc"); Expr launch_subgraph; - auto gv_func = - builder_->AddFunction(plan.func, plan.is_alloc ? "cuda_graph_alloc" : "cuda_graph_capture"); - if (plan.is_alloc) { + if (plan->is_alloc) { // Storage allocation should be fully static and shouldn't depend on any symbolic variables. - ICHECK(!plan.propogated_tir_vars.defined()); - ICHECK(plan.inputs.empty()); - launch_subgraph = - Call(call_builtin_with_ctx_op, - {builtin_get_cached_alloc, - Tuple({gv_func, PrimValue(IntImm(DataType::Int(64), index_alloc_++))})}, - Attrs(), {plan.func->ret_struct_info}); + ICHECK(!plan->propogated_tir_vars.defined()); + ICHECK(plan->inputs.empty()); + auto gv_alloc = gv_global_alloc_.value(); + auto ret_struct_info = Downcast(gv_alloc->struct_info_.value())->ret; + launch_subgraph = Call( + call_builtin_with_ctx_op, + {builtin_get_cached_alloc, Tuple({gv_alloc, PrimValue(IntImm(DataType::Int(64), 0))})}, + Attrs(), {ret_struct_info}); } else { - StructInfo call_sinfo = plan.func->ret_struct_info; + auto gv_func = builder_->AddFunction( + plan->func, current_func_.value()->name_hint + "_cuda_graph_capture"); + StructInfo call_sinfo = plan->func->ret_struct_info; // Arguments of the lifted function Array args; - for (const auto& arg : plan.inputs) { + for (const auto& arg : plan->inputs) { args.push_back(VisitExpr_(arg)); } - if (plan.propogated_tir_vars.defined()) { - ShapeExpr propogated_tir_vars = plan.propogated_tir_vars.value(); + if (plan->propogated_tir_vars.defined()) { + ShapeExpr propogated_tir_vars = plan->propogated_tir_vars.value(); args.push_back(propogated_tir_vars); // The ret_struct_info of the lifted function can contain symbolic variables. We need to // bind the symbolic parameters to the actual values. - const auto& shape_expr = plan.func->params.back(); + const auto& shape_expr = plan->func->params.back(); auto symbolic_params = Downcast(shape_expr->struct_info_.value())->values.value(); Map tir_var_remap; @@ -599,25 +794,23 @@ class CUDAGraphRewriter : public ExprMutator { // Arguments of builtin_run_or_capture Array tuple_arg_fields{gv_func, Tuple(args), PrimValue(IntImm(DataType::Int(64), index_capture_++))}; - if (plan.propogated_tir_vars.defined()) { + if (plan->propogated_tir_vars.defined()) { // The shape expr is explicitly passed twice, one as the last argument of the lifted // function, one as the last argument of builtin_run_or_capture as the cache key. Explicitly // passing it twice simplifies the handling during the capture phase. - tuple_arg_fields.push_back(plan.propogated_tir_vars.value()); + tuple_arg_fields.push_back(plan->propogated_tir_vars.value()); } launch_subgraph = Call(call_builtin_with_ctx_op, {builtin_run_or_capture, Tuple(tuple_arg_fields)}, Attrs(), {call_sinfo}); } Expr ret_value = builder_->Emit(launch_subgraph); - for (int i = 0; i < static_cast(plan.outputs.size()); ++i) { - // The unpacked result is saved in the var_redef_. It will be emitted when 1) the var - // definition is the original IR is visited, or 2) the var is used as an input to another - // lifted function, whichever comes first. - var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i); + for (const auto& [var, tuple_index] : plan->outputs) { + var_redef_[var] = TupleGetItem(ret_value, tuple_index); } - - lifted_bindings_.insert(plan.lifted_bindings.begin(), plan.lifted_bindings.end()); + std::transform(plan->lifted_bindings.begin(), plan->lifted_bindings.end(), + std::inserter(lifted_binding_vars_, lifted_binding_vars_.end()), + [](const BindingNode* binding) { return binding->var.get(); }); } void VisitBinding_(const VarBindingNode* op) final { @@ -629,7 +822,7 @@ class CUDAGraphRewriter : public ExprMutator { EmitRedef(op->var.get(), it->second); return; } - if (lifted_bindings_.count(op->var.get())) { + if (lifted_binding_vars_.count(op->var.get())) { // The binding is lifted to the subgraph and will be removed from the original function. return; } @@ -654,11 +847,14 @@ class CUDAGraphRewriter : public ExprMutator { return new_var; } - std::unordered_map subgraph_launches_; + std::unordered_map subgraph_launches_; std::unordered_map var_redef_; - std::unordered_set lifted_bindings_; + std::unordered_set lifted_binding_vars_; int index_alloc_ = 0; int index_capture_ = 0; + support::Arena arena_; + Optional gv_global_alloc_ = NullOpt; + Optional current_func_ = NullOpt; }; IRModule RewriteCUDAGraph(IRModule mod) { diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 43b26f110fa29..9db285fea609b 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -107,7 +107,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): return gv @R.function(private=True) - def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected _2: R.Tuple = cls.exp(alloc, alloc1) @@ -133,7 +133,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32 storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) storage2: R.Object = gv[2] - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0] alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) _6: R.Tuple = cls.exp(alloc3, alloc4) @@ -191,7 +191,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float3 _5: R.Tuple = R.memory.kill_tensor(alloc2) _6: R.Tuple = R.memory.kill_storage(storage) _7: R.Tuple = R.memory.kill_storage(storage1) - return alloc2 + return alloc3 @I.ir_module class Expected: @@ -217,7 +217,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): return gv @R.function(private=True) - def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected _: R.Tuple = cls.exp(alloc, alloc1) @@ -242,14 +242,14 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2, 4), dtype="float3 _: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0] alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) _4: R.Tuple = cls.exp(alloc2, alloc3) _5: R.Tuple = R.memory.kill_tensor(alloc2) _6: R.Tuple = R.memory.kill_storage(storage) _7: R.Tuple = R.memory.kill_storage(storage1) - return alloc2 + return alloc3 # fmt: on after = relax.transform.RewriteCUDAGraph()(Before) @@ -318,7 +318,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): return gv @R.function(private=True) - def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")): + def main_cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")): R.func_attr({"relax.force_pure": True}) cls = Expected _2: R.Tuple = cls.exp(alloc, alloc1) @@ -338,7 +338,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((2,4), dtype="float32 _1: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")),)) + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.main_cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), dtype="float32")),)) alloc2: R.Tensor((2, 4), dtype="float32") = gv1[1] lv: R.Tensor((2, 4), dtype="float32") = gv1[0] _4: R.Tuple = R.call_packed("vm.builtin.dummy", (x, lv), sinfo_args=(R.Tuple,)) @@ -528,7 +528,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): return gv @R.function(private=True) - def cuda_graph_capture( + def main_cuda_graph_capture( lv: R.Tensor((16, 32, 32, 16), dtype="float16"), lv1: R.Tensor((16, 3, 3, 16), dtype="float16"), alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"), @@ -635,7 +635,7 @@ def main( ) = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", ( - cls.cuda_graph_capture, + cls.main_cuda_graph_capture, (lv_1, lv1, alloc1, alloc, params, storage), R.prim_value(0), ), @@ -728,7 +728,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object): return gv @R.function(private=True) - def cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple: + def main_cuda_graph_capture(alloc0: R.Tensor((8,), dtype="float32")) -> R.Tuple: R.func_attr({"relax.force_pure": True}) _: R.Object = R.call_packed("dummy_func", alloc0, R.dtype("float32"), R.str("string")) gv: R.Tuple = R.tuple() @@ -748,7 +748,7 @@ def main() -> R.Tuple: ) gv1: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", - (cls.cuda_graph_capture, (alloc0,), R.prim_value(0)), + (cls.main_cuda_graph_capture, (alloc0,), R.prim_value(0)), sinfo_args=(R.Tuple,), ) return R.tuple() @@ -822,7 +822,7 @@ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): return gv @R.function(private=True) - def cuda_graph_capture( + def main_cuda_graph_capture( alloc1: R.Tensor(("m",), dtype="float32"), alloc2: R.Tensor(("m",), dtype="float32"), shape_expr: R.Shape(["m"]), @@ -858,7 +858,7 @@ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float3 R.call_builtin_with_ctx( "vm.builtin.cuda_graph.run_or_capture", ( - cls.cuda_graph_capture, + cls.main_cuda_graph_capture, (alloc1, alloc2, R.shape([m])), R.prim_value(0), R.shape([m]), @@ -875,5 +875,218 @@ def main(x: R.Tensor(("m",), dtype="float32")) -> R.Tensor(("m",), dtype="float3 tvm.ir.assert_structural_equal(mod, Expected) +class TestMergeAllocFuncs(BaseCompare): + @I.ir_module + class Before: + @R.function + def func1(): + R.func_attr({"relax.force_pure": True}) + storage1 = R.memory.alloc_storage(R.shape([128]), 0, "global", "float32") + storage2 = R.memory.alloc_storage(R.shape([256]), 0, "global", "float32") + storage3 = R.memory.alloc_storage(R.shape([512]), 0, "ipc_memory", "float32") + alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([128]), "float32") + alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([256]), "float32") + alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([512]), "float32") + R.call_packed("dummy", alloc1, alloc2, alloc3, sinfo_args=(R.Tuple,)) + return R.tuple() + + @R.function + def func2(): + R.func_attr({"relax.force_pure": True}) + storage1 = R.memory.alloc_storage(R.shape([192]), 0, "global", "float32") + storage2 = R.memory.alloc_storage(R.shape([64]), 0, "global", "float32") + storage3 = R.memory.alloc_storage(R.shape([1024]), 0, "ipc_memory", "float32") + storage4 = R.memory.alloc_storage(R.shape([512]), 0, "global", "float32") + alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([192]), "float32") + alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([64]), "float32") + alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([1024]), "float32") + alloc4 = R.memory.alloc_tensor(storage4, 0, R.shape([512]), "float32") + R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, sinfo_args=(R.Tuple,)) + return R.tuple() + + @I.ir_module + class Expected: + @R.function(private=True) + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) + storage4: R.Object = R.memory.alloc_storage( + R.shape([512]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + storage1: R.Object = R.memory.alloc_storage( + R.shape([192]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + storage2: R.Object = R.memory.alloc_storage( + R.shape([64]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + storage3: R.Object = R.memory.alloc_storage( + R.shape([1024]), R.prim_value(0), R.str("ipc_memory"), R.dtype("float32") + ) + gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = ( + storage4, + storage1, + storage2, + storage3, + ) + return gv + + @R.function + def func1() -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + cls = Expected + gv: R.Tuple(R.Object, R.Object, R.Object, R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), + ) + storage1: R.Object = gv[1] + storage2: R.Object = gv[0] + storage3: R.Object = gv[3] + alloc1: R.Tensor((128,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([128]), R.dtype("float32") + ) + alloc2: R.Tensor((256,), dtype="float32") = R.memory.alloc_tensor( + storage2, R.prim_value(0), R.shape([256]), R.dtype("float32") + ) + alloc3: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor( + storage3, R.prim_value(0), R.shape([512]), R.dtype("float32") + ) + R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + (cls.func1_cuda_graph_capture, (alloc1, alloc2, alloc3), R.prim_value(0)), + sinfo_args=(R.Tuple,), + ) + return R.tuple() + + @R.function(private=True) + def func1_cuda_graph_capture( + alloc1: R.Tensor((128,), dtype="float32"), + alloc2: R.Tensor((256,), dtype="float32"), + alloc3: R.Tensor((512,), dtype="float32"), + ) -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + R.call_packed("dummy", alloc1, alloc2, alloc3, sinfo_args=(R.Tuple,)) + R.tuple() + return R.tuple() + + @R.function + def func2() -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + cls = Expected + gv2: R.Tuple(R.Object, R.Object, R.Object, R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object, R.Object, R.Object, R.Object),), + ) + storage11: R.Object = gv2[1] + storage21: R.Object = gv2[2] + storage31: R.Object = gv2[3] + storage4: R.Object = gv2[0] + alloc1: R.Tensor((192,), dtype="float32") = R.memory.alloc_tensor( + storage11, R.prim_value(0), R.shape([192]), R.dtype("float32") + ) + alloc2: R.Tensor((64,), dtype="float32") = R.memory.alloc_tensor( + storage21, R.prim_value(0), R.shape([64]), R.dtype("float32") + ) + alloc3: R.Tensor((1024,), dtype="float32") = R.memory.alloc_tensor( + storage31, R.prim_value(0), R.shape([1024]), R.dtype("float32") + ) + alloc4: R.Tensor((512,), dtype="float32") = R.memory.alloc_tensor( + storage4, R.prim_value(0), R.shape([512]), R.dtype("float32") + ) + R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + (cls.func2_cuda_graph_capture, (alloc1, alloc2, alloc3, alloc4), R.prim_value(1)), + sinfo_args=(R.Tuple,), + ) + return R.tuple() + + @R.function(private=True) + def func2_cuda_graph_capture( + alloc1: R.Tensor((192,), dtype="float32"), + alloc2: R.Tensor((64,), dtype="float32"), + alloc3: R.Tensor((1024,), dtype="float32"), + alloc4: R.Tensor((512,), dtype="float32"), + ) -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + R.call_packed("dummy", alloc1, alloc2, alloc3, alloc4, sinfo_args=(R.Tuple,)) + R.tuple() + return R.tuple() + + +class TestDisableCaptureOutput(BaseCompare): + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor((8,), "float32")) -> R.Tuple(R.Tensor((8,), "float32")): + R.func_attr({"relax.force_pure": True}) + storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") + alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float32") + _ = R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,)) + storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") + alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float32") + _1 = R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,)) + storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float32") + alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float32") + _2 = R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,)) + gv = (alloc3,) + return gv + + @I.ir_module + class Expected: + @R.function(private=True) + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): + R.func_attr({"relax.force_pure": True}) + storage1: R.Object = R.memory.alloc_storage( + R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + storage2: R.Object = R.memory.alloc_storage( + R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + gv: R.Tuple(R.Object, R.Object) = storage1, storage2 + return gv + + @R.function(private=True) + def main_cuda_graph_capture( + alloc1: R.Tensor((8,), dtype="float32"), alloc2: R.Tensor((8,), dtype="float32") + ) -> R.Tuple: + R.func_attr({"relax.force_pure": True}) + R.call_packed("dummy", alloc1, alloc2, sinfo_args=(R.Tuple,)) + R.tuple() + return R.tuple() + + @R.function + def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="float32")): + R.func_attr({"relax.force_pure": True}) + cls = Expected + gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.get_cached_alloc", + (cls.cuda_graph_alloc, R.prim_value(0)), + sinfo_args=(R.Tuple(R.Object, R.Object),), + ) + storage1: R.Object = gv[0] + alloc1: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage1, R.prim_value(0), R.shape([8]), R.dtype("float32") + ) + R.call_packed("dummy", x, alloc1, sinfo_args=(R.Tuple,)) + storage2: R.Object = gv[1] + alloc2: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage2, R.prim_value(0), R.shape([8]), R.dtype("float32") + ) + R.call_builtin_with_ctx( + "vm.builtin.cuda_graph.run_or_capture", + (cls.main_cuda_graph_capture, (alloc1, alloc2), R.prim_value(0)), + sinfo_args=(R.Tuple,), + ) + storage3: R.Object = R.memory.alloc_storage( + R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float32") + ) + alloc3: R.Tensor((8,), dtype="float32") = R.memory.alloc_tensor( + storage3, R.prim_value(0), R.shape([8]), R.dtype("float32") + ) + R.call_packed("dummy", alloc2, alloc3, sinfo_args=(R.Tuple,)) + gv = (alloc3,) + return gv + + if __name__ == "__main__": tvm.testing.main()