From cd92bfe7a7a40360b13a1afe0c05494795cf55a8 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 30 Sep 2024 20:40:53 +0530 Subject: [PATCH] [IREE][EP] Release device just after creating session This patch releases hal device just after creating the runtime session as it's not needed after that. Late release of hal device was causing memory issue while running a set of models consecutively. Signed-Off-by: Gaurav Shukla --- .../core/providers/iree/iree_ep_runtime.cc | 22 +++++-------------- .../providers/iree/iree_execution_provider.cc | 5 +++++ 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/iree/iree_ep_runtime.cc b/onnxruntime/core/providers/iree/iree_ep_runtime.cc index 086ef9962465a..7caee054d94dc 100644 --- a/onnxruntime/core/providers/iree/iree_ep_runtime.cc +++ b/onnxruntime/core/providers/iree/iree_ep_runtime.cc @@ -27,9 +27,6 @@ Instance::~Instance() { if (instance) { iree_runtime_instance_release(instance); } - if (device) { - iree_hal_device_release(device); - } } iree_status_t Instance::Initialize(std::string device_str) { @@ -234,22 +231,13 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api, auto output_tensor = context.GetOutput(i, shape.data(), shape.size()); ORT_ENFORCE(output_tensor.IsTensor()); - iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv); // TODO: Synchronous mapping read, like everything in this function, is not a // great idea. It isn't supported on all device types and will need a scrub. - iree_string_view_t device_val = iree_hal_device_id(device); - auto device_str = std::string(device_val.data, device_val.size); - if (device_str == "hip") { - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( - iree_runtime_session_device(session), - ret_buffer, 0, output_tensor.GetTensorMutableRawData(), - iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, - iree_infinite_timeout()))); - return common::Status::OK(); - } - ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0, - output_tensor.GetTensorMutableRawData(), - iree_hal_buffer_view_byte_length(ret.bv)))); + ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h( + iree_runtime_session_device(session), + iree_hal_buffer_view_buffer(ret.bv), 0, output_tensor.GetTensorMutableRawData(), + iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT, + iree_infinite_timeout()))); } return common::Status::OK(); diff --git a/onnxruntime/core/providers/iree/iree_execution_provider.cc b/onnxruntime/core/providers/iree/iree_execution_provider.cc index d504561707e60..a3e037eb04ac8 100644 --- a/onnxruntime/core/providers/iree/iree_execution_provider.cc +++ b/onnxruntime/core/providers/iree/iree_execution_provider.cc @@ -168,6 +168,11 @@ common::Status IREEExecutionProvider::Compile(const std::vectorInitialize())); + // Release hal device after session initialization. + if (rt_instance_->device) { + iree_hal_device_release(rt_instance_->device); + } + // Load the compiled module, releasing our ownership of the CompilerOutput. ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(vmfb_path, vmfb_output.Release(vmfb_path))));