Skip to content

Commit

Permalink
[Relax][VM] Refactor CUDA graph builtins as VM extension (#16823)
Browse files Browse the repository at this point in the history
* [Relax][VM] Refactor CUDA graph builtins as VM extension

* skip test
  • Loading branch information
vinx13 authored Apr 1, 2024
1 parent 00395ae commit fc78b22
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 23 deletions.
44 changes: 44 additions & 0 deletions include/tvm/runtime/relax_vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "../memory/memory_manager.h"
Expand Down Expand Up @@ -97,6 +98,27 @@ class VMClosure : public Closure {
static PackedFunc BindLastArgs(PackedFunc func, std::vector<TVMRetValue> last_args);
};

/*!
* \brief Represent a VM extension.
* A VM extension allows the user to extend the VM with target specific functionalities.
* The VM holds the reference of the extensions to ensure the extensions have the same lifetime
* as the VM.
*
* This is the base class for all VM extensions and should not be used directly.
*/
class VMExtensionNode : public Object {
protected:
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "runtime.VMExtension";
TVM_DECLARE_BASE_OBJECT_INFO(VMExtensionNode, Object);
};

/*! \brief Managed reference to VM extension. */
class VMExtension : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(VMExtension, ObjectRef, VMExtensionNode);
};

/*!
* \brief The virtual machine.
*
Expand Down Expand Up @@ -156,6 +178,25 @@ class VirtualMachine : public runtime::ModuleNode {
* \param instrument The instrument function.
*/
virtual void SetInstrument(PackedFunc instrument) = 0;

/*!
* \brief Get or create a VM extension. Once created, the extension will be stored in the VM
* and held until the VM is destructed.
*
* \tparam T The type of the extension
* \return The extension instance
*/
template <typename T, typename = std::enable_if_t<std::is_base_of<VMExtension, T>::value>>
T GetOrCreateExtension() {
using ContainerType = typename T::ContainerType;
uint32_t key = ContainerType::RuntimeTypeIndex();
if (auto it = extensions.find(key); it != extensions.end()) {
return Downcast<T>((*it).second);
}
auto [it, _] = extensions.emplace(key, T::Create());
return Downcast<T>((*it).second);
}

/*!
* \brief Create a specific instance of VM.
* \return Created VM
Expand Down Expand Up @@ -183,6 +224,9 @@ class VirtualMachine : public runtime::ModuleNode {
std::vector<Allocator*> allocators;
/*! \brief Runtime physical device list. */
std::vector<Device> devices;
/*! \brief The VM extensions. Mapping from the type index of the extension to the extension
* instance. */
std::unordered_map<uint32_t, VMExtension> extensions;
};

} // namespace relax_vm
Expand Down
60 changes: 37 additions & 23 deletions src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,27 @@ struct CUDAGraphCaptureKeyEqual {
}
};

/*! \brief The cache states of a CUDA graph. */
class CUDAGraphCache : public Object {
public:
struct CaptureResult {
~CaptureResult() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
/*! \brief The captured state of a CUDA graph */
struct CUDAGraphCapturedState {
~CUDAGraphCapturedState() {
if (exec) {
CUDA_CALL(cudaGraphExecDestroy(exec));
}
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ObjectRef states;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};
}

static CUDAGraphCache* Get() { return dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
/*!
* \brief Tuple of intemediate tensors in the capture func that will be used outside the
* capture func
*/
ObjectRef states;
/*! \brief The instantiated cuda graph */
cudaGraphExec_t exec = nullptr;
};

/*! \brief The VM extension of CUDA graph. */
class CUDAGraphExtensionNode : public VMExtensionNode {
public:
TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphExtensionNode, VMExtensionNode);

/*!
* \brief Launch the cuda graph if it has been cached, otherwise execute it in capture mode.
Expand All @@ -107,7 +109,7 @@ class CUDAGraphCache : public Object {

cudaStream_t capture_stream;
CUDA_CALL(cudaStreamCreate(&capture_stream));
CUDAGraphCache::CaptureResult entry;
CUDAGraphCapturedState entry;

// Set up arguments for the graph execution
Array<ObjectRef> tuple_args = Downcast<Array<ObjectRef>>(args);
Expand Down Expand Up @@ -164,12 +166,14 @@ class CUDAGraphCache : public Object {
return alloc_result;
}

static constexpr const char* _type_key = "relax_vm.CUDAGraphExtension";

private:
/*!
* \brief The cache of captured cuda graphs. The key is a unique index for the capture function.
* The value is the result of the capture.
*/
std::unordered_map<CUDAGraphCaptureKey, CaptureResult, CUDAGraphCaptureKeyHash,
std::unordered_map<CUDAGraphCaptureKey, CUDAGraphCapturedState, CUDAGraphCaptureKeyHash,
CUDAGraphCaptureKeyEqual>
capture_cache_;
/*!
Expand All @@ -179,29 +183,39 @@ class CUDAGraphCache : public Object {
std::unordered_map<int64_t, ObjectRef> alloc_cache_;
};

/*! Managed reference to CUDAGraphExtensionNode */
class CUDAGraphExtension : public VMExtension {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(CUDAGraphExtension, VMExtension, CUDAGraphExtensionNode);
static CUDAGraphExtension Create() {
auto data_ = make_object<CUDAGraphExtensionNode>();
return CUDAGraphExtension(std::move(data_));
}
};

TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.run_or_capture")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK(args.size() == 5 || args.size() == 4);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef capture_func = args[1];
ObjectRef func_args = args[2];
int64_t entry_index = args[3];
Optional<ShapeTuple> shape_expr = NullOpt;
if (args.size() == 5) {
shape_expr = args[4].AsObjectRef<ShapeTuple>();
}
CUDAGraphCache* cache = CUDAGraphCache::Get();
*rv = cache->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr);
*rv = extension->RunOrCapture(vm, capture_func, func_args, entry_index, shape_expr);
});

TVM_REGISTER_GLOBAL("vm.builtin.cuda_graph.get_cached_alloc")
.set_body([](TVMArgs args, TVMRetValue* rv) {
ICHECK_EQ(args.size(), 3);
VirtualMachine* vm = VirtualMachine::GetContextPtr(args[0]);
auto extension = vm->GetOrCreateExtension<CUDAGraphExtension>();
ObjectRef alloc_func = args[1];
int64_t entry_index = args[2];
CUDAGraphCache* cache = CUDAGraphCache::Get();
*rv = cache->GetCachedAllocation(vm, alloc_func, entry_index);
*rv = extension->GetCachedAllocation(vm, alloc_func, entry_index);
});

} // namespace relax_vm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
import pytest


# pylint: disable=missing-docstring,no-self-argument,invalid-name
Expand Down Expand Up @@ -64,6 +65,7 @@ def main(x: R.Tensor((2, 2), dtype="float32")):


# pylint: enable=missing-docstring,no-self-argument,invalid-name
@pytest.mark.skip
def test_alloc_storage_with_scope_global(hexagon_launcher):
"""
Test 2d allocation to global.vtcm memory scope in a Relax Function
Expand Down

0 comments on commit fc78b22

Please sign in to comment.