Skip to content

Commit

Permalink
CUDA EP vs ROCM EP hipify audit
Browse files Browse the repository at this point in the history
Migrate all CUDA EP improvements and changes to ROCM EP. The process
involves using hipify against all CUDA EP files (i.e. do not exclude any
files from onnxruntime_rocm_hipify.cmake) then vimdiff compare them
against the ROCM EP files that are under source control and pull in most
changes. These changes include functional as well as formatting and
makes comparing CUDA EP and ROCM EP easier, though it makes the PR diff
somewhat less obvious due to formatting changes.

- hipify audit of onnxruntime/core/providers/rocm, enable ops
  - Loop
  - Scan
- hipify audit of onnxruntime/contrib_ops/rocm
- fix contrib ops search implementation
- enable more contrib ops
  - Affine
  - ComplexMul
  - ConvTransposeWithDynamicPads
  - Crop
  - DynamicSlice
  - FFT [Rfft, Irfft]
  - GreedySearch
  - ImageScaler
  - ParametricSoftplus
  - ScaledTanh
  - ThresholdRelu
  • Loading branch information
jeffdaily committed Oct 3, 2023
1 parent d11e053 commit b1f3135
Show file tree
Hide file tree
Showing 46 changed files with 1,501 additions and 1,504 deletions.
3 changes: 2 additions & 1 deletion cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,7 @@ if (onnxruntime_USE_ROCM)
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(MIOpen REQUIRED)
find_package(hipfft REQUIRED)

# MIOpen version
if(NOT DEFINED ENV{MIOPEN_PATH})
Expand Down Expand Up @@ -1554,7 +1555,7 @@ if (onnxruntime_USE_ROCM)

find_library(RCCL_LIB rccl REQUIRED)
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB})
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB})

file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h"
Expand Down
27 changes: 0 additions & 27 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
"math/complex_mul_impl.h"
"math/cufft_plan_cache.h"
"math/fft_ops.cc"
"math/fft_ops.h"
"math/fft_ops_impl.cu"
"math/fft_ops_impl.h"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down Expand Up @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files
"quantization/qordered_ops/qordered_unary_ops.cc"
"quantization/qordered_ops/qordered_unary_ops_impl.h"
"quantization/qordered_ops/qordered_unary_ops_impl.cu"
"tensor/crop.cc"
"tensor/crop.h"
"tensor/crop_impl.cu"
"tensor/crop_impl.h"
"tensor/dynamicslice.cc"
"tensor/image_scaler.cc"
"tensor/image_scaler.h"
"tensor/image_scaler_impl.cu"
"tensor/image_scaler_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"conv_transpose_with_dynamic_pads.cc"
"conv_transpose_with_dynamic_pads.h"
"cuda_contrib_kernels.cc"
"cuda_contrib_kernels.h"
"inverse.cc"
Expand All @@ -114,10 +92,6 @@ endif()

set(provider_excluded_files
"atomic/common.cuh"
"controlflow/loop.cc"
"controlflow/loop.h"
"controlflow/scan.cc"
"controlflow/scan.h"
"cu_inc/common.cuh"
"math/einsum_utils/einsum_auxiliary_ops.cc"
"math/einsum_utils/einsum_auxiliary_ops.h"
Expand Down Expand Up @@ -165,7 +139,6 @@ set(provider_excluded_files
"cuda_memory_check.h"
"cuda_fence.cc"
"cuda_fence.h"
"cuda_fwd.h"
"cuda_kernel.h"
"cuda_pch.cc"
"cuda_pch.h"
Expand Down
12 changes: 6 additions & 6 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -240,7 +240,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
device_copy_int32_func_,
update_gpt_feeds_fp16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down Expand Up @@ -271,7 +271,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -293,7 +293,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -320,7 +320,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -341,7 +341,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_float_func_,
expand_buffer_float16_func_,
create_beam_scorer_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, init_cache_indir_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class BeamSearch : public IControlFlowKernel {
create_beam_scorer_func_ = create_beam_scorer_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func) {
Expand Down Expand Up @@ -96,7 +96,7 @@ class BeamSearch : public IControlFlowKernel {
expand_buffer_float16_func_ = expand_buffer_float16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
#endif
Expand All @@ -115,7 +115,7 @@ class BeamSearch : public IControlFlowKernel {
GenerationDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
GenerationDeviceHelper::CreateBeamScorer create_beam_scorer_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
update_feeds_func_(update_feeds_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const void* cuda_device_prop,
Expand Down Expand Up @@ -100,7 +100,7 @@ class BeamSearchGpt : public BeamSearchBase<T> {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
Expand Down Expand Up @@ -336,7 +336,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// Increase sequence length after a new token is generated.
++current_length;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedSelfAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
Expand Down Expand Up @@ -87,7 +87,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
// Device specific functions
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down Expand Up @@ -280,7 +280,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
expand_buffer_float16_func_(expand_buffer_float16_func),
create_beam_scorer_func_(create_beam_scorer_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
Expand Down Expand Up @@ -85,7 +85,7 @@ class BeamSearchWhisper : public BeamSearchBase<T> {
// Device specific functions
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
#endif
Expand Down Expand Up @@ -272,7 +272,7 @@ Status BeamSearchWhisper<T>::Execute(const FeedsFetchesManager& encoder_feeds_fe
auto cross_attention_past_key_sz = first_cross_attention_key->Shape().Size();
beam_state.EnsurePastStateReorderStagingBuffer(this->temp_space_allocator_, cross_attention_past_key_sz);

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Here we only need to reorder the past key for self-attention and cross-attention.
for (size_t i = 0; i < 2 * static_cast<size_t>(decoder_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum DeviceCopyDirection {

namespace GenerationDeviceHelper {

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
using ReorderPastStateFunc = std::function<Status(
const void* cuda_device_prop, // cudaDeviceProp
Tensor& past_state,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -227,7 +227,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
init_greedy_state_fp16_func_,
device_copy_func_,
update_gpt_feeds_fp16_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, cuda_device_prop_, cuda_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/greedy_search.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GreedySearch : public IControlFlowKernel {
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) {
reorder_past_state_func_ = reorder_past_state_func;
}
Expand All @@ -73,7 +73,7 @@ class GreedySearch : public IControlFlowKernel {
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
#endif
Expand All @@ -90,7 +90,7 @@ class GreedySearch : public IControlFlowKernel {
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
init_greedy_state_func_(init_greedy_state_func),
update_feeds_func_(update_feeds_func) {}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
Status InitializeCuda(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const void* cuda_device_prop,
Expand Down Expand Up @@ -109,7 +109,7 @@ class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitGreedyStateFunc<T> init_greedy_state_func_;
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
Expand Down Expand Up @@ -336,7 +336,7 @@ Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_
// Increase sequence length after a new token is generated.
++current_length;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedSelfAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_attention_) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/transformers/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const {
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand All @@ -163,7 +163,7 @@ Status Sampling::Compute(OpKernelContext* ctx) const {
init_greedy_state_fp16_func_,
device_copy_func_,
update_gpt_feeds_fp16_func_};
#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
ORT_RETURN_IF_ERROR(impl.InitializeCuda(reorder_past_state_func_, gpu_device_prop_, gpu_device_arch_));
#endif
ORT_RETURN_IF_ERROR(impl.Initialize());
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/cpu/transformers/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Sampling : public IControlFlowKernel {
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
void SetDeviceHelpers_Cuda(const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func) {
reorder_past_state_func_ = reorder_past_state_func;
}
Expand All @@ -70,7 +70,7 @@ class Sampling : public IControlFlowKernel {
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
const void* gpu_device_prop_ = nullptr;
int gpu_device_arch_ = 0;
#endif
Expand All @@ -87,7 +87,7 @@ class Sampling : public IControlFlowKernel {
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;

#ifdef USE_CUDA
#if defined(USE_CUDA) || defined(USE_ROCM)
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
#endif

Expand Down
Loading

0 comments on commit b1f3135

Please sign in to comment.