diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index ce028f4d7dd3..b389030cfe37 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -267,7 +267,28 @@ class NDArrayCache { }; TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get); -TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body_typed(NDArrayCache::Update); +TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body([](TVMArgs args, TVMRetValue* rv) { + CHECK(args.size() == 2 || args.size() == 3); + String name = args[0]; + bool is_override = args.size() == 2 ? false : args[2]; + + NDArray arr; + if (args[1].type_code() == kTVMNDArrayHandle) { + arr = args[1]; + } else { + // We support converting DLTensors to NDArrays as RPC references are always DLTensors + DLTensor* tensor = args[1]; + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray arr = NDArray::Empty(shape, tensor->dtype, tensor->device); + arr.CopyFrom(tensor); + TVMSynchronize(arr->device.device_type, arr->device.device_id, nullptr); + } + + NDArrayCache::Update(name, arr, is_override); +}); TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove); TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear); TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load);