Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport to 2.8: Implement cuda::std::numeric_limits for __half and __nv_bfloat16 (#3361) #3522

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 199 additions & 7 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
#endif // no system header

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/__type_traits/is_arithmetic.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/climits>
#include <cuda/std/version>

Expand All @@ -46,7 +49,46 @@ enum float_denorm_style
denorm_present = 1
};

template <class _Tp, bool = is_arithmetic<_Tp>::value>
enum class __numeric_limits_type
{
__integral,
__bool,
__floating_point,
__other,
};

template <class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr __numeric_limits_type __make_numeric_limits_type()
{
#if !defined(_CCCL_NO_IF_CONSTEXPR)
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_same, _Tp, bool))
{
return __numeric_limits_type::__bool;
}
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_integral, _Tp))
{
return __numeric_limits_type::__integral;
}
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp))
{
return __numeric_limits_type::__floating_point;
}
else
{
return __numeric_limits_type::__other;
}
#else // ^^^ !_CCCL_NO_IF_CONSTEXPR ^^^ // vvv _CCCL_NO_IF_CONSTEXPR vvv
return _CCCL_TRAIT(is_same, _Tp, bool)
? __numeric_limits_type::__bool
: (_CCCL_TRAIT(is_integral, _Tp)
? __numeric_limits_type::__integral
: (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp)
? __numeric_limits_type::__floating_point
: __numeric_limits_type::__other));
#endif // _CCCL_NO_IF_CONSTEXPR
}

template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>()>
class __numeric_limits_impl
{
public:
Expand Down Expand Up @@ -135,7 +177,7 @@ struct __int_min<_Tp, __digits, false>
};

template <class _Tp>
class __numeric_limits_impl<_Tp, true>
class __numeric_limits_impl<_Tp, __numeric_limits_type::__integral>
{
public:
using type = _Tp;
Expand Down Expand Up @@ -212,7 +254,7 @@ public:
};

template <>
class __numeric_limits_impl<bool, true>
class __numeric_limits_impl<bool, __numeric_limits_type::__bool>
{
public:
using type = bool;
Expand Down Expand Up @@ -286,7 +328,7 @@ public:
};

template <>
class __numeric_limits_impl<float, true>
class __numeric_limits_impl<float, __numeric_limits_type::__floating_point>
{
public:
using type = float;
Expand Down Expand Up @@ -381,7 +423,7 @@ public:
};

template <>
class __numeric_limits_impl<double, true>
class __numeric_limits_impl<double, __numeric_limits_type::__floating_point>
{
public:
using type = double;
Expand Down Expand Up @@ -476,7 +518,7 @@ public:
};

template <>
class __numeric_limits_impl<long double, true>
class __numeric_limits_impl<long double, __numeric_limits_type::__floating_point>
{
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE

Expand Down Expand Up @@ -551,6 +593,156 @@ public:
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
};

#if defined(_LIBCUDACXX_HAS_NVFP16)
template <>
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
{
public:
using type = __half;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
{
return type(__half_raw{0x0400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
{
return type(__half_raw{0x7bffu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
{
return type(__half_raw{0xfbffu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
{
return type(__half_raw{0x1400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
{
return type(__half_raw{0x3800u});
}

static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
{
return type(__half_raw{0x7c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
{
return type(__half_raw{0x7e00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
{
return type(__half_raw{0x7d00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
{
return type(__half_raw{0x0001u});
}

static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
template <>
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
{
public:
using type = __nv_bfloat16;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
{
return type(__nv_bfloat16_raw{0x0080u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
{
return type(__nv_bfloat16_raw{0x7f7fu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
{
return type(__nv_bfloat16_raw{0xff7fu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
{
return type(__nv_bfloat16_raw{0x3c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
{
return type(__nv_bfloat16_raw{0x3f00u});
}

static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
{
return type(__nv_bfloat16_raw{0x7f80u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fc0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fa0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
{
return type(__nv_bfloat16_raw{0x0001u});
}

static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Tp>
class numeric_limits : public __numeric_limits_impl<_Tp>
{};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef _MY_INT_HPP
#define _MY_INT_HPP

#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include "test_macros.h"

struct my_int_non_convertible;
Expand All @@ -22,6 +25,10 @@ template <>
struct cuda::std::is_integral<my_int> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int> : public cuda::std::numeric_limits<int>
{};

// Wrapper type that's not implicitly convertible

struct my_int_non_convertible
Expand All @@ -43,6 +50,10 @@ template <>
struct cuda::std::is_integral<my_int_non_convertible> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int_non_convertible> : public cuda::std::numeric_limits<int>
{};

// Wrapper type that's not nothrow-constructible

struct my_int_non_nothrow_constructible
Expand All @@ -62,4 +73,8 @@ template <>
struct cuda::std::is_integral<my_int_non_nothrow_constructible> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int_non_nothrow_constructible> : public cuda::std::numeric_limits<int>
{};

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16

static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef NUMERIC_LIMITS_MEMBERS_COMMON_H
#define NUMERIC_LIMITS_MEMBERS_COMMON_H

// Disable all the extended floating point operations and conversions
#define __CUDA_NO_HALF_CONVERSIONS__ 1
#define __CUDA_NO_HALF_OPERATORS__ 1
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1

#include <cuda/std/limits>

template <class T>
__host__ __device__ bool float_eq(T x, T y)
{
return x == y;
}

#if defined(_LIBCUDACXX_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVBF16

#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H
Loading
Loading