Skip to content

Commit

Permalink
[IREE][EP] Release device just after creating session
Browse files Browse the repository at this point in the history
This patch releases hal device just after creating the runtime session
as it's not needed after that. Late release of hal device was causing
memory issue while running a set of models consecutively.

Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav committed Sep 30, 2024
1 parent ecebd73 commit cd92bfe
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 17 deletions.
22 changes: 5 additions & 17 deletions onnxruntime/core/providers/iree/iree_ep_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ Instance::~Instance() {
if (instance) {
iree_runtime_instance_release(instance);
}
if (device) {
iree_hal_device_release(device);
}
}

iree_status_t Instance::Initialize(std::string device_str) {
Expand Down Expand Up @@ -234,22 +231,13 @@ common::Status Session::Call(const char* entrypoint_name, const OrtApi* ort_api,
auto output_tensor = context.GetOutput(i, shape.data(), shape.size());
ORT_ENFORCE(output_tensor.IsTensor());

iree_hal_buffer_t* ret_buffer = iree_hal_buffer_view_buffer(ret.bv);
// TODO: Synchronous mapping read, like everything in this function, is not a
// great idea. It isn't supported on all device types and will need a scrub.
iree_string_view_t device_val = iree_hal_device_id(device);
auto device_str = std::string(device_val.data, device_val.size);
if (device_str == "hip") {
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
ret_buffer, 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
return common::Status::OK();
}
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_buffer_map_read(ret_buffer, /*source_offset=*/0,
output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv))));
ORT_RETURN_IF_ERROR(HandleIREEStatus(iree_hal_device_transfer_d2h(
iree_runtime_session_device(session),
iree_hal_buffer_view_buffer(ret.bv), 0, output_tensor.GetTensorMutableRawData(),
iree_hal_buffer_view_byte_length(ret.bv), IREE_HAL_TRANSFER_BUFFER_FLAG_DEFAULT,
iree_infinite_timeout())));
}

return common::Status::OK();
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/providers/iree/iree_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ common::Status IREEExecutionProvider::Compile(const std::vector<FusedNodeAndGrap
// In case device info is absent, set `local-task` as default device.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->Initialize()));

// Release hal device after session initialization.
if (rt_instance_->device) {
iree_hal_device_release(rt_instance_->device);
}

// Load the compiled module, releasing our ownership of the CompilerOutput.
ORT_RETURN_IF_ERROR(iree_ep_rt::HandleIREEStatus(rt_session->AppendBytecodeModule(vmfb_path,
vmfb_output.Release(vmfb_path))));
Expand Down

0 comments on commit cd92bfe

Please sign in to comment.