Skip to content

Commit

Permalink
Implement cuda::std::numeric_limits for __half and __nv_bfloat16 (
Browse files Browse the repository at this point in the history
#3361)

* implement `cuda::std::numeric_limits` for `__half` and `__nv_bfloat16`
  • Loading branch information
davebayer authored and bernhardmgruber committed Jan 24, 2025
1 parent 6d735b6 commit 83fd7d0
Show file tree
Hide file tree
Showing 36 changed files with 563 additions and 201 deletions.
206 changes: 199 additions & 7 deletions libcudacxx/include/cuda/std/limits
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
#endif // no system header

#include <cuda/std/__bit/bit_cast.h>
#include <cuda/std/__type_traits/is_arithmetic.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/climits>
#include <cuda/std/version>

Expand All @@ -46,7 +49,46 @@ enum float_denorm_style
denorm_present = 1
};

template <class _Tp, bool = is_arithmetic<_Tp>::value>
enum class __numeric_limits_type
{
__integral,
__bool,
__floating_point,
__other,
};

template <class _Tp>
_LIBCUDACXX_HIDE_FROM_ABI constexpr __numeric_limits_type __make_numeric_limits_type()
{
#if !defined(_CCCL_NO_IF_CONSTEXPR)
_CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_same, _Tp, bool))
{
return __numeric_limits_type::__bool;
}
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_integral, _Tp))
{
return __numeric_limits_type::__integral;
}
else _CCCL_IF_CONSTEXPR (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp))
{
return __numeric_limits_type::__floating_point;
}
else
{
return __numeric_limits_type::__other;
}
#else // ^^^ !_CCCL_NO_IF_CONSTEXPR ^^^ // vvv _CCCL_NO_IF_CONSTEXPR vvv
return _CCCL_TRAIT(is_same, _Tp, bool)
? __numeric_limits_type::__bool
: (_CCCL_TRAIT(is_integral, _Tp)
? __numeric_limits_type::__integral
: (_CCCL_TRAIT(is_floating_point, _Tp) || _CCCL_TRAIT(__is_extended_floating_point, _Tp)
? __numeric_limits_type::__floating_point
: __numeric_limits_type::__other));
#endif // _CCCL_NO_IF_CONSTEXPR
}

template <class _Tp, __numeric_limits_type = __make_numeric_limits_type<_Tp>()>
class __numeric_limits_impl
{
public:
Expand Down Expand Up @@ -135,7 +177,7 @@ struct __int_min<_Tp, __digits, false>
};

template <class _Tp>
class __numeric_limits_impl<_Tp, true>
class __numeric_limits_impl<_Tp, __numeric_limits_type::__integral>
{
public:
using type = _Tp;
Expand Down Expand Up @@ -212,7 +254,7 @@ public:
};

template <>
class __numeric_limits_impl<bool, true>
class __numeric_limits_impl<bool, __numeric_limits_type::__bool>
{
public:
using type = bool;
Expand Down Expand Up @@ -286,7 +328,7 @@ public:
};

template <>
class __numeric_limits_impl<float, true>
class __numeric_limits_impl<float, __numeric_limits_type::__floating_point>
{
public:
using type = float;
Expand Down Expand Up @@ -381,7 +423,7 @@ public:
};

template <>
class __numeric_limits_impl<double, true>
class __numeric_limits_impl<double, __numeric_limits_type::__floating_point>
{
public:
using type = double;
Expand Down Expand Up @@ -476,7 +518,7 @@ public:
};

template <>
class __numeric_limits_impl<long double, true>
class __numeric_limits_impl<long double, __numeric_limits_type::__floating_point>
{
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE

Expand Down Expand Up @@ -551,6 +593,156 @@ public:
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE
};

#if defined(_LIBCUDACXX_HAS_NVFP16)
template <>
class __numeric_limits_impl<__half, __numeric_limits_type::__floating_point>
{
public:
using type = __half;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 11;
static constexpr int digits10 = 3;
static constexpr int max_digits10 = 5;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
{
return type(__half_raw{0x0400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
{
return type(__half_raw{0x7bffu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
{
return type(__half_raw{0xfbffu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
{
return type(__half_raw{0x1400u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
{
return type(__half_raw{0x3800u});
}

static constexpr int min_exponent = -13;
static constexpr int min_exponent10 = -4;
static constexpr int max_exponent = 16;
static constexpr int max_exponent10 = 4;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
{
return type(__half_raw{0x7c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
{
return type(__half_raw{0x7e00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
{
return type(__half_raw{0x7d00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
{
return type(__half_raw{0x0001u});
}

static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
template <>
class __numeric_limits_impl<__nv_bfloat16, __numeric_limits_type::__floating_point>
{
public:
using type = __nv_bfloat16;

static constexpr bool is_specialized = true;

static constexpr bool is_signed = true;
static constexpr int digits = 8;
static constexpr int digits10 = 2;
static constexpr int max_digits10 = 4;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type min() noexcept
{
return type(__nv_bfloat16_raw{0x0080u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type max() noexcept
{
return type(__nv_bfloat16_raw{0x7f7fu});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type lowest() noexcept
{
return type(__nv_bfloat16_raw{0xff7fu});
}

static constexpr bool is_integer = false;
static constexpr bool is_exact = false;
static constexpr int radix = __FLT_RADIX__;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type epsilon() noexcept
{
return type(__nv_bfloat16_raw{0x3c00u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type round_error() noexcept
{
return type(__nv_bfloat16_raw{0x3f00u});
}

static constexpr int min_exponent = -125;
static constexpr int min_exponent10 = -37;
static constexpr int max_exponent = 128;
static constexpr int max_exponent10 = 38;

static constexpr bool has_infinity = true;
static constexpr bool has_quiet_NaN = true;
static constexpr bool has_signaling_NaN = true;
static constexpr float_denorm_style has_denorm = denorm_present;
static constexpr bool has_denorm_loss = false;
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type infinity() noexcept
{
return type(__nv_bfloat16_raw{0x7f80u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type quiet_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fc0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type signaling_NaN() noexcept
{
return type(__nv_bfloat16_raw{0x7fa0u});
}
_LIBCUDACXX_HIDE_FROM_ABI static constexpr type denorm_min() noexcept
{
return type(__nv_bfloat16_raw{0x0001u});
}

static constexpr bool is_iec559 = true;
static constexpr bool is_bounded = true;
static constexpr bool is_modulo = false;

static constexpr bool traps = false;
static constexpr bool tinyness_before = false;
static constexpr float_round_style round_style = round_to_nearest;
};
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Tp>
class numeric_limits : public __numeric_limits_impl<_Tp>
{};
Expand Down
15 changes: 15 additions & 0 deletions libcudacxx/test/libcudacxx/std/containers/views/mdspan/my_int.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef _MY_INT_HPP
#define _MY_INT_HPP

#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include "test_macros.h"

struct my_int_non_convertible;
Expand All @@ -22,6 +25,10 @@ template <>
struct cuda::std::is_integral<my_int> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int> : public cuda::std::numeric_limits<int>
{};

// Wrapper type that's not implicitly convertible

struct my_int_non_convertible
Expand All @@ -43,6 +50,10 @@ template <>
struct cuda::std::is_integral<my_int_non_convertible> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int_non_convertible> : public cuda::std::numeric_limits<int>
{};

// Wrapper type that's not nothrow-constructible

struct my_int_non_nothrow_constructible
Expand All @@ -62,4 +73,8 @@ template <>
struct cuda::std::is_integral<my_int_non_nothrow_constructible> : cuda::std::true_type
{};

template <>
class cuda::std::numeric_limits<my_int_non_nothrow_constructible> : public cuda::std::numeric_limits<int>
{};

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,13 @@ int main(int, char**)
#ifndef _LIBCUDACXX_HAS_NO_LONG_DOUBLE
test<long double>();
#endif
#if defined(_LIBCUDACXX_HAS_NVFP16)
test<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#if defined(_LIBCUDACXX_HAS_NVBF16)
test<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16

static_assert(!cuda::std::numeric_limits<cuda::std::complex<double>>::is_specialized,
"!cuda::std::numeric_limits<cuda::std::complex<double> >::is_specialized");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef NUMERIC_LIMITS_MEMBERS_COMMON_H
#define NUMERIC_LIMITS_MEMBERS_COMMON_H

// Disable all the extended floating point operations and conversions
#define __CUDA_NO_HALF_CONVERSIONS__ 1
#define __CUDA_NO_HALF_OPERATORS__ 1
#define __CUDA_NO_BFLOAT16_CONVERSIONS__ 1
#define __CUDA_NO_BFLOAT16_OPERATORS__ 1

#include <cuda/std/limits>

template <class T>
__host__ __device__ bool float_eq(T x, T y)
{
return x == y;
}

#if defined(_LIBCUDACXX_HAS_NVFP16)
__host__ __device__ inline bool float_eq(__half x, __half y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
__host__ __device__ inline bool float_eq(__nv_bfloat16 x, __nv_bfloat16 y)
{
return __heq(x, y);
}
#endif // _LIBCUDACXX_HAS_NVBF16

#endif // NUMERIC_LIMITS_MEMBERS_COMMON_H
Loading

0 comments on commit 83fd7d0

Please sign in to comment.