From 4ac4cd26687ce1b15c22206f9e281443639aa5e0 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous <107195283+TedThemistokleous@users.noreply.github.com> Date: Fri, 12 Jul 2024 00:21:38 -0400 Subject: [PATCH] Migraphx ep windows build (#21284) ### Description Repeat of #21084 with removal of policy CMP0144 to suppress warnings which uses CMake 3.27.0. ### Motivation and Context Already approved PR: https://github.com/microsoft/onnxruntime/pull/21084 Removed the added policy from CMake 3.27.0. --- cmake/CMakeLists.txt | 5 +- cmake/onnxruntime_providers_migraphx.cmake | 56 +++--- .../providers/migraphx/gpu_data_transfer.cc | 10 - ...hip_allocator.cc => migraphx_allocator.cc} | 14 +- .../{hip_allocator.h => migraphx_allocator.h} | 14 +- .../core/providers/migraphx/migraphx_call.cc | 25 +-- .../core/providers/migraphx/migraphx_call.h | 2 - .../migraphx/migraphx_execution_provider.cc | 95 ++++------ .../migraphx/migraphx_execution_provider.h | 16 +- .../migraphx_execution_provider_info.h | 2 +- .../migraphx_execution_provider_utils.h | 2 +- .../core/providers/migraphx/migraphx_inc.h | 2 +- .../migraphx/migraphx_provider_factory.cc | 19 +- .../migraphx/migraphx_provider_factory.h | 9 + .../migraphx/migraphx_stream_handle.cc | 171 ++++++++++++++++++ .../migraphx/migraphx_stream_handle.h | 48 +++++ .../providers/shared_library/provider_api.h | 3 + .../provider_bridge_provider.cc | 12 +- .../shared_library/provider_interfaces.h | 5 + .../core/session/provider_bridge_ort.cc | 21 +++ setup.py | 29 ++- tools/ci_build/build.py | 11 +- 22 files changed, 410 insertions(+), 161 deletions(-) rename onnxruntime/core/providers/migraphx/{hip_allocator.cc => migraphx_allocator.cc} (83%) rename onnxruntime/core/providers/migraphx/{hip_allocator.h => migraphx_allocator.h} (78%) create mode 100644 onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc create mode 100644 onnxruntime/core/providers/migraphx/migraphx_stream_handle.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index daacd221caa93..e2fc3da9de35e 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1488,9 +1488,6 @@ if (onnxruntime_USE_CUDA) endif() if (onnxruntime_USE_MIGRAPHX) - if (WIN32) - message(FATAL_ERROR "MIGraphX does not support build in Windows!") - endif() set(AMD_MIGRAPHX_HOME ${onnxruntime_MIGRAPHX_HOME}) endif() @@ -1560,7 +1557,7 @@ if (UNIX OR onnxruntime_USE_NCCL) if (onnxruntime_USE_NCCL) if (onnxruntime_USE_CUDA) set(NCCL_LIBNAME "nccl") - elseif (onnxruntime_USE_ROCM) + elseif (onnxruntime_USE_ROCM OR onnxruntime_USE_MIGRAPHX) set(NCCL_LIBNAME "rccl") endif() find_path(NCCL_INCLUDE_DIR diff --git a/cmake/onnxruntime_providers_migraphx.cmake b/cmake/onnxruntime_providers_migraphx.cmake index 01c4f8b2c8719..d7d83b0ce8d64 100644 --- a/cmake/onnxruntime_providers_migraphx.cmake +++ b/cmake/onnxruntime_providers_migraphx.cmake @@ -19,23 +19,25 @@ endif() # Add search paths for default rocm installation - list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH}) - find_package(hip) - find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME}) + # Suppress the warning about the small capitals of the package name - Enable when support to CMake 3.27.0 is used + # cmake_policy(SET CMP0144 NEW) - find_package(miopen) - find_package(rocblas) + if(WIN32 AND NOT HIP_PLATFORM) + set(HIP_PLATFORM "amd") + endif() + + find_package(hip REQUIRED) + find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME}) - set(migraphx_libs migraphx::c hip::host MIOpen roc::rocblas) + set(migraphx_libs migraphx::c hip::host) file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.h" - "${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) @@ -46,18 +48,16 @@ set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) - target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") - set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") - target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp) - - include(CheckLibraryExists) - check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC) - if(HAS_STREAM_SYNC) - target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC) - message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED") + if(MSVC) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32) else() - message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED") + target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + endif() + if(UNIX) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs) endif() if (onnxruntime_ENABLE_TRAINING_OPS) @@ -68,8 +68,16 @@ endif() endif() - install(TARGETS onnxruntime_providers_migraphx - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - ) + if(CMAKE_SYSTEM_NAME STREQUAL "Windows") + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + else() + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + ) + endif() diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 72193ef6268c1..94480c308b99f 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -60,17 +60,7 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst, HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); } } else if (src_device.Type() == OrtDevice::GPU) { -#ifndef MIGRAPHX_STREAM_SYNC - if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { - // copying from GPU to pinned memory, this is non-blocking - HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); - } else { - // copying from GPU to CPU memory, this is blocking - HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); - } -#else HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast(stream.GetHandle()))); -#endif } else { // copying between cpu memory memcpy(dst_data, src_data, bytes); diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc similarity index 83% rename from onnxruntime/core/providers/migraphx/hip_allocator.cc rename to onnxruntime/core/providers/migraphx/migraphx_allocator.cc index 53f10e318e65f..0693eea056416 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -3,7 +3,7 @@ #include "core/providers/shared_library/provider_api.h" #include "migraphx_call.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "core/common/status.h" #include "core/framework/float16.h" #include "core/common/status.h" @@ -11,7 +11,7 @@ namespace onnxruntime { -void HIPAllocator::CheckDevice() const { +void MIGraphXAllocator::CheckDevice() const { #ifndef NDEBUG // check device to match at debug build // if it's expected to change, call hipSetDevice instead of the check @@ -23,7 +23,7 @@ void HIPAllocator::CheckDevice() const { #endif } -void* HIPAllocator::Alloc(size_t size) { +void* MIGraphXAllocator::Alloc(size_t size) { CheckDevice(); void* p = nullptr; if (size > 0) { @@ -32,12 +32,12 @@ void* HIPAllocator::Alloc(size_t size) { return p; } -void HIPAllocator::Free(void* p) { +void MIGraphXAllocator::Free(void* p) { CheckDevice(); (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown } -void* HIPExternalAllocator::Alloc(size_t size) { +void* MIGraphXExternalAllocator::Alloc(size_t size) { void* p = nullptr; if (size > 0) { p = alloc_(size); @@ -49,7 +49,7 @@ void* HIPExternalAllocator::Alloc(size_t size) { return p; } -void HIPExternalAllocator::Free(void* p) { +void MIGraphXExternalAllocator::Free(void* p) { free_(p); std::lock_guard lock(lock_); auto it = reserved_.find(p); @@ -59,7 +59,7 @@ void HIPExternalAllocator::Free(void* p) { } } -void* HIPExternalAllocator::Reserve(size_t size) { +void* MIGraphXExternalAllocator::Reserve(size_t size) { void* p = Alloc(size); if (!p) return nullptr; std::lock_guard lock(lock_); diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.h b/onnxruntime/core/providers/migraphx/migraphx_allocator.h similarity index 78% rename from onnxruntime/core/providers/migraphx/hip_allocator.h rename to onnxruntime/core/providers/migraphx/migraphx_allocator.h index 3244f9f04ea70..64da844e8c714 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.h +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.h @@ -9,12 +9,12 @@ namespace onnxruntime { -class HIPAllocator : public IAllocator { +class MIGraphXAllocator : public IAllocator { public: - HIPAllocator(int device_id, const char* name) + MIGraphXAllocator(int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id), + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(device_id)), device_id, OrtMemTypeDefault)) {} virtual void* Alloc(size_t size) override; @@ -24,14 +24,14 @@ class HIPAllocator : public IAllocator { void CheckDevice() const; }; -class HIPExternalAllocator : public HIPAllocator { +class MIGraphXExternalAllocator : public MIGraphXAllocator { typedef void* (*ExternalAlloc)(size_t size); typedef void (*ExternalFree)(void* p); typedef void (*ExternalEmptyCache)(); public: - HIPExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) - : HIPAllocator(device_id, name) { + MIGraphXExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + : MIGraphXAllocator(device_id, name) { alloc_ = reinterpret_cast(alloc); free_ = reinterpret_cast(free); empty_cache_ = reinterpret_cast(empty_cache); @@ -55,7 +55,7 @@ class HIPPinnedAllocator : public IAllocator { HIPPinnedAllocator(int device_id, const char* name) : IAllocator( OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, device_id), + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(device_id)), device_id, OrtMemTypeCPUOutput)) {} virtual void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc index 5248ac2f39214..9807cd646e51c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#ifdef _WIN32 +#include +#else #include -#include -#include -#include +#endif + +#include #include "core/common/common.h" #include "core/common/status.h" #include "core/providers/shared_library/provider_api.h" @@ -34,16 +37,20 @@ std::conditional_t RocmCall( ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { if (retCode != successCode) { try { - char hostname[HOST_NAME_MAX]; - if (gethostname(hostname, HOST_NAME_MAX) != 0) - strcpy(hostname, "?"); +#ifdef _WIN32 + // According to the POSIX spec, 255 is the safe minimum value. + static constexpr int HOST_NAME_MAX = 255; +#endif + std::string hostname(HOST_NAME_MAX, 0); + if (gethostname(hostname.data(), HOST_NAME_MAX) != 0) + hostname = "?"; int currentHipDevice; (void)hipGetDevice(¤tHipDevice); (void)hipGetLastError(); // clear last HIP error static char str[1024]; snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; file=%s ; line=%d ; expr=%s; %s", libName, (int)retCode, RocmErrString(retCode), currentHipDevice, - hostname, + hostname.c_str(), file, line, exprString, msg); if constexpr (THRW) { // throw an exception with the error info @@ -68,9 +75,5 @@ std::conditional_t RocmCall( template Status RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); template void RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg, const char* file, const int line); -template Status RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template void RocmCall(rocblas_status retCode, const char* exprString, const char* libName, rocblas_status successCode, const char* msg, const char* file, const int line); -template Status RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); -template void RocmCall(miopenStatus_t retCode, const char* exprString, const char* libName, miopenStatus_t successCode, const char* msg, const char* file, const int line); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h index 15d385a636b76..f6a95cebf34b5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_call.h +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -4,8 +4,6 @@ #pragma once #include "migraphx_inc.h" -#pragma once - namespace onnxruntime { // ----------------------------------------------------------------------- diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 6ee85c3a4c047..097b16ecde536 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -13,12 +13,11 @@ #include "core/common/logging/severity.h" #include "migraphx_execution_provider.h" #include "migraphx_execution_provider_utils.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "gpu_data_transfer.h" #include "migraphx_inc.h" -// TODO: find a better way to share this -#include "core/providers/rocm/rocm_stream_handle.h" +#include "migraphx_stream_handle.h" #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -102,10 +101,10 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info) { InitProviderOrtApi(); // Set GPU device to be used - HIP_CALL_THROW(hipSetDevice(device_id_)); + HIP_CALL_THROW(hipSetDevice(info_.device_id)); t_ = migraphx::target(info.target_device.c_str()); // whether fp16 is enable @@ -181,16 +180,10 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); } - ROCBLAS_CALL_THROW(rocblas_create_handle(&external_rocblas_handle_)); - ROCBLAS_CALL_THROW(rocblas_set_stream(external_rocblas_handle_, stream_)); - - MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); - MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); - metadef_id_generator_ = ModelMetadefIdGenerator::Create(); LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " - << "device_id: " << device_id_ + << "device_id: " << info_.device_id << ", migraphx_fp16_enable: " << fp16_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ @@ -205,17 +198,14 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv } MIGraphXExecutionProvider::~MIGraphXExecutionProvider() { - ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(external_rocblas_handle_))); - ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(external_miopen_handle_))); } std::vector MIGraphXExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateROCMAllocator(device_id, onnxruntime::CUDA); }, device_id_); + [](OrtDevice::DeviceId device_id) { return CreateMIGraphXAllocator(device_id, onnxruntime::CUDA); }, info_.device_id); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - ORT_UNUSED_PARAMETER(device_id); - return CreateROCMPinnedAllocator(onnxruntime::CUDA_PINNED); + return CreateMIGraphXPinnedAllocator(device_id, onnxruntime::CUDA_PINNED); }, 0); return std::vector{CreateAllocator(default_memory_info), CreateAllocator(pinned_allocator_info)}; @@ -254,40 +244,40 @@ static bool getMIGraphXType(ONNXTensorElementDataType type, migraphx_shape_datatype_t& mgx_type) { mgx_type = migraphx_shape_float_type; switch (type) { - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: mgx_type = migraphx_shape_half_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: mgx_type = migraphx_shape_float_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_DOUBLE: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: mgx_type = migraphx_shape_double_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: mgx_type = migraphx_shape_int8_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: mgx_type = migraphx_shape_int16_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: mgx_type = migraphx_shape_int32_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: mgx_type = migraphx_shape_int64_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: mgx_type = migraphx_shape_uint8_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: mgx_type = migraphx_shape_uint16_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: mgx_type = migraphx_shape_uint32_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT64: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: mgx_type = migraphx_shape_uint64_type; break; - case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: mgx_type = migraphx_shape_bool_type; break; default: @@ -303,7 +293,7 @@ std::vector toVector(const ONNX_NAMESPACE::int64s& nums) { std::vector result; int num = nums.size(); for (int i = 0; i < num; ++i) { - result.push_back(nums[i]); + result.push_back(static_cast(nums[i])); } return result; @@ -501,16 +491,9 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co if (arg_s != nullptr) { const auto& tensor_dims = arg_s->dim(); std::vector dims; - std::transform(tensor_dims.begin(), - tensor_dims.end(), - std::back_inserter(dims), - [&](auto&& d) -> std::size_t { - if (d.has_dim_value()) { - return d.dim_value(); - } else { - return 0; - } - }); + for (auto&& dim : tensor_dims) { + dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 0); + } if (dims == std::vector{0}) { return true; } @@ -546,8 +529,8 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::vector>& clusters, - const logging::Logger& logger) { - // Then check whether a subgraph should fallback to CPU + [[maybe_unused]] const logging::Logger& logger) { + // Then check whether a subgraph should fall back to CPU // 1. Check whether a subgraph contains a RNN operator std::unordered_set rnn_names = {"RNN", "GRU", "LSTM"}; std::unordered_set op_names = {"AveragePool", "Conv", "Gemm", "LRN", "MatMul", "MaxPool"}; @@ -591,17 +574,10 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (arg_s == nullptr) return false; const auto& tensor_dims = arg_s->dim(); std::vector dims; - std::transform(tensor_dims.begin(), - tensor_dims.end(), - std::back_inserter(dims), - [&](auto&& d) -> std::size_t { - if (d.has_dim_value()) { - return d.dim_value(); - } else { - return 1; - } - }); - return (std::accumulate(dims.begin(), dims.end(), 1, std::multiplies{}) > 300); + for (auto&& dim : tensor_dims) { + dims.emplace_back(dim.has_dim_value() ? dim.dim_value() : 1); + } + return (std::accumulate(dims.begin(), dims.end(), 1ULL, std::multiplies{}) > 300); })) { return false; } @@ -623,7 +599,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v static bool IsNodeSupported(const std::set& op_set, const onnxruntime::GraphViewer& graph_viewer, const NodeIndex node_idx, - const logging::Logger& logger) { + [[maybe_unused]] const logging::Logger& logger) { const auto& node = graph_viewer.GetNode(node_idx); const auto& optype = node->OpType(); const auto& domain = node->Domain(); @@ -1442,14 +1418,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // lock to avoid race condition std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); -#ifdef MIGRAPHX_STREAM_SYNC void* rocm_stream; Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &rocm_stream)); auto prog_outputs = prog.run_async(m, static_cast(rocm_stream)); -#else - auto prog_outputs = prog.eval(m); - HIP_CALL_THROW(hipDeviceSynchronize()); -#endif + // In case of input parameters are reused as output parameter call hipMemcpy auto output_num = prog_outputs.size(); if (prog_output_indices.size() < output_num) { @@ -1478,8 +1450,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& void MIGraphXExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const { auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)]; - RegisterRocmStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, - false /*TODO:external_stream_*/, external_miopen_handle_, external_rocblas_handle_); + RegisterMIGraphXStreamHandles(stream_handle_registry, OrtDevice::GPU, allocator, true, stream_, false /*TODO:external_stream_*/); } OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { @@ -1487,7 +1458,6 @@ OrtDevice MIGraphXExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) if (mem_type == OrtMemTypeCPUOutput) return OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, 0 /*CPU device id always be 0*/); return default_device_; } -#ifdef MIGRAPHX_STREAM_SYNC Status MIGraphXExecutionProvider::Sync() const { HIP_CALL_THROW(hipStreamSynchronize(static_cast(nullptr))); @@ -1512,5 +1482,4 @@ Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxrunti return Status::OK(); } -#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 1977f71b8b1cf..f34ca320d0a5a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,9 +3,6 @@ #pragma once -#include -#include - #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" @@ -14,8 +11,6 @@ #include #include -// TODO: find a better way to share this -// #include "core/providers/cuda/rocm_stream_handle.h" namespace onnxruntime { @@ -62,13 +57,11 @@ class MIGraphXExecutionProvider : public IExecutionProvider { explicit MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info); ~MIGraphXExecutionProvider(); -#ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; Status OnRunStart(const onnxruntime::RunOptions& run_options) override; Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; -#endif std::vector> GetCapability(const onnxruntime::GraphViewer& graph_viewer, @@ -85,7 +78,13 @@ class MIGraphXExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; + int GetDeviceId() const override { return info_.device_id; } + ProviderOptions GetProviderOptions() const override { + return MIGraphXExecutionProviderInfo::ToProviderOptions(info_); + } + private: + MIGraphXExecutionProviderInfo info_; bool fp16_enable_ = false; bool int8_enable_ = false; std::string int8_calibration_cache_name_; @@ -98,7 +97,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { bool load_compiled_model_ = false; std::string load_compiled_path_; bool dump_model_ops_ = false; - int device_id_; migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; @@ -109,8 +107,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::unordered_map map_no_input_shape_; AllocatorPtr allocator_; - miopenHandle_t external_miopen_handle_ = nullptr; - rocblas_handle external_rocblas_handle_ = nullptr; std::unique_ptr metadef_id_generator_; }; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 8411e3eef096b..68d5d9af98ea4 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -14,7 +14,7 @@ namespace onnxruntime { // Information needed to construct trt execution providers. struct MIGraphXExecutionProviderInfo { std::string target_device; - int device_id{0}; + OrtDevice::DeviceId device_id{0}; bool fp16_enable{false}; bool int8_enable{false}; std::string int8_calibration_table_name{""}; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h index 071070e92a209..9274b5696185c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_utils.h @@ -28,7 +28,7 @@ bool IsGraphInput(const GraphViewer& graph, const std::string& name) { return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); } -bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) { +bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, [[maybe_unused]] bool check_outer_scope = true) { const ONNX_NAMESPACE::TensorProto* initializer = nullptr; return graph.GetInitializedTensor(name, initializer); } diff --git a/onnxruntime/core/providers/migraphx/migraphx_inc.h b/onnxruntime/core/providers/migraphx/migraphx_inc.h index 96b24051ace76..2b035b20f619f 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_inc.h +++ b/onnxruntime/core/providers/migraphx/migraphx_inc.h @@ -4,5 +4,5 @@ #pragma once #include -#include +#include #include diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index dd24dbdc76d2f..6d199930116e8 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -6,7 +6,7 @@ #include "core/providers/migraphx/migraphx_provider_factory.h" #include "migraphx_execution_provider.h" #include "migraphx_provider_factory_creator.h" -#include "hip_allocator.h" +#include "migraphx_allocator.h" #include "gpu_data_transfer.h" #include "core/framework/provider_options.h" @@ -33,10 +33,23 @@ std::unique_ptr MIGraphXProviderFactory::CreateProvider() { return std::make_unique(info_); } +struct ProviderInfo_MIGraphX_Impl final : ProviderInfo_MIGraphX { + std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + + std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + +} g_info; + struct MIGraphX_Provider : Provider { + void* GetInfo() override { return &g_info; } + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; - info.device_id = device_id; + info.device_id = static_cast(device_id); info.target_device = "gpu"; return std::make_shared(info); } @@ -44,7 +57,7 @@ struct MIGraphX_Provider : Provider { std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { auto& options = *reinterpret_cast(provider_options); MIGraphXExecutionProviderInfo info; - info.device_id = options.device_id; + info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; info.int8_enable = options.migraphx_int8_enable; diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h index ac9834e64942a..b257a4318dc0e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -10,4 +10,13 @@ struct IExecutionProviderFactory; struct MIGraphXExecutionProviderInfo; enum class ArenaExtendStrategy : int32_t; struct MIGraphXExecutionProviderExternalAllocatorInfo; + +struct ProviderInfo_MIGraphX { + virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; + + protected: + ~ProviderInfo_MIGraphX() = default; // Can only be destroyed through a subclass instance +}; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc new file mode 100644 index 0000000000000..9c5bb4ecf5c97 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.cc @@ -0,0 +1,171 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "migraphx_stream_handle.h" + +namespace onnxruntime { + +struct MIGraphXNotification : public synchronize::Notification { + MIGraphXNotification(Stream& s) : Notification(s) { + HIP_CALL_THROW(hipEventCreateWithFlags(&event_, hipEventDisableTiming)); + } + + ~MIGraphXNotification() { + if (event_) + HIP_CALL_THROW(hipEventDestroy(event_)); + } + + void Activate() override { + // record event with hipEventBlockingSync so we can support sync on host without busy wait. + HIP_CALL_THROW(hipEventRecord(event_, static_cast(stream_.GetHandle()))); + } + + void wait_on_device(Stream& device_stream) { + ORT_ENFORCE(device_stream.GetDevice().Type() == OrtDevice::GPU, "Unexpected device:", device_stream.GetDevice().ToString()); + // launch a wait command to the migraphx stream + HIP_CALL_THROW(hipStreamWaitEvent(static_cast(device_stream.GetHandle()), event_, 0)); + }; + + void wait_on_host() { + // CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); + HIP_CALL_THROW(hipEventSynchronize(event_)); + } + + hipEvent_t event_; +}; + +MIGraphXStream::MIGraphXStream(hipStream_t stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream) + : Stream(stream, device), + cpu_allocator_(cpu_allocator), + release_cpu_buffer_on_migraphx_stream_(release_cpu_buffer_on_migraphx_stream) { +} + +MIGraphXStream::~MIGraphXStream() { + ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); + if (own_stream_) { + auto* handle = GetHandle(); + if (handle) + HIP_CALL_THROW(hipStreamDestroy(static_cast(handle))); + } +} + +std::unique_ptr MIGraphXStream::CreateNotification(size_t /*num_consumers*/) { + return std::make_unique(*this); +} + +void MIGraphXStream::Flush() { + if (own_stream_) + HIP_CALL_THROW(hipStreamSynchronize(static_cast(GetHandle()))); +} + +void MIGraphXStream::EnqueDeferredCPUBuffer(void* cpu_buffer) { + // stream is per thread, so don't need lock + deferred_cpu_buffers_.push_back(cpu_buffer); +} + +struct CpuBuffersInfo { + // This struct stores the information needed + // to release CPU buffers allocated for GPU kernels. + // It's used to enqueue their release after + // associated GPU kernels in a MIGraphX stream. + + // This is a CPU allocator in MIGraphX EP. + // It must be the one used to allocate the + // following pointers. + AllocatorPtr allocator; + // buffers[i] is the i-th pointer added by + // AddDeferredReleaseCPUPtr for a specific + // MIGraphX stream. For example, this fields + // should contain all values in + // deferred_release_buffer_pool_[my_stream] + // when release my_stream's buffers. + std::unique_ptr buffers; + // CPU buffer buffers[i]. + // Number of buffer points in "buffers". + size_t n_buffers; +}; + +static void ReleaseCpuBufferCallback(void* raw_info) { + std::unique_ptr info = std::make_unique(); + info.reset(reinterpret_cast(raw_info)); + for (size_t i = 0; i < info->n_buffers; ++i) { + info->allocator->Free(info->buffers[i]); + } +} + +Status MIGraphXStream::CleanUpOnRunEnd() { + if (deferred_cpu_buffers_.empty()) + return Status::OK(); + // Release the ownership of cpu_buffers_info so that the underlying + // object will keep alive until the end of ReleaseCpuBufferCallback. + if (release_cpu_buffer_on_migraphx_stream_ && cpu_allocator_->Info().alloc_type == OrtArenaAllocator) { + std::unique_ptr cpu_buffers_info = std::make_unique(); + cpu_buffers_info->allocator = cpu_allocator_; + cpu_buffers_info->buffers = std::make_unique(deferred_cpu_buffers_.size()); + for (size_t i = 0; i < deferred_cpu_buffers_.size(); ++i) { + cpu_buffers_info->buffers[i] = deferred_cpu_buffers_.at(i); + } + cpu_buffers_info->n_buffers = deferred_cpu_buffers_.size(); + HIP_RETURN_IF_ERROR(hipLaunchHostFunc(static_cast(GetHandle()), ReleaseCpuBufferCallback, cpu_buffers_info.release())); + } else { + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(GetHandle()))); + for (auto* buffer : deferred_cpu_buffers_) { + cpu_allocator_->Free(buffer); + } + } + + deferred_cpu_buffers_.clear(); + return Status::OK(); +} + +void* MIGraphXStream::GetResource(int version, int id) const { + ORT_ENFORCE(version <= ORT_ROCM_RESOUCE_VERSION, "resource version unsupported!"); + void* resource{}; + switch (id) { + case RocmResource::hip_stream_t: + return reinterpret_cast(GetHandle()); + default: + break; + } + return resource; +} + +// CPU Stream command handles +void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_device(stream); +} + +void WaitMIGraphXNotificationOnHost(Stream& /*stream*/, synchronize::Notification& notification) { + static_cast(¬ification)->wait_on_host(); +} + +void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream, + hipStream_t external_stream, + bool use_existing_stream) { + // wait migraphx notification on migraphx ep + stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitMIGraphXNotificationOnDevice); + // wait migraphx notification on cpu ep + stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitMIGraphXNotificationOnHost); + if (!use_existing_stream) + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_migraphx_stream](const OrtDevice& device) { + HIP_CALL_THROW(hipSetDevice(device.Id())); + hipStream_t stream = nullptr; + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); + }); + else + stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, + release_cpu_buffer_on_migraphx_stream, + external_stream](const OrtDevice& device) { + return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_migraphx_stream); + }); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h new file mode 100644 index 0000000000000..03a7c1607e3ad --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_stream_handle.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/framework/stream_handles.h" +#include "migraphx_inc.h" +#include "migraphx_call.h" + +#define HIP_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(HIP_CALL(expr)) + +namespace onnxruntime { +void WaitMIGraphXNotificationOnDevice(Stream& stream, synchronize::Notification& notification); + +struct MIGraphXStream : Stream { + MIGraphXStream(hipStream_t stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream); + + ~MIGraphXStream(); + + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override; + + void Flush() override; + + Status CleanUpOnRunEnd() override; + + void EnqueDeferredCPUBuffer(void* cpu_buffer); + + bool own_stream_{true}; + + virtual void* GetResource(int version, int id) const; + + virtual WaitNotificationFn GetWaitNotificationFn() const { return WaitMIGraphXNotificationOnDevice; } + + private: + std::vector deferred_cpu_buffers_; + AllocatorPtr cpu_allocator_; + bool release_cpu_buffer_on_migraphx_stream_{true}; +}; + +void RegisterMIGraphXStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, + const OrtDevice::DeviceType device_type, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_migraphx_stream, + hipStream_t external_stream, + bool use_existing_stream); +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 590bddabdba54..86e49627fe26b 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -279,6 +279,9 @@ std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name); std::unique_ptr CreateCUDAPinnedAllocator(const char* name); +std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name); +std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name); + std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name); std::unique_ptr CreateROCMPinnedAllocator(const char* name); diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 6e6a80f097c12..7fb9fd3c8cfd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -353,16 +353,12 @@ std::unique_ptr CreateGPUDataTransfer() { #endif #ifdef USE_MIGRAPHX -std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) { - return g_host->CreateROCMAllocator(device_id, name); +std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) { + return g_host->CreateMIGraphXAllocator(device_id, name); } -std::unique_ptr CreateROCMPinnedAllocator(const char* name) { - return g_host->CreateROCMPinnedAllocator(name); -} - -std::unique_ptr CreateGPUDataTransfer() { - return g_host->CreateGPUDataTransfer(); +std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) { + return g_host->CreateMIGraphXPinnedAllocator(device_id, name); } #endif diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index bc6dac1a2f27f..4d40fcafaeea1 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -179,6 +179,11 @@ struct ProviderHost { virtual void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) = 0; #endif +#ifdef USE_MIGRAPHX + virtual std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) = 0; +#endif + #ifdef USE_ROCM virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateROCMPinnedAllocator(const char* name) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index b53e70926cd5d..bd5a68152fb71 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -132,6 +132,8 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); +ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX(); +ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX(); ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -243,6 +245,11 @@ struct ProviderHostImpl : ProviderHost { void CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg, const char* file, const int line) override { GetProviderInfo_CUDA().CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } #endif +#ifdef USE_MIGRAPHX + std::unique_ptr CreateMIGraphXAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXAllocator(device_id, name); } + std::unique_ptr CreateMIGraphXPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGraphX().CreateMIGraphXPinnedAllocator(device_id, name); } +#endif + #ifdef USE_ROCM std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_ROCM().CreateROCMAllocator(device_id, name); } std::unique_ptr CreateROCMPinnedAllocator(const char* name) override { return GetProviderInfo_ROCM().CreateROCMPinnedAllocator(name); } @@ -1954,6 +1961,20 @@ ProviderInfo_ROCM& GetProviderInfo_ROCM() { ORT_THROW("ROCM Provider not available, can't get interface for it"); } +ProviderInfo_MIGraphX* TryGetProviderInfo_MIGraphX() try { + return reinterpret_cast(s_library_migraphx.Get().GetInfo()); +} catch (const std::exception& exception) { + LOGS_DEFAULT(ERROR) << exception.what(); + return nullptr; +} + +ProviderInfo_MIGraphX& GetProviderInfo_MIGraphX() { + if (auto* info = TryGetProviderInfo_MIGraphX()) + return *info; + + ORT_THROW("MIGraphX Provider not available, can't get interface for it"); +} + void CopyGpuToCpu( void* dst_ptr, const void* src_ptr, diff --git a/setup.py b/setup.py index 5750833ce35de..51feedcfd3286 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = None rocm_version = None +is_migraphx = False is_rocm = False is_openvino = False # The following arguments are mutually exclusive @@ -64,8 +65,9 @@ def parse_arg_remove_string(argv, arg_name_equal): cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=") elif parse_arg_remove_boolean(sys.argv, "--use_rocm"): is_rocm = True - package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=") +elif parse_arg_remove_boolean(sys.argv, "--use_migraphx"): + is_migraphx = True elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): is_openvino = True package_name = "onnxruntime-openvino" @@ -87,6 +89,9 @@ def parse_arg_remove_string(argv, arg_name_equal): elif parse_arg_remove_boolean(sys.argv, "--use_qnn"): package_name = "onnxruntime-qnn" +if is_rocm or is_migraphx: + package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" + # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 # PEP 599 defines the following platform tags: @@ -280,10 +285,21 @@ def finalize_options(self): return ret -providers_cuda_or_rocm = "libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") -providers_tensorrt_or_migraphx = "libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") -providers_openvino = "libonnxruntime_providers_openvino.so" -providers_cann = "libonnxruntime_providers_cann.so" +providers_cuda_or_rocm = "onnxruntime_providers_" + ("rocm" if is_rocm else "cuda") +providers_tensorrt_or_migraphx = "onnxruntime_providers_" + ("migraphx" if is_migraphx else "tensorrt") +providers_openvino = "onnxruntime_providers_openvino" +providers_cann = "onnxruntime_providers_cann" + +if platform.system() == "Linux": + providers_cuda_or_rocm = "lib" + providers_cuda_or_rocm + ".so" + providers_tensorrt_or_migraphx = "lib" + providers_tensorrt_or_migraphx + ".so" + providers_openvino = "lib" + providers_openvino + ".so" + providers_cann = "lib" + providers_cann + ".so" +elif platform.system() == "Windows": + providers_cuda_or_rocm = providers_cuda_or_rocm + ".dll" + providers_tensorrt_or_migraphx = providers_tensorrt_or_migraphx + ".dll" + providers_openvino = providers_openvino + ".dll" + providers_cann = providers_cann + ".dll" # Additional binaries dl_libs = [] @@ -335,6 +351,9 @@ def finalize_options(self): "dnnl.dll", "mklml.dll", "libiomp5md.dll", + providers_cuda_or_rocm, + providers_tensorrt_or_migraphx, + providers_cann, "onnxruntime.dll", ] # DNNL, TensorRT & OpenVINO EPs are built as shared libs diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index ae4c9b27544ba..75fbf5d0851ae 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -613,6 +613,7 @@ def convert_arg_line_to_args(self, arg_line): "MinGW Makefiles", "Ninja", "NMake Makefiles", + "NMake Makefiles JOM", "Unix Makefiles", "Visual Studio 17 2022", "Xcode", @@ -2211,6 +2212,7 @@ def build_python_wheel( use_cuda, cuda_version, use_rocm, + use_migraphx, rocm_version, use_dnnl, use_tensorrt, @@ -2262,6 +2264,8 @@ def build_python_wheel( args.append("--use_rocm") if rocm_version: args.append(f"--rocm_version={rocm_version}") + elif use_migraphx: + args.append("--use_migraphx") elif use_openvino: args.append("--use_openvino") elif use_dnnl: @@ -2587,9 +2591,6 @@ def main(): if args.use_tensorrt: args.use_cuda = True - if args.use_migraphx: - args.use_rocm = True - if args.build_wheel or args.gen_doc or args.use_tvm or args.enable_training: args.enable_pybind = True @@ -2885,7 +2886,8 @@ def main(): # fail unexpectedly. Similar, if your packaging step forgot to copy a file into the package, we don't know it # either. if args.build: - # TODO: find asan DLL and copy it to onnxruntime/capi folder when args.enable_address_sanitizer is True and the target OS is Windows + # TODO: find asan DLL and copy it to onnxruntime/capi folder when args.enable_address_sanitizer is True and + # the target OS is Windows if args.build_wheel: nightly_build = bool(os.getenv("NIGHTLY_BUILD") == "1") default_training_package_device = bool(os.getenv("DEFAULT_TRAINING_PACKAGE_DEVICE") == "1") @@ -2896,6 +2898,7 @@ def main(): args.use_cuda, args.cuda_version, args.use_rocm, + args.use_migraphx, args.rocm_version, args.use_dnnl, args.use_tensorrt,