From 99eca3f2bdf5f3df7dcc3a76f75d136e8a784773 Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Fri, 25 Oct 2024 18:06:29 +0000 Subject: [PATCH] Replace hipMemcpyWithStream with hipMemcpyAsync. hipMemcpyAsync will be blocking when the host memory is not pinned, and hence there is not need to treat unpinned memory specially --- .../core/providers/migraphx/gpu_data_transfer.cc | 10 +++++----- .../migraphx/migraphx_execution_provider.cc | 14 +++++++++----- .../core/providers/rocm/gpu_data_transfer.cc | 2 ++ 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 51625b83b8f61..6bd4a3e5a1ce7 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// TODO: Unify this file with onnxruntime/core/providers/rocm/gpu_data_transfer.cc + #include "core/providers/shared_library/provider_api.h" #include "gpu_data_transfer.h" #include "migraphx_call.h" @@ -49,15 +51,13 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copy from pinned memory to GPU, this is non-blocking + if (src_device.Type() == OrtDevice::CPU) { + // copy from host memory to GPU, + // this is non-blocking if src is pinned, otherwise it is blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast(stream.GetHandle()))); - } else { - // copy from other CPU memory to GPU, this is blocking - HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast(stream.GetHandle()))); } } else if (src_device.Type() == OrtDevice::GPU) { HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index dca38480434fe..f97c07b408398 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1445,11 +1445,15 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector ort_shape{res_lens.begin(), res_lens.end()}; auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size()); void* output_data = output_tensor.GetTensorMutableRawData(); - HIP_CALL_THROW(hipMemcpyWithStream(output_data, - gpu_res.data(), - res_shape.bytes(), - hipMemcpyDeviceToDevice, - static_cast(rocm_stream))); + // Prefer hipMemcpyAsync over hipMemcpyWithStream, due to + // 1. hipMemcpyAsync will automatically block when the host memory is not pinned + // 2. hipMemcpyWithStream has a known performance problem: + // https://github.com/ROCm/clr/issues/78 + HIP_CALL_THROW(hipMemcpyAsync(output_data, + gpu_res.data(), + res_shape.bytes(), + hipMemcpyDeviceToDevice, + static_cast(rocm_stream))); } } }; diff --git a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc index 635a25480b646..f1d0f709dfabd 100644 --- a/onnxruntime/core/providers/rocm/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/rocm/gpu_data_transfer.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// TODO: Unify this file with onnxruntime/core/providers/migraphx/gpu_data_transfer.cc + #include "core/providers/shared_library/provider_api.h" #include "core/providers/rocm/gpu_data_transfer.h"