Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Fix NumericLimits #22738

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct TopK {
__device__ __forceinline__ void Init() {
for (int i = 0; i < max_k; i++) {
key[i] = -1;
value[i] = NumericLimits<T>::Min();
value[i] = NumericLimits<T>::Lowest();
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <cub/cub.cuh>


#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/cu_inc/common.cuh"

Expand All @@ -19,7 +20,10 @@ struct TopOne {
int32_t key;
T value;

__device__ __host__ __forceinline__ TopOne(int32_t key = -1, T value = NumericLimits<T>::Min()) : key(key), value(value) {
__device__ __host__ __forceinline__ TopOne() : key(-1), value(NumericLimits<T>::Lowest()) {
}

__device__ __host__ __forceinline__ TopOne(int32_t key, T value) : key(key), value(value) {
}

__device__ __forceinline__ void Reduce(int32_t k, T v) {
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/cuda/math/topk_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute,
if (aligned_dimension <= GridDim::maxThreadsPerBlock) {
BitonicTopK<CudaT><<<N, GridDim::maxThreadsPerBlock, aligned_dimension * sizeof(KV<CudaT>), stream>>>(
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, aligned_K, largest, sorted, dimension,
aligned_dimension, NumericLimits<T>::Min(), NumericLimits<T>::Max());
aligned_dimension, NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
} else if (K <= BT * 16 || 0 == sorted) {
if (use_deterministic_compute) {
static std::once_flag log_warning;
Expand All @@ -425,19 +425,19 @@ Status TopKImpl(const CudaKernel* kernel, bool use_deterministic_compute,
if (BT * 2 >= K || 0 == sorted) {
RadixTopK<CudaT, BT, 2><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
NumericLimits<T>::Min(), NumericLimits<T>::Max());
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
} else if (BT * 4 >= K) {
RadixTopK<CudaT, BT, 4><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
NumericLimits<T>::Min(), NumericLimits<T>::Max());
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
} else if (BT * 8 >= K) {
RadixTopK<CudaT, BT, 8><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
NumericLimits<T>::Min(), NumericLimits<T>::Max());
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
} else {
RadixTopK<CudaT, BT, 16><<<N, BT, 256 * sizeof(uint32_t), stream>>>(
input_x_ptr, output_v_ptr, output_i, elem_nums, size, axis, K, largest, sorted, dimension, XPT,
NumericLimits<T>::Min(), NumericLimits<T>::Max());
NumericLimits<CudaT>::Lowest(), NumericLimits<CudaT>::Max());
}
} else {
auto input_key_buffer = kernel->GetScratchBuffer<CudaT>(dimension, ort_stream);
Expand Down
44 changes: 10 additions & 34 deletions onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <type_traits>
#include <vector>
#include <limits>

#include <gsl/gsl>
#include "core/framework/float16.h"
Expand Down Expand Up @@ -120,51 +121,26 @@ constexpr int kNumBitsPerBitmaskElement = std::numeric_limits<BitmaskElementType

template <typename T>
struct NumericLimits {
__inline__ __host__ __device__ static T Min() {
__inline__ __host__ __device__ static T Lowest() {
return std::numeric_limits<T>::lowest();
}
__inline__ __host__ __device__ static T Max() {
return std::numeric_limits<T>::max();
}
};

template <>
struct NumericLimits<MLFloat16> {
__inline__ __host__ __device__ static half Min() {
return -65504.0;
}
__inline__ __host__ __device__ static half Max() {
return 65504.0;
}
};

template <>
struct NumericLimits<half> {
__inline__ __host__ __device__ static half Min() {
return -65504.0;
}
__inline__ __host__ __device__ static half Max() {
return 65504.0;
__inline__ __host__ __device__ static half Lowest() {
return -65504.0f;
}
};

template <>
struct NumericLimits<float> {
__inline__ __host__ __device__ static float Min() {
return -INFINITY;
}
__inline__ __host__ __device__ static float Max() {
return INFINITY;
}
};

template <>
struct NumericLimits<double> {
__inline__ __host__ __device__ static double Min() {
return -HUGE_VAL;
}
__inline__ __host__ __device__ static double Max() {
return HUGE_VAL;
__inline__ __host__ __device__ static half Max() {
#ifdef CUDART_MAX_NORMAL_FP16 // defined in cuda 12.3 or later
return CUDART_MAX_NORMAL_FP16;
#else
return 65504.0f;
#endif
}
};

Expand Down
Loading