Skip to content

Commit

Permalink
[VM] [Hexagon] Introduce 2D Discontiguous vtcm alloc tensor
Browse files Browse the repository at this point in the history
Adds 2D Discontiguous alloc tensor hexagon builtin to support 2D
allocations for hexagon at relax level. This is needed when the ops are
implemented to take advantage of 2d indirections and enables
memory manager optimizations to try utilize VTCM memory efficiently.

This patch also introduces the `R.vm.copy_tensor` op to support copies
between different tensors, specifically planned to be used when copying
tensors from one memory scope to another

Co-authored-by: arangasa <[email protected]>
  • Loading branch information
quic-sanirudh and arangasa committed Feb 20, 2024
1 parent dd70941 commit 6902f5a
Show file tree
Hide file tree
Showing 15 changed files with 669 additions and 27 deletions.
4 changes: 4 additions & 0 deletions include/tvm/runtime/memory/memory_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class StorageObj : public Object {
/*! \brief The index into the VM function table. */
Buffer buffer;

/* \brief Common function to create an NDArray container with the provided offset, shape and dtype
*/
NDArray::Container* CreateNDArrayContainer(int64_t offset, ShapeTuple shape, DLDataType dtype);

/*! \brief Allocate an NDArray from a given piece of storage. */
NDArray AllocNDArray(int64_t offset, ShapeTuple shape, DLDataType dtype);

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/vm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
# under the License.
"""Relax vm primitives."""

from .vm import alloc_storage, alloc_tensor, call_tir_dyn, kill_object
from .vm import alloc_storage, alloc_tensor, call_tir_dyn, copy_tensor_from_to, kill_object
20 changes: 20 additions & 0 deletions python/tvm/relax/op/vm/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,23 @@ def call_tir_dyn(func: Expr, args: Tuple) -> Call:
args = Tuple(args)

return _ffi_api.call_tir_dyn(func, args) # type: ignore


@args_converter.auto
def copy_tensor_from_to(src: Expr, dst: Expr) -> Call:
"""Construct a call to copy one tensor to another.
Parameters
----------
src : Expr
Source tensor for copy.
dst : Expr
Destination tensor for copy.
Returns
-------
result : Call
A relax Call, which performs the copy.
"""
return _ffi_api.copy_tensor_from_to(src, dst) # type: ignore
13 changes: 13 additions & 0 deletions src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
EmitAllocStorage(call, dst_reg);
} else if (call_node->op == alloc_tensor_op_) {
EmitAllocTensor(call, dst_reg);
} else if (call_node->op == copy_tensor_from_to_op_) {
EmitCopyTensor(call, dst_reg);
} else if (call_node->op == kill_object_op_) {
dst_reg = EmitKillObject(call);
} else {
Expand Down Expand Up @@ -361,6 +363,16 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
builder_->EmitCall("vm.builtin.alloc_tensor", args, dst_reg);
}

void EmitCopyTensor(const Call& call_node, RegName dst_reg) {
ICHECK_EQ(call_node->args.size(), 2);
std::vector<Instruction::Arg> args;
args.reserve(2);
for (Expr arg : call_node->args) {
args.push_back(this->VisitExpr(arg));
}
builder_->EmitCall("vm.builtin.copy_tensor_from_to", args, dst_reg);
}

RegName EmitKillObject(const Call& call_node) {
ICHECK_EQ(call_node->args.size(), 1);
Instruction::Arg arg = this->VisitExpr(call_node->args[0]);
Expand Down Expand Up @@ -430,6 +442,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& copy_tensor_from_to_op_ = Op::Get("relax.vm.copy_tensor_from_to");
const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
Expand Down
13 changes: 13 additions & 0 deletions src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
EmitAllocStorage(call, dst_reg);
} else if (call_node->op == alloc_tensor_op_) {
EmitAllocTensor(call, dst_reg);
} else if (call_node->op == copy_tensor_from_to_op_) {
EmitCopyTensor(call, dst_reg);
} else if (call_node->op == kill_object_op_) {
dst_reg = EmitKillObject(call);
} else {
Expand Down Expand Up @@ -414,6 +416,16 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
this->EmitCallPacked("vm.builtin.alloc_tensor", args, dst_reg);
}

void EmitCopyTensor(const Call& call_node, int64_t dst_reg) {
ICHECK_EQ(call_node->args.size(), 2);
Array<PrimExpr> args;
args.reserve(2);
for (Expr arg : call_node->args) {
args.push_back(this->VisitExpr(arg).value());
}
this->EmitCallPacked("vm.builtin.copy_tensor_from_to", args, dst_reg);
}

int64_t EmitKillObject(const Call& call_node) {
ICHECK_EQ(call_node->args.size(), 1);
PrimExpr arg = this->VisitExpr(call_node->args[0]).value();
Expand Down Expand Up @@ -519,6 +531,7 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
const Op& alloc_storage_op_ = Op::Get("relax.vm.alloc_storage");
const Op& alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor");
const Op& copy_tensor_from_to_op_ = Op::Get("relax.vm.copy_tensor_from_to");
const Op& kill_object_op_ = Op::Get("relax.vm.kill_object");
const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
const Op& null_value_op_ = Op::Get("relax.null_value");
Expand Down
16 changes: 16 additions & 0 deletions src/relax/op/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,22 @@ Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm d

TVM_REGISTER_GLOBAL("relax.op.vm.alloc_tensor").set_body_typed(MakeVMAllocTensor);

// vm copy_tensor_from_to

RELAY_REGISTER_OP("relax.vm.copy_tensor_from_to")
.set_num_inputs(2)
.add_argument("src", "Expr", "The tensor to copy from")
.add_argument("dst", "Expr", "The tensor to copy to")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
.set_attr<Bool>("FPurity", Bool(true));

Expr MakeVMCopyTensor(Expr src, Expr dst) {
static const Op& op = Op::Get("relax.vm.copy_tensor_from_to");
return Call(op, {src, dst}, Attrs(), {});
}

TVM_REGISTER_GLOBAL("relax.op.vm.copy_tensor_from_to").set_body_typed(MakeVMCopyTensor);

// vm kill_object

TVM_REGISTER_OP("relax.vm.kill_object")
Expand Down
8 changes: 4 additions & 4 deletions src/runtime/hexagon/hexagon_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ std::vector<MemoryCopy> MemoryCopy::MergeAdjacent(std::vector<MemoryCopy> micro_
return macro_copies;
}

void hexagon_buffer_copy_across_regions(const BufferSet& dest, const BufferSet& src,
void HexagonBufferCopyAcrossRegions(const BufferSet& dest, const BufferSet& src,
size_t bytes_to_copy, bool src_is_hexbuff,
bool dest_is_hexbuff) {
// First, determine all copies that do not cross boundaries in
Expand Down Expand Up @@ -268,23 +268,23 @@ void HexagonBuffer::CopyTo(void* data, size_t nbytes) const {
BufferSet src(allocations_.data(), allocations_.size(), nbytes_per_allocation_);
BufferSet dest(&data, 1, nbytes);

hexagon_buffer_copy_across_regions(dest, src, nbytes, true /* src_is_hexbuff */,
HexagonBufferCopyAcrossRegions(dest, src, nbytes, true /* src_is_hexbuff */,
false /* dest_is_hexbuff */);
}

void HexagonBuffer::CopyFrom(void* data, size_t nbytes) {
BufferSet src(&data, 1, nbytes);
BufferSet dest(allocations_.data(), allocations_.size(), nbytes_per_allocation_);

hexagon_buffer_copy_across_regions(dest, src, nbytes, false /* src_is_hexbuff */,
HexagonBufferCopyAcrossRegions(dest, src, nbytes, false /* src_is_hexbuff */,
true /* dest_is_hexbuff */);
}

void HexagonBuffer::CopyFrom(const HexagonBuffer& other, size_t nbytes) {
BufferSet src(other.allocations_.data(), other.allocations_.size(), other.nbytes_per_allocation_);
BufferSet dest(allocations_.data(), allocations_.size(), nbytes_per_allocation_);

hexagon_buffer_copy_across_regions(dest, src, nbytes, true /* src_is_hexbuff */,
HexagonBufferCopyAcrossRegions(dest, src, nbytes, true /* src_is_hexbuff */,
true /* dest_is_hexbuff */);
}

Expand Down
13 changes: 13 additions & 0 deletions src/runtime/hexagon/hexagon_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ struct BufferSet {
size_t region_size_bytes;
};

/**
* @brief Single function to handle copying potentially discontiguous buffers efficiently
*
* @param The destination buffer
* @param The source buffer
* @param Number of bytes to copy. This should be less than both source and dest buffer size
* @param Boolean to specify whether the source is a hexagon buffer
* @param Boolean to specify whether the destination is a hexagon buffer
*/
void HexagonBufferCopyAcrossRegions(const BufferSet& dest, const BufferSet& src,
size_t bytes_to_copy, bool src_is_hexbuff,
bool dest_is_hexbuff);

} // namespace hexagon
} // namespace runtime
} // namespace tvm
Expand Down
74 changes: 55 additions & 19 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
#include <cstring>

#include "../workspace_pool.h"
#include "hexagon_buffer.h"
#include "hexagon_common.h"
#include "qurt_memory.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -91,23 +93,29 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap
CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with "
<< "HexagonDeviceAPI::AllocDataSpace before initializing resources. "
<< "Please call HexagonDeviceAPI::AcquireResources";

void* base_ptr;
PhysicalShape physical_shape;
if (ndim == 0) {
// Allocate storage for a single scalar value.
return runtime_hexbuffs->AllocateHexagonBuffer(typesize, kHexagonAllocAlignment, mem_scope);
base_ptr = runtime_hexbuffs->AllocateHexagonBuffer(typesize, kHexagonAllocAlignment, mem_scope);
physical_shape = {1, 1, typesize};
} else if (ndim == 1) {
// Allocate a single, contiguous memory region.
size_t nbytes = shape[0] * typesize;
return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, kHexagonAllocAlignment, mem_scope);
base_ptr = runtime_hexbuffs->AllocateHexagonBuffer(nbytes, kHexagonAllocAlignment, mem_scope);
physical_shape = {1, 1, nbytes};
} else if (ndim == 2) {
// Allocate the region(s) needed for Hexagon's indirect-tensor format.
size_t nallocs = shape[0];
size_t nbytes = shape[1] * typesize;
return runtime_hexbuffs->AllocateHexagonBuffer(nallocs, nbytes, kHexagonAllocAlignment,
mem_scope);
base_ptr =
runtime_hexbuffs->AllocateHexagonBuffer(nallocs, nbytes, kHexagonAllocAlignment, mem_scope);
physical_shape = {2, nallocs, nbytes};
} else {
return nullptr; // unreachable
}
SetPhysicalShape(base_ptr, physical_shape);
return base_ptr;
}

void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment,
Expand All @@ -121,7 +129,10 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignme
CHECK(runtime_hexbuffs) << "Attempted to allocate Hexagon data with "
<< "HexagonDeviceAPI::AllocDataSpace before initializing resources. "
<< "Please call HexagonDeviceAPI::AcquireResources";
return runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, String("global"));
void* base_ptr = runtime_hexbuffs->AllocateHexagonBuffer(nbytes, alignment, String("global"));
PhysicalShape physical_shape = {1, 1, nbytes};
SetPhysicalShape(base_ptr, physical_shape);
return base_ptr;
}

void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) {
Expand All @@ -134,6 +145,7 @@ void HexagonDeviceAPI::FreeDataSpace(Device dev, void* ptr) {
// occur in the normal course of shutdown, log a message and continue.
DLOG(INFO) << "FreeDataSpace called outside a session for " << ptr;
}
ndarray_physical_shape.erase(ptr);
}

// WorkSpace: runtime allocations for Hexagon
Expand All @@ -157,6 +169,8 @@ void HexagonDeviceAPI::FreeWorkspace(Device dev, void* data) {
dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->FreeWorkspace(dev, data);
}

void* get_data_start(DLTensor* tensor) { return (reinterpret_cast<uint8_t*>(tensor->data)); }

void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) {
CHECK_EQ(from->byte_offset, 0);
CHECK_EQ(to->byte_offset, 0);
Expand All @@ -165,22 +179,44 @@ void HexagonDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHan
<< "HexagonDeviceAPI::CopyDataFromTo before initializing resources. "
<< "Please call HexagonDeviceAPI::AcquireResources";

auto lookup_hexagon_buffer = [this](void* ptr) -> HexagonBuffer* {
return runtime_hexbuffs->FindHexagonBuffer(ptr);
};
auto numBytes = GetDataSize(*from);

size_t FlatShape = 1;
for (auto i = 0; i < from->ndim; ++i) FlatShape *= from->shape[i];

PhysicalShape source_shape = {1, 1, FlatShape};
PhysicalShape dest_shape = {1, 1, FlatShape};
auto it1 = ndarray_physical_shape.find(from->data);
if (it1 != ndarray_physical_shape.end()) source_shape = it1->second;
size_t src_rank = source_shape.ndim;
void* src_start = get_data_start(from);
void* dst_start = get_data_start(to);
BufferSet src((src_rank == 1) ? &(src_start) : static_cast<void**>(src_start),
source_shape.nblocks, numBytes / source_shape.nblocks);
auto it2 = ndarray_physical_shape.find(to->data);
if (it2 != ndarray_physical_shape.end()) dest_shape = it2->second;
size_t dest_rank = dest_shape.ndim;
BufferSet dest((dest_rank == 1) ? &(dst_start) : static_cast<void**>(dst_start),
dest_shape.nblocks, numBytes / dest_shape.nblocks);
HexagonBufferCopyAcrossRegions(dest, src, numBytes, (it1 != ndarray_physical_shape.end()),
(it2 != ndarray_physical_shape.end()));
return;
}

HexagonBuffer* hex_from_buf = lookup_hexagon_buffer(from->data);
HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data);
void HexagonDeviceAPI::SetPhysicalShape(const DLTensor* tensor, const int64_t ndim,
const int64_t* shape) {
PhysicalShape physical_shape = {static_cast<size_t>(ndim), static_cast<size_t>(shape[0]),
static_cast<size_t>(shape[1])};
SetPhysicalShape(tensor->data, physical_shape);
}

if (hex_from_buf && hex_to_buf) {
hex_to_buf->CopyFrom(*hex_from_buf, GetDataSize(*from));
} else if (hex_to_buf) {
hex_to_buf->CopyFrom(from->data, GetDataSize(*from));
} else if (hex_from_buf) {
hex_from_buf->CopyTo(to->data, GetDataSize(*to));
void HexagonDeviceAPI::SetPhysicalShape(const void* data, const PhysicalShape& physical_shape) {
auto it = ndarray_physical_shape.find(const_cast<void*>(data));
if (it != ndarray_physical_shape.end()) {
ndarray_physical_shape[const_cast<void*>(data)] = physical_shape;
} else {
CHECK(false) << "CopyDataFromTo requested between src and dst which are not managed by the "
"hexagon device api.";
ndarray_physical_shape.insert(
std::pair<void*, PhysicalShape>(const_cast<void*>(data), physical_shape));
}
}

Expand Down
20 changes: 19 additions & 1 deletion src/runtime/hexagon/hexagon_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

#include <tvm/runtime/device_api.h>

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
Expand All @@ -40,6 +39,12 @@ namespace tvm {
namespace runtime {
namespace hexagon {

struct PhysicalShape {
size_t ndim;
size_t nblocks;
size_t block_size;
};

/*!
* \brief Hexagon Device API that is compiled and run on Hexagon.
*/
Expand Down Expand Up @@ -148,6 +153,11 @@ class HexagonDeviceAPI final : public DeviceAPI {
*/
void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final;

/*!
* \brief set physical shape of tensor
*/
void SetPhysicalShape(const DLTensor* tensor, const int64_t ndim, const int64_t* shape);

HexagonThreadManager* ThreadManager() {
CHECK(runtime_threads) << "runtime_threads has not been created";
return runtime_threads.get();
Expand Down Expand Up @@ -178,6 +188,11 @@ class HexagonDeviceAPI final : public DeviceAPI {
return (dev.device_type == kDLHexagon) || (dev.device_type == kDLCPU);
}

/*!
* \brief set physical shape of tensor - private helper
*/
void SetPhysicalShape(const void* data, const PhysicalShape&);

//! \brief Manages runtime HexagonBuffer allocations
// runtime_hexbuffs is used for runtime allocations. It is created with a call to
// AcquireResources, and destroyed on ReleaseResources. The buffers in this manager are scoped
Expand All @@ -199,6 +214,9 @@ class HexagonDeviceAPI final : public DeviceAPI {

//! \brief Hexagon power manager
std::unique_ptr<HexagonPowerManager> runtime_power_manager;

//! \brief NDArray base -> Physical Shape map
std::unordered_map<void*, PhysicalShape> ndarray_physical_shape;
};
} // namespace hexagon
} // namespace runtime
Expand Down
Loading

0 comments on commit 6902f5a

Please sign in to comment.