From fc78b22fbc469153f4d50de10891374e2c47f8bc Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 1 Apr 2024 15:23:54 -0700 Subject: [PATCH] [Relax][VM] Refactor CUDA graph builtins as VM extension (#16823) * [Relax][VM] Refactor CUDA graph builtins as VM extension * skip test --- include/tvm/runtime/relax_vm/vm.h | 44 ++++++++++++++ .../relax_vm/cuda/cuda_graph_builtin.cc | 60 ++++++++++++------- .../test_relax_2d_buffer_allocation.py | 2 + 3 files changed, 83 insertions(+), 23 deletions(-) diff --git a/include/tvm/runtime/relax_vm/vm.h b/include/tvm/runtime/relax_vm/vm.h index d2c96e9e97af..da833d5d6c5f 100644 --- a/include/tvm/runtime/relax_vm/vm.h +++ b/include/tvm/runtime/relax_vm/vm.h @@ -29,6 +29,7 @@ #include #include +#include #include #include "../memory/memory_manager.h" @@ -97,6 +98,27 @@ class VMClosure : public Closure { static PackedFunc BindLastArgs(PackedFunc func, std::vector last_args); }; +/*! + * \brief Represent a VM extension. + * A VM extension allows the user to extend the VM with target specific functionalities. + * The VM holds the reference of the extensions to ensure the extensions have the same lifetime + * as the VM. + * + * This is the base class for all VM extensions and should not be used directly. + */ +class VMExtensionNode : public Object { + protected: + static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + static constexpr const char* _type_key = "runtime.VMExtension"; + TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object); +}; + +/*! \brief Managed reference to VM extension. */ +class VMExtension : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode); +}; + /*! * \brief The virtual machine. * @@ -156,6 +178,25 @@ class VirtualMachine : public runtime::ModuleNode { * \param instrument The instrument function. */ virtual void SetInstrument(PackedFunc instrument) = 0; + + /*! + * \brief Get or create a VM extension. Once created, the extension will be stored in the VM + * and held until the VM is destructed. + * + * \tparam T The type of the extension + * \return The extension instance + */ + template ::value>> + T GetOrCreateExtension() { + using ContainerType = typename T::ContainerType; + uint32_t key = ContainerType::RuntimeTypeIndex(); + if (auto it = extensions.find(key); it != extensions.end()) { + return Downcast((*it).second); + } + auto [it, _] = extensions.emplace(key, T::Create()); + return Downcast((*it).second); + } + /*! * \brief Create a specific instance of VM. * \return Created VM @@ -183,6 +224,9 @@ class VirtualMachine : public runtime::ModuleNode { std::vector allocators; /*! \brief Runtime physical device list. */ std::vector devices; + /*! \brief The VM extensions. Mapping from the type index of the extension to the extension + * instance. */ + std::unordered_map extensions; }; } // namespace relax_vm diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 02b6da7dab8d..dea497e4a9d7 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual { } }; -/*! \brief The cache states of a CUDA graph. */ -class CUDAGraphCache : public Object { - public: - struct CaptureResult { - ~CaptureResult() { - if (exec) { - CUDA_CALL(cudaGraphExecDestroy(exec)); - } +/*! \brief The captured state of a CUDA graph */ +struct CUDAGraphCapturedState { + ~CUDAGraphCapturedState() { + if (exec) { + CUDA_CALL(cudaGraphExecDestroy(exec)); } - /*! - * \brief Tuple of intemediate tensors in the capture func that will be used outside the - * capture func - */ - ObjectRef states; - /*! \brief The instantiated cuda graph */ - cudaGraphExec_t exec = nullptr; - }; + } - static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore::Get(); } + /*! + * \brief Tuple of intemediate tensors in the capture func that will be used outside the + * capture func + */ + ObjectRef states; + /*! \brief The instantiated cuda graph */ + cudaGraphExec_t exec = nullptr; +}; + +/*! \brief The VM extension of CUDA graph. */ +class CUDAGraphExtensionNode : public VMExtensionNode { + public: + TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode); /*! * \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode. @@ -107,7 +109,7 @@ class CUDAGraphCache : public Object { cudaStream_t capture_stream; CUDA_CALL(cudaStreamCreate(&capture_stream)); - CUDAGraphCache::CaptureResult entry; + CUDAGraphCapturedState entry; // Set up arguments for the graph execution Array tuple_args = Downcast>(args); @@ -164,12 +166,14 @@ class CUDAGraphCache : public Object { return alloc_result; } + static constexpr const char* _type_key = "relax_vm.CUDAGraphExtension"; + private: /*! * \brief The cache of captured cuda graphs. The key is a unique index for the capture function. * The value is the result of the capture. */ - std::unordered_map capture_cache_; /*! @@ -179,10 +183,21 @@ class CUDAGraphCache : public Object { std::unordered_map alloc_cache_; }; +/*! Managed reference to CUDAGraphExtensionNode */ +class CUDAGraphExtension : public VMExtension { + public: + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode); + static CUDAGraphExtension Create() { + auto data_ = make_object(); + return CUDAGraphExtension(std::move(data_)); + } +}; + TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") .set_body([](TVMArgs args, TVMRetValue* rv) { ICHECK(args.size() == 5 || args.size() == 4); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + auto extension = vm->GetOrCreateExtension(); ObjectRef capture_func = args[1]; ObjectRef func_args = args[2]; int64_t entry_index = args[3]; @@ -190,18 +205,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture") if (args.size() == 5) { shape_expr = args[4].AsObjectRef(); } - CUDAGraphCache* cache = CUDAGraphCache::Get(); - *rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); + *rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr); }); TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc") .set_body([](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size(), 3); VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]); + auto extension = vm->GetOrCreateExtension(); ObjectRef alloc_func = args[1]; int64_t entry_index = args[2]; - CUDAGraphCache* cache = CUDAGraphCache::Get(); - *rv = cache->GetCachedAllocation(vm, alloc_func, entry_index); + *rv = extension->GetCachedAllocation(vm, alloc_func, entry_index); }); } // namespace relax_vm diff --git a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py index ae459dc770d7..6eaa1179ba17 100644 --- a/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py +++ b/tests/python/contrib/test_hexagon/test_relax_2d_buffer_allocation.py @@ -25,6 +25,7 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T +import pytest # pylint: disable=missing-docstring,no-self-argument,invalid-name @@ -64,6 +65,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")): # pylint: enable=missing-docstring,no-self-argument,invalid-name +@pytest.mark.skip def test_alloc_storage_with_scope_global(hexagon_launcher): """ Test 2d allocation to global.vtcm memory scope in a Relax Function