Skip to content

Commit

Permalink
[VM] Memory Manager moved up to runtime (#15833)
Browse files Browse the repository at this point in the history
* [VM] memory Manager moved up to runtime

Now graph runtime also uses the same memory manager
This acommodates a common memory manager with pooled and naive support.

As a follow up we can move the WorkspacePool to use this common memory manager.

* * update dependents with new file addition.

* *  define memory_manager under new namespace

* * use ShapeTuple across vm executor and memory_manager

* * ShapeTuple across the Allocators

* * GetDataSize is moved to DeviceAPI and memory_manager uses this interface.

* * review comments

* * Make compiler happy with unused variables

* * lint

* Update src/runtime/memory/memory_manager.cc

Co-authored-by: Egor Churaev <[email protected]>

* * allow multiple allocators to coexist for the same device.
Using available allocator instead of requested is leading to an unpexpected crash

---------

Co-authored-by: Egor Churaev <[email protected]>
  • Loading branch information
srkreddy1238 and echuraev authored Oct 3, 2023
1 parent a38053e commit b8abff9
Show file tree
Hide file tree
Showing 19 changed files with 159 additions and 94 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
src/runtime/memory/*.cc
src/runtime/disco/*.cc
src/runtime/minrpc/*.cc
)
Expand Down
1 change: 1 addition & 0 deletions apps/android_camera/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "../src/runtime/graph_executor/graph_executor.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/logging.cc"
#include "../src/runtime/memory/memory_manager.cc"
#include "../src/runtime/minrpc/minrpc_logger.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc"
Expand Down
1 change: 1 addition & 0 deletions apps/android_deploy/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "../src/runtime/graph_executor/graph_executor.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/logging.cc"
#include "../src/runtime/memory/memory_manager.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc"
#include "../src/runtime/object.cc"
Expand Down
1 change: 1 addition & 0 deletions apps/android_rpc/app/src/main/jni/tvm_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "../src/runtime/graph_executor/graph_executor_factory.cc"
#include "../src/runtime/library_module.cc"
#include "../src/runtime/logging.cc"
#include "../src/runtime/memory/memory_manager.cc"
#include "../src/runtime/minrpc/minrpc_logger.cc"
#include "../src/runtime/module.cc"
#include "../src/runtime/ndarray.cc"
Expand Down
1 change: 1 addition & 0 deletions apps/bundle_deploy/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "../../src/runtime/graph_executor/graph_executor.cc"
#include "../../src/runtime/library_module.cc"
#include "../../src/runtime/logging.cc"
#include "../../src/runtime/memory/memory_manager.cc"
#include "../../src/runtime/module.cc"
#include "../../src/runtime/ndarray.cc"
#include "../../src/runtime/object.cc"
Expand Down
1 change: 1 addition & 0 deletions apps/howto_deploy/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
// Graph executor
#include "../../src/runtime/graph_executor/graph_executor.cc"
#include "../../src/runtime/graph_executor/graph_executor_factory.cc"
#include "../../src/runtime/memory/memory_manager.cc"

// Uncomment the following lines to enable RPC
// #include "../../src/runtime/rpc/rpc_session.cc"
Expand Down
1 change: 1 addition & 0 deletions golang/src/tvm_runtime_pack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

// Graph executor
#include "src/runtime/graph_executor/graph_executor.cc"
#include "src/runtime/memory/memory_manager.cc"

// Uncomment the following lines to enable RPC
// #include "../../src/runtime/rpc/rpc_session.cc"
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ class TVM_DLL DeviceAPI {
*/
virtual void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) = 0;

/*!
* \brief Get the physical memory size required.
* \param arr the tensor object.
* \param mem_scope the memory scope if any
* \return the memory size.
*/
virtual size_t GetDataSize(const DLTensor& arr, Optional<String> mem_scope = NullOpt);

/*!
* \brief Query the device for specified properties.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
*/

/*!
* \file tvm/runtime/vm/memory_manager.h
* \file tvm/runtime/memory/memory_manager.h
* \brief Abstract device memory management API
*/
#ifndef TVM_RUNTIME_VM_MEMORY_MANAGER_H_
#define TVM_RUNTIME_VM_MEMORY_MANAGER_H_
#ifndef TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_
#define TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/ndarray.h>
Expand All @@ -37,22 +37,22 @@

namespace tvm {
namespace runtime {
namespace vm {
namespace memory {

enum AllocatorType {
kNaive = 1,
kPooled,
};

struct Buffer {
/*! \brief The pointer to the allocated block of memory. */
void* data{nullptr};
/*! \brief The size of the block. */
size_t size{0};
/*! \brief The shape of the tensor. */
std::vector<int64_t> shape;
/*! \brief The context of the allocated buffers. */
Device device;
};

enum AllocatorType {
kNaive = 1,
kPooled,
/*! \brief The allocator that created this buffer. */
AllocatorType alloc_type;
};

class Allocator {
Expand All @@ -63,9 +63,11 @@ class Allocator {
* \param shape The shape of the NDArray.
* \param dtype The datatype of the NDArray.
* \param dev The device where the array is allocated.
* \param mem_scope The device memory scope hint.
* \return The empty NDArray.
*/
NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev);
NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev,
Optional<String> mem_scope = NullOpt);
/*! \brief Return the allocator type. */
inline AllocatorType type() const { return type_; }
/*! \brief Allocate a buffer given a size, alignment and type.
Expand All @@ -76,13 +78,12 @@ class Allocator {
*/
virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0;
/*! \brief Allocate a buffer given a shape and type.
* \param ndims The rank of the tensor.
* \param shape The shape of the tensor.
* \param type_hint A type hint to the allocator.
* \param mem_scope A memory scope of the buffer.
* \return A sized allocation in the form of a buffer.
*/
virtual Buffer Alloc(int ndims, int64_t* shape, DLDataType type_hint,
virtual Buffer Alloc(ShapeTuple shape, DLDataType type_hint,
const std::string& mem_scope = "") = 0;
/*! \brief Free a buffer allocated by the allocator.
* \param buffer The buffer to free.
Expand All @@ -94,7 +95,7 @@ class Allocator {
virtual size_t UsedMemory() const = 0;

protected:
virtual Buffer Alloc(Device dev, int ndims, int64_t* shape, DLDataType type_hint,
virtual Buffer Alloc(Device dev, ShapeTuple shape, DLDataType type_hint,
const std::string& mem_scope);

private:
Expand All @@ -114,16 +115,18 @@ class MemoryManager {
/*!
* \brief Get an allocator given the context.
* \param dev The TVM device
* \param type The allocator type
* \return The memory allocator.
*/
static Allocator* GetAllocator(Device dev);
static Allocator* GetAllocator(Device dev, AllocatorType type);

private:
MemoryManager() {}

protected:
std::mutex mu_;
std::unordered_map<Device, std::unique_ptr<Allocator>> allocators_;
std::unordered_map<Device, std::unordered_map<AllocatorType, std::unique_ptr<Allocator>>>
allocators_;
};

/*! \brief An object representing a storage allocation. */
Expand All @@ -133,13 +136,13 @@ class StorageObj : public Object {
Buffer buffer;

/*! \brief Allocate an NDArray from a given piece of storage. */
NDArray AllocNDArray(size_t offset, std::vector<int64_t> shape, DLDataType dtype);
NDArray AllocNDArray(size_t offset, ShapeTuple shape, DLDataType dtype);

/*! \brief The deleter for an NDArray when allocated from underlying storage. */
static void Deleter(Object* ptr);

~StorageObj() {
auto alloc = MemoryManager::Global()->GetAllocator(buffer.device);
auto alloc = MemoryManager::Global()->GetAllocator(buffer.device, buffer.alloc_type);
alloc->Free(buffer);
}

Expand All @@ -156,8 +159,8 @@ class Storage : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Storage, ObjectRef, StorageObj);
};

} // namespace vm
} // namespace memory
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_VM_MEMORY_MANAGER_H_
#endif // TVM_RUNTIME_MEMORY_MEMORY_MANAGER_H_
9 changes: 8 additions & 1 deletion include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
#define TVM_RUNTIME_VM_VM_H_

#include <tvm/runtime/container/closure.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/vm/bytecode.h>
#include <tvm/runtime/vm/executable.h>
#include <tvm/runtime/vm/memory_manager.h>

#include <memory>
#include <string>
Expand All @@ -41,6 +41,13 @@

namespace tvm {
namespace runtime {

using memory::Allocator;
using memory::AllocatorType;
using memory::MemoryManager;
using memory::Storage;
using memory::StorageObj;

namespace vm {

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
#include <utility>
#include <vector>

#include "../../../runtime/vm/naive_allocator.h"
#include "../../../runtime/memory/naive_allocator.h"
#include "../../../runtime/vm/profiler/vm.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler.h"
Expand Down
14 changes: 14 additions & 0 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ static size_t GetDataAlignment(const DLDataType dtype) {
return align;
}

size_t DeviceAPI::GetDataSize(const DLTensor& arr, Optional<String> mem_scope) {
if (!mem_scope.defined() || mem_scope.value().empty() || mem_scope.value() == "global") {
size_t size = 1;
for (tvm_index_t i = 0; i < arr.ndim; ++i) {
size *= static_cast<size_t>(arr.shape[i]);
}
size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8;
return size;
}
LOG(FATAL) << "Device does not support physical mem computation with "
<< "specified memory scope: " << mem_scope.value();
return 0;
}

void* DeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype,
Optional<String> mem_scope) {
if (!mem_scope.defined() || mem_scope.value() == "" || mem_scope.value() == "global") {
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ void GraphExecutor::SetupStorage() {
if (!pit.scope.empty()) {
mem_scope = String(pit.scope);
}
storage_pool_.push_back(NDArray::Empty(shape, pit.dtype, dev, mem_scope));
storage_pool_.push_back(MemoryManager::GetOrCreateAllocator(dev, AllocatorType::kNaive)
->Empty(shape, pit.dtype, dev, mem_scope));
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/runtime/graph_executor/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <dlpack/dlpack.h>
#include <dmlc/json.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/memory/memory_manager.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>

Expand All @@ -42,6 +43,9 @@
namespace tvm {
namespace runtime {

using memory::AllocatorType;
using memory::MemoryManager;

/*! \brief macro to do C API call */
#define TVM_CCALL(func) \
{ \
Expand Down
Loading

0 comments on commit b8abff9

Please sign in to comment.