Skip to content

Commit

Permalink
[Relax] NDArray Cache Update with DLTensor Support (#16464)
Browse files Browse the repository at this point in the history
As NDArray on RPC devices only returns a DLTensor, we add support
for DLTensor in NDArray Cache.

It's not easy to add test cases as we cannot create a raw DLTensor
in Python interface.
  • Loading branch information
Hzfengsy authored Jan 24, 2024
1 parent 0e8e421 commit 593a4bd
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/runtime/relax_vm/ndarray_cache_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,28 @@ class NDArrayCache {
};

TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK(args.size() == 2 || args.size() == 3);
String name = args[0];
bool is_override = args.size() == 2 ? false : args[2];

NDArray arr;
if (args[1].type_code() == kTVMNDArrayHandle) {
arr = args[1];
} else {
// We support converting DLTensors to NDArrays as RPC references are always DLTensors
DLTensor* tensor = args[1];
std::vector<int64_t> shape;
for (int64_t i = 0; i < tensor->ndim; i++) {
shape.push_back(tensor->shape[i]);
}
NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device);
arr.CopyFrom(tensor);
TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr);
}

NDArrayCache::Update(name, arr, is_override);
});
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load);
Expand Down

0 comments on commit 593a4bd

Please sign in to comment.