Skip to content

Commit

Permalink
Replace hipMemcpyWithStream with hipMemcpyAsync.
Browse files Browse the repository at this point in the history
hipMemcpyAsync will be blocking when the host memory is not pinned, and
hence there is not need to treat unpinned memory specially
  • Loading branch information
xinyazhang committed Oct 25, 2024
1 parent cb8c4bd commit 99eca3f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/migraphx/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// TODO: Unify this file with onnxruntime/core/providers/rocm/gpu_data_transfer.cc

#include "core/providers/shared_library/provider_api.h"
#include "gpu_data_transfer.h"
#include "migraphx_call.h"
Expand Down Expand Up @@ -49,15 +51,13 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
auto& dst_device = dst.Location().device;

if (dst_device.Type() == OrtDevice::GPU) {
if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) {
// copy from pinned memory to GPU, this is non-blocking
if (src_device.Type() == OrtDevice::CPU) {
// copy from host memory to GPU,
// this is non-blocking if src is pinned, otherwise it is blocking
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
} else if (src_device.Type() == OrtDevice::GPU) {
// copying between GPU, this is non-blocking
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
}
} else if (src_device.Type() == OrtDevice::GPU) {
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1445,11 +1445,15 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
std::vector<int64_t> ort_shape{res_lens.begin(), res_lens.end()};
auto output_tensor = ctx.GetOutput(i, ort_shape.data(), ort_shape.size());
void* output_data = output_tensor.GetTensorMutableRawData();
HIP_CALL_THROW(hipMemcpyWithStream(output_data,
gpu_res.data(),
res_shape.bytes(),
hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(rocm_stream)));
// Prefer hipMemcpyAsync over hipMemcpyWithStream, due to
// 1. hipMemcpyAsync will automatically block when the host memory is not pinned
// 2. hipMemcpyWithStream has a known performance problem:
// https://github.com/ROCm/clr/issues/78
HIP_CALL_THROW(hipMemcpyAsync(output_data,
gpu_res.data(),
res_shape.bytes(),
hipMemcpyDeviceToDevice,
static_cast<hipStream_t>(rocm_stream)));
}
}
};
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// TODO: Unify this file with onnxruntime/core/providers/migraphx/gpu_data_transfer.cc

#include "core/providers/shared_library/provider_api.h"

#include "core/providers/rocm/gpu_data_transfer.h"
Expand Down

0 comments on commit 99eca3f

Please sign in to comment.