diff --git a/src/api/include/migraphx/migraphx.h b/src/api/include/migraphx/migraphx.h index cde517f0b28..6a1edd9b6b0 100644 --- a/src/api/include/migraphx/migraphx.h +++ b/src/api/include/migraphx/migraphx.h @@ -44,7 +44,8 @@ m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus diff --git a/src/include/migraphx/bit_cast.hpp b/src/include/migraphx/bit_cast.hpp new file mode 100644 index 00000000000..b5fb6d472f6 --- /dev/null +++ b/src/include/migraphx/bit_cast.hpp @@ -0,0 +1,50 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ +#ifndef MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP +#define MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif +#include + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define MIGRAPHX_CONST_FOLD(x) (__builtin_constant_p(x) ? (x) : (x)) + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +template +inline constexpr To bit_cast(From fr) noexcept +{ + static_assert(sizeof(To) == sizeof(From)); +#if defined(__GNUC__) and !defined(__clang__) + return MIGRAPHX_CONST_FOLD(*reinterpret_cast(&fr)); +#else + return __builtin_bit_cast(To, fr); +#endif +} +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif +#endif // MIGRAPHX_GUARD_RTGLIB_BITCAST_HPP diff --git a/src/include/migraphx/float8.hpp b/src/include/migraphx/float8.hpp new file mode 100644 index 00000000000..b2d6fedc68c --- /dev/null +++ b/src/include/migraphx/float8.hpp @@ -0,0 +1,409 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP +#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP + +// We are clipping/saturation in down conversion by default. Unclipped version is not tested and +// shouldn't be used without having enough tests. +// logic is based on clipping table from here : https://onnx.ai/onnx/technical/float8.html#cast +// NOLINTNEXTLINE +#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace fp8 { + +enum class rounding_mode +{ + standard, // standard rounding is doing RNE -- round to nearest even + stochastic +}; + +enum class f8_type +{ + bf8 = 0, // s1e5m2 + fp8 = 1 // s1e4m3 +}; + +template +class numeric_limits; + +template +struct float8 +{ + uint8_t data = 0x00; + // default constructor + constexpr float8() = default; + // default copy constructor + constexpr float8(const float8& y) = default; + struct from_bits_t + { + }; + static constexpr from_bits_t from_bits() { return from_bits_t(); } + + explicit constexpr float8(uint8_t bits, from_bits_t) : data(bits) {} + + explicit constexpr float8( + float v, + migraphx::fp8::rounding_mode rm = migraphx::fp8::rounding_mode::standard, + uint32_t rng = 0) + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<3, 4, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING + } + else + { +#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, true /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#else // MIGRAPHX_F8_DOWNCAST_CLIPPING + data = migraphx::fp8::impl:: + cast_to_f8<2, 5, float, FNUZ /*negative_zero_nan*/, false /*clip*/>( + v, (rm == migraphx::fp8::rounding_mode::stochastic), rng); +#endif // rocblas_F8_downcast_clipping} + } + } + + inline constexpr operator float() const + { + if constexpr(T == migraphx::fp8::f8_type::fp8) + { + return migraphx::fp8::impl::cast_from_f8<3, 4, float, FNUZ /*negative_zero_nan*/>(data); + } // else + return migraphx::fp8::impl::cast_from_f8<2, 5, float, FNUZ /*negative_zero_nan*/>(data); + } + + inline constexpr bool is_zero() const + { + if constexpr(FNUZ) + { + return data == 0x00; + } + else + { + return (data == 0x00) or (data == 0x80); + } + } + + inline constexpr bool is_nan() const + { + if constexpr(FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx::fp8::f8_type::bf8) + { + return (data == 0x7D) or (data == 0x7E) or (data == 0x7F) or (data == 0xFD) or + (data == 0xFE) or (data == 0xFF); + } + else + { + return (data == 0x7F) or (data == 0xFF); + } + } + } + + inline constexpr bool is_inf() const + { + if constexpr(FNUZ) + { + return data == 0x80; + } + else + { + if(T == migraphx::fp8::f8_type::bf8) + { + return (data == 0x7C) or (data == 0xFC); + } + else + { + // no infinities in e4m3fn, represent them as NaNs + return (data == 0x7F) or (data == 0xFF); + } + } + } + +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_UNARY_OP(unary_op, binary_op) \ + constexpr float8& operator unary_op(const float8& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } \ + constexpr float8& operator unary_op(const float& rhs) \ + { \ + const auto tmp = static_cast(*this) binary_op static_cast(rhs); \ + *this = static_cast(tmp); \ + return *this; \ + } + + MIGRAPHX_FP8_UNARY_OP(*=, *) + MIGRAPHX_FP8_UNARY_OP(-=, -) + MIGRAPHX_FP8_UNARY_OP(+=, +) + MIGRAPHX_FP8_UNARY_OP(/=, /) + + inline constexpr float8& operator=(const float8& rhs) = default; + inline constexpr float8& operator=(float8&& rhs) noexcept = default; + + inline constexpr float8& operator=(float rhs) + { + *this = static_cast(rhs); + return *this; + } + + inline constexpr bool operator==(const float8& rhs) const + { + if(rhs.is_nan() or rhs.is_inf() or this->is_nan() or this->is_inf()) + return false; + else if((rhs.is_zero() and this->is_zero()) or (this->data == rhs.data)) + return true; + return false; + } + + inline constexpr bool operator<(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we < them; + } + + inline constexpr bool operator>(const float8& rhs) const + { + const auto we = static_cast(*this); + const auto them = static_cast(rhs); + return we > them; + } +}; + +// https://onnx.ai/onnx/technical/float8.html +using fp8e4m3fn = float8; +using fp8e5m2 = float8; +using fp8e4m3fnuz = float8; +using fp8e5m2fnuz = float8; +/* +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \ + inline constexpr U operator binary_op(const T& lhs, const T& rhs) \ + { \ + return U(static_cast(lhs) binary_op static_cast(rhs)); \ + } + +// TODO: these should return floats for binary ops +// NOLINTNEXTLINE +#define MIGRAPHX_FP8_BINARY_OP_GEN_FOR(T) \ + MIGRAPHX_FP8_BINARY_OP(*, T, T) \ + MIGRAPHX_FP8_BINARY_OP(-, T, T) \ + MIGRAPHX_FP8_BINARY_OP(/, T, T) \ + MIGRAPHX_FP8_BINARY_OP(+, T, T) \ + MIGRAPHX_FP8_BINARY_OP(==, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(>, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(<, T, bool) \ + MIGRAPHX_FP8_BINARY_OP(!=, T, bool) + +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fn) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e5m2fnuz) +MIGRAPHX_FP8_BINARY_OP_GEN_FOR(fp8e4m3fnuz) +*/ + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fnuz& rhs) +{ + return os << static_cast(rhs); +} + +inline fp8e4m3fnuz fabs(fp8e4m3fnuz v) +{ + v.data = v.data & 0x7F; // NOLINT + return v; +} + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e4m3fn& rhs) +{ + return os << static_cast(rhs); +} + +inline fp8e4m3fn fabs(fp8e4m3fn v) +{ + v.data = v.data & 0x7F; // NOLINT + return v; +} + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e5m2fnuz& rhs) +{ + return os << static_cast(rhs); +} + +inline fp8e5m2fnuz fabs(fp8e5m2fnuz v) +{ + v.data = v.data & 0x7F; // NOLINT + return v; +} +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const fp8e5m2& rhs) +{ + return os << static_cast(rhs); +} + +inline fp8e5m2 fabs(fp8e5m2 v) +{ + v.data = v.data & 0x7F; // NOLINT + return v; +} +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = false; + static constexpr fp8e4m3fnuz epsilon() { return fp8e4m3fnuz(0x28, fp8e4m3fnuz::from_bits()); } + // NOLINTNEXTLINE + static constexpr fp8e4m3fnuz quiet_NaN() { return fp8e4m3fnuz(0x80, fp8e4m3fnuz::from_bits()); } + + static constexpr fp8e4m3fnuz max() { return fp8e4m3fnuz(0x7F, fp8e4m3fnuz::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr fp8e4m3fnuz min() { return fp8e4m3fnuz(0x08, fp8e4m3fnuz::from_bits()); } + + static constexpr fp8e4m3fnuz lowest() { return fp8e4m3fnuz(0xFF, fp8e4m3fnuz::from_bits()); } +}; + +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = false; + static constexpr fp8e4m3fn epsilon() { return fp8e4m3fn(0x20, fp8e4m3fn::from_bits()); } + // NOLINTNEXTLINE + static constexpr fp8e4m3fn quiet_NaN() { return fp8e4m3fn(0x7F, fp8e4m3fn::from_bits()); } + + static constexpr fp8e4m3fn max() { return fp8e4m3fn(0x7E, fp8e4m3fn::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01 + static constexpr fp8e4m3fn min() { return fp8e4m3fn(0x08, fp8e4m3fn::from_bits()); } + + static constexpr fp8e4m3fn lowest() { return fp8e4m3fn(0xFE, fp8e4m3fn::from_bits()); } +}; + +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = false; + static constexpr fp8e5m2fnuz epsilon() { return fp8e5m2fnuz(0x34, fp8e5m2fnuz::from_bits()); } + + static constexpr fp8e5m2fnuz quiet_NaN() // NOLINT + { + return fp8e5m2fnuz(0x80, fp8e5m2fnuz::from_bits()); + } + + static constexpr fp8e5m2fnuz max() { return fp8e5m2fnuz(0x7F, fp8e5m2fnuz::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr fp8e5m2fnuz min() { return fp8e5m2fnuz(0x4, fp8e5m2fnuz::from_bits()); } + + static constexpr fp8e5m2fnuz lowest() { return fp8e5m2fnuz(0xFF, fp8e5m2fnuz::from_bits()); } +}; + +template <> +class numeric_limits +{ + public: + static constexpr bool has_infinity = true; + static constexpr fp8e5m2 epsilon() { return fp8e5m2(0x34, fp8e5m2::from_bits()); } + // 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs + static constexpr fp8e5m2 quiet_NaN() { return fp8e5m2(0xFF, fp8e5m2::from_bits()); } // NOLINT + + static constexpr fp8e5m2 max() { return fp8e5m2(0x7B, fp8e5m2::from_bits()); } + // this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make + // this distinction. For the floating points we would end up using lowest most of the times. + static constexpr fp8e5m2 min() { return fp8e5m2(0x4, fp8e5m2::from_bits()); } + + static constexpr fp8e5m2 lowest() { return fp8e5m2(0xFB, fp8e5m2::from_bits()); } + // 7C and FC both are infinity + static constexpr fp8e5m2 infinity() { return fp8e5m2(0x7C, fp8e5m2::from_bits()); } +}; +} // namespace fp8 +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +// ================================================================================================= +// define numeric limits for the new data type +// NOLINTBEGIN +namespace std { +#define MIGRAPHX_FP8_STD_OVERLOADS(T) \ + inline bool isfinite(T x) { return not x.is_inf() and not x.is_nan(); } \ + inline bool isnan(T x) { return x.is_nan(); } \ + template <> \ + class numeric_limits : public migraphx::fp8::numeric_limits \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template \ + struct common_type : std::common_type \ + { \ + }; \ + template <> \ + struct common_type \ + { \ + using type = T; \ + }; + +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fn) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_FP8_STD_OVERLOADS(migraphx::fp8::fp8e5m2fnuz) +} // namespace std +// NOLINTEND +// ================================================================================================= +#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_HPP diff --git a/src/include/migraphx/float8_impl.hpp b/src/include/migraphx/float8_impl.hpp new file mode 100644 index 00000000000..e6423eea83a --- /dev/null +++ b/src/include/migraphx/float8_impl.hpp @@ -0,0 +1,328 @@ +/* ************************************************************************ + * Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop- + * ies of the Software, and to permit persons to whom the Software is furnished + * to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM- + * PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER + * IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE- + * CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + * + * ************************************************************************ */ + +#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP +#define MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL_HPP +#include +#include +#include +#include +#include +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace fp8 { +namespace impl { + +// NOLINTBEGIN +template +constexpr uint8_t cast_to_f8(T f_x, bool stoch = false, uint32_t rng = 0) +{ + constexpr bool is_float = std::is_same::value; + // half is not supported for now + constexpr bool is_half = false; + static_assert(Wm + We == 7, "Wm+We==7"); + static_assert(is_float or is_half, "Only float can be cast to f8"); + + const uint32_t mfmt = (sizeof(T) == 4) ? 23 : 10; + typename std::conditional::type x; + + if constexpr(sizeof(T) == 4) + x = migraphx::bit_cast(f_x); + else + x = migraphx::bit_cast(f_x); + + uint32_t head = 0; + uint32_t mantissa = 0; + int exponent = 0; + uint32_t bias = 0; + uint32_t sign = 0; + if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << We) - 1) << Wm); + uint32_t signed_all_ones = (sign << 7) + ((((1 << We) - 1) << Wm) + ((1 << Wm) - 1)); + + // Calcualte maximum singed value FLT_MAX, FLT_MIN + uint32_t signed_max = signed_all_ones; + if(not NegativeZeroNan) + signed_max = (Wm == 2) ? (signed_max - 4) : (signed_max - 1); + + // Deal with inf and NaNs + if(NegativeZeroNan) // For the FNUZ cases, it is simple just return NaNs + { + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) + return 0x80; + } + else + { + // calculate most common NaN mantissa for FP8, which is all Ones in binary + uint32_t nan_mantissa = 1; + for(auto i = 1; i < Wm; ++i) + { + nan_mantissa |= (nan_mantissa << 1); + } + if((sizeof(T) == 4 and ((x & 0x7F800000) == 0x7F800000)) or + (sizeof(T) == 2 and ((x & 0x7C00) == 0x7C00))) + { + // infinity + if(mantissa == 0) + { + if(sign == 0) + return (Wm == 2) ? 0x7B : 0x7E; + else + return (Wm == 2) ? 0xFB : 0xFE; + } + else // NaNs + return signed_inf + nan_mantissa; + } + } + // handle positive zero + if(x == 0) + return 0; + // handle negative zero + else if((sizeof(T) == 4 and x == 0x80000000) or (sizeof(T) == 2 and x == 0x8000)) + { + return NegativeZeroNan ? 0 : 0x80; // For FNUZ types neg zero is just positive zero + } + + /* First need to check if it is normal or denorm as there is a difference of implict 1 + Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + RNE, no need to add rng. Then probably need to check whether there is carry and adjust + exponent and mantissa again*/ + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits + const int f8_bias = (1 << (We - 1u)) - 1 + (NegativeZeroNan ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + /* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + f8_exponent is the converted f8 exponent with bias encoding + exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + the difference needs to be adjusted and mantissa shifted*/ + int act_exponent = 0; + int f8_exponent = 0; + int exponent_diff = 0; + + if(exponent == 0 and mantissa != 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 + here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal + has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some + numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is + 2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 + are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = 1 - bias; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1u << mfmt); // Add the implicit 1 into mantissa + } + + // need to know whether the number is right in the middle of two adjacent fp8 numbers. use max + // value of 31 to avoid undefined behaviour + bool midpoint = (mantissa & ((1u << std::min(31u, mfmt - Wm + exponent_diff)) - 1)) == + (1u << std::min(31u, mfmt - Wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. + */ + + if(exponent_diff > 0) + mantissa >>= std::min(31u, uint32_t(exponent_diff)); + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1u << (mfmt - Wm)) - 1; + bool odd = + mantissa & (1u << (mfmt - Wm)); // if the least significant bit that is not truncated is 1 + /* + This part is doing rounding by adding mantissa part that is going to get dropped. + e.g. if the dropped part for less than 0.5 than it would round down. + if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained + mantissa. + For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and + `xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part. + For the odd case : + this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained + part making it RNE. + For the even case : this will add xy0:10000000 + 000:01111111 which would + round down and keep number Even + */ + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0 and ((1 << mfmt) & mantissa)) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + else if((1 << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + + mantissa >>= (mfmt - Wm); + + // above range: quantize to maximum possible float of the same sign + // for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans + const int max_exp = (1 << We) - ((NegativeZeroNan or Wm == 3) ? 1 : 2); + if(f8_exponent > max_exp) + { + if(Clip) + return signed_max; + else + { + // https://onnx.ai/onnx/technical/float8.html#cast + if(NegativeZeroNan) + return 0x80; + else + return (Wm == 2) ? signed_inf : signed_all_ones; + } + } + + if(f8_exponent == 0 and mantissa == 0) + return NegativeZeroNan ? 0 : (sign << 7); + mantissa &= (1 << Wm) - 1; + return (sign << 7) | (f8_exponent << Wm) | mantissa; +} +// NOLINTEND + +template +constexpr T cast_from_f8(uint8_t x) +{ + // half is not supported for now + constexpr bool is_half = false; + constexpr bool is_float = std::is_same::value; + static_assert(is_float or is_half, "Only float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + // NOLINTNEXTLINE + T f_inf, f_neg_inf, f_nan, f_neg0; + + if constexpr(is_float) + { + const uint32_t if_inf = 0x7F800000; + const uint32_t if_neg_inf = 0xFF800000; + const uint32_t if_nan = 0x7F800001; + const uint32_t if_neg0 = 0x80000000; + f_inf = migraphx::bit_cast(if_inf); + f_neg_inf = migraphx::bit_cast(if_neg_inf); + f_nan = migraphx::bit_cast(if_nan); + f_neg0 = migraphx::bit_cast(if_neg0); + } + + if(x == 0) + return 0; + + uint32_t sign = x >> 7; // NOLINT + uint32_t mantissa = x & ((1 << Wm) - 1); // NOLINT + int exponent = (x & 0x7F) >> Wm; // NOLINT + if(NegativeZeroNan) + { + if(x == 0x80) + return f_nan; + } + else + { + if(x == 0x80) + return f_neg0; + if(exponent == ((1 << We) - 1) and Wm == 2) // NOLINT + return (mantissa == 0) ? (sign ? f_neg_inf : f_inf) : f_nan; + else if(Wm == 3 and (x == 0x7F or x == 0xFF)) + return f_nan; + } + typename std::conditional::type retval; + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (We - 1)) + 1 - (NegativeZeroNan ? 1 : 0); // NOLINT + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __builtin_clz(mantissa) - (32 - Wm); + mantissa <<= sh; // NOLINT + exponent += 1 - sh; + mantissa &= ((1 << Wm) - 1); // NOLINT + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - Wm; // NOLINT + + // subnormal output (occurs when T=half, We=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; // NOLINT + mantissa >>= 1 - exponent; // NOLINT + exponent = 0; + } + + if(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; // NOLINT + else + retval = (sign << 31) | (exponent << 23) | mantissa; // NOLINT + return migraphx::bit_cast(retval); +} + +} // namespace impl +} // namespace fp8 +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_RTGLIB_FLOAT8_IMPL diff --git a/src/include/migraphx/half.hpp b/src/include/migraphx/half.hpp index 10cc7e4289c..0f6516d9bda 100644 --- a/src/include/migraphx/half.hpp +++ b/src/include/migraphx/half.hpp @@ -27,6 +27,7 @@ #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -67,6 +68,18 @@ struct common_type : std::common_type // NOLINT { }; +template <> +struct common_type +{ + using type = float; +}; + +template <> +struct common_type +{ + using type = float; +}; + template <> struct common_type { diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 3cf5785087c..d596398ca78 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -60,7 +61,8 @@ struct MIGRAPHX_EXPORT shape m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x, diff --git a/src/include/migraphx/type_traits.hpp b/src/include/migraphx/type_traits.hpp index 1512c38f203..44b5e0573cc 100644 --- a/src/include/migraphx/type_traits.hpp +++ b/src/include/migraphx/type_traits.hpp @@ -28,25 +28,35 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +#define MIGRAPHX_DETAIL_DEFINE_TRAIT(trait) \ + template \ + struct trait : std::trait \ + { \ + }; + #define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ template <> \ struct trait : std::true_type \ { \ }; +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_floating_point); +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_arithmetic); +MIGRAPHX_DETAIL_DEFINE_TRAIT(is_signed); + MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, half) MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, half) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e4m3fnuz) +MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e4m3fnuz) + template using accumulator_type = std::conditional_t{}, diff --git a/src/py/migraphx_py.cpp b/src/py/migraphx_py.cpp index 4b6de6c19d0..3b95959f98d 100644 --- a/src/py/migraphx_py.cpp +++ b/src/py/migraphx_py.cpp @@ -40,7 +40,7 @@ #include #include #include - +#include #ifdef HAVE_GPU #include #endif @@ -144,6 +144,18 @@ struct npy_format_descriptor static constexpr auto name() { return _("half"); } }; +template <> +struct npy_format_descriptor +{ + static std::string format() + { + // following: https://docs.python.org/3/library/struct.html#format-characters + // TODO: need to figure out correct encoding + return "z"; + } + static constexpr auto name() { return _("fp8e4m3fnuz"); } +}; + } // namespace detail } // namespace pybind11 diff --git a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp index 355cc4477b1..28a4b2939d7 100644 --- a/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp +++ b/src/targets/gpu/device/include/migraphx/gpu/device/types.hpp @@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x) // Hip doens't support __fp16 inline __device__ __host__ float to_hip_type(gpu_half x) { return x; } -#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ - template \ - struct trait : std::trait \ - { \ - }; \ - \ - template <> \ - struct trait : std::true_type \ - { \ +#define MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(trait, T) \ + template \ + struct trait : std::trait \ + { \ + }; \ + \ + template <> \ + struct trait : std::true_type \ + { \ }; -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) -MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_signed, __fp16) +MIGRAPHX_DEVICE_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, __fp16) } // namespace device } // namespace gpu diff --git a/src/targets/gpu/gemm_impl.cpp b/src/targets/gpu/gemm_impl.cpp index b4f0881f8d3..4495e21ecac 100644 --- a/src/targets/gpu/gemm_impl.cpp +++ b/src/targets/gpu/gemm_impl.cpp @@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type) case shape::uint8_type: return rocblas_datatype_u8_r; case shape::int32_type: return rocblas_datatype_i32_r; case shape::uint32_type: return rocblas_datatype_u32_r; + case shape::fp8e4m3fnuz_type: case shape::tuple_type: case shape::bool_type: case shape::uint16_type: diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3fc9ea4fac7..33aca123217 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -150,6 +150,7 @@ function(test_headers PREFIX) list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/targets/gpu/include/migraphx/gpu/ck.hpp) endif() + list(REMOVE_ITEM HEADERS ${CMAKE_SOURCE_DIR}/src/include/migraphx/float8_impl.hpp) foreach(HEADER ${HEADERS}) file(RELATIVE_PATH HEADER_REL ${CMAKE_SOURCE_DIR} ${HEADER}) string(MAKE_C_IDENTIFIER ${HEADER_REL} TEST_NAME) diff --git a/test/float_equal.cpp b/test/float_equal.cpp index 102ee4faf67..847a929437c 100644 --- a/test/float_equal.cpp +++ b/test/float_equal.cpp @@ -22,6 +22,7 @@ * THE SOFTWARE. */ #include +#include #include #include "test.hpp" @@ -53,7 +54,7 @@ auto test_float_equal(T x, U y) template void test_equality() { - auto x1 = T(0.1); + auto x1 = T(0.125); auto x2 = U(0.0); auto x3 = U(1.0); EXPECT(test_float_equal(x1, x1)); @@ -71,8 +72,12 @@ void test_equality() TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); +TEST_CASE_REGISTER(test_equality); template void test_limits() @@ -110,8 +115,13 @@ void test_limits() TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); +TEST_CASE_REGISTER(test_limits); + #ifndef _WIN32 // On Windows, types int and long have the same min and max values. TEST_CASE_REGISTER(test_limits); diff --git a/test/fp8e4m3fn.cpp b/test/fp8e4m3fn.cpp new file mode 100644 index 00000000000..0fc0ca90c9d --- /dev/null +++ b/test/fp8e4m3fn.cpp @@ -0,0 +1,291 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e4m3fn_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0, 0.001953125, 0.00390625, 0.005859375, + 0.0078125, 0.009765625, 0.01171875, 0.013671875, + 0.015625, 0.017578125, 0.01953125, 0.021484375, + 0.0234375, 0.025390625, 0.02734375, 0.029296875, + 0.03125, 0.03515625, 0.0390625, 0.04296875, + 0.046875, 0.05078125, 0.0546875, 0.05859375, + 0.0625, 0.0703125, 0.078125, 0.0859375, + 0.09375, 0.1015625, 0.109375, 0.1171875, + 0.125, 0.140625, 0.15625, 0.171875, + 0.1875, 0.203125, 0.21875, 0.234375, + 0.25, 0.28125, 0.3125, 0.34375, + 0.375, 0.40625, 0.4375, 0.46875, + 0.5, 0.5625, 0.625, 0.6875, + 0.75, 0.8125, 0.875, 0.9375, + 1.0, 1.125, 1.25, 1.375, + 1.5, 1.625, 1.75, 1.875, + 2.0, 2.25, 2.5, 2.75, + 3.0, 3.25, 3.5, 3.75, + 4.0, 4.5, 5.0, 5.5, + 6.0, 6.5, 7.0, 7.5, + 8.0, 9.0, 10.0, 11.0, + 12.0, 13.0, 14.0, 15.0, + 16.0, 18.0, 20.0, 22.0, + 24.0, 26.0, 28.0, 30.0, + 32.0, 36.0, 40.0, 44.0, + 48.0, 52.0, 56.0, 60.0, + 64.0, 72.0, 80.0, 88.0, + 96.0, 104.0, 112.0, 120.0, + 128.0, 144.0, 160.0, 176.0, + 192.0, 208.0, 224.0, 240.0, + 256.0, 288.0, 320.0, 352.0, + 384.0, 416.0, 448.0, std::numeric_limits::quiet_NaN(), + -0.0, -0.001953125, -0.00390625, -0.005859375, + -0.0078125, -0.009765625, -0.01171875, -0.013671875, + -0.015625, -0.017578125, -0.01953125, -0.021484375, + -0.0234375, -0.025390625, -0.02734375, -0.029296875, + -0.03125, -0.03515625, -0.0390625, -0.04296875, + -0.046875, -0.05078125, -0.0546875, -0.05859375, + -0.0625, -0.0703125, -0.078125, -0.0859375, + -0.09375, -0.1015625, -0.109375, -0.1171875, + -0.125, -0.140625, -0.15625, -0.171875, + -0.1875, -0.203125, -0.21875, -0.234375, + -0.25, -0.28125, -0.3125, -0.34375, + -0.375, -0.40625, -0.4375, -0.46875, + -0.5, -0.5625, -0.625, -0.6875, + -0.75, -0.8125, -0.875, -0.9375, + -1.0, -1.125, -1.25, -1.375, + -1.5, -1.625, -1.75, -1.875, + -2.0, -2.25, -2.5, -2.75, + -3.0, -3.25, -3.5, -3.75, + -4.0, -4.5, -5.0, -5.5, + -6.0, -6.5, -7.0, -7.5, + -8.0, -9.0, -10.0, -11.0, + -12.0, -13.0, -14.0, -15.0, + -16.0, -18.0, -20.0, -22.0, + -24.0, -26.0, -28.0, -30.0, + -32.0, -36.0, -40.0, -44.0, + -48.0, -52.0, -56.0, -60.0, + -64.0, -72.0, -80.0, -88.0, + -96.0, -104.0, -112.0, -120.0, + -128.0, -144.0, -160.0, -176.0, + -192.0, -208.0, -224.0, -240.0, + -256.0, -288.0, -320.0, -352.0, + -384.0, -416.0, -448.0, std::numeric_limits::quiet_NaN(), + + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx::fp8::fp8e4m3fn fp8_val(bit_val, migraphx::fp8::fp8e4m3fn::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fn_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e4m3fn_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = { + {{512, 0x7e}, {-512, 0xfe}, {448, 0x7e}, {-448, 0xfe}, + {256, 0x78}, {-256, 0xf8}, {240, 0x77}, {-240, 0xf7}, + {1e-07, 0x0}, {1e+07, 0x7e}, {1, 0x38}, {-1, 0xb8}, + {0.1, 0x1d}, {0.11, 0x1e}, {0.111, 0x1e}, {0.1111, 0x1e}, + {-0.1, 0x9d}, {-0.11, 0x9e}, {-0.111, 0x9e}, {-0.1111, 0x9e}, + {0.2, 0x25}, {2, 0x40}, {20, 0x5a}, {200, 0x74}, + {-0.2, 0xa5}, {-2, 0xc0}, {-20, 0xda}, {-200, 0xf4}, + {0.5, 0x30}, {-0.5, 0xb0}, {1.17549e-38, 0x0}, {1.4013e-45, 0x0}, + {0.0078125, 0x4}, {-0.0078125, 0x84}, {0.000976562, 0x0}, {-0.000976562, 0x80}, + {0.000488281, 0x0}, {-0.000488281, 0x80}}}; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e4m3fn(sample.first), + migraphx::fp8::fp8e4m3fn(sample.second, migraphx::fp8::fp8e4m3fn::from_bits())); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx::fp8::fp8e4m3fn fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + migraphx::fp8::fp8e4m3fn fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero is preserved for fp8e4m3fn + EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); +} + +TEST_CASE(test_pos_zero_eq_neg_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); + migraphx::fp8::fp8e5m2 fp8_pzero(pzero); + EXPECT(fp8_nzero == fp8_pzero); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fn fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fn fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fn::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e4m3fn, it gets clipped to max() + migraphx::fp8::fp8e4m3fn fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e4m3fn, it gets clipped to lowest + migraphx::fp8::fp8e4m3fn fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fn fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fn fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fn fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fn(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e4m3fn(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + auto b = migraphx::fp8::fp8e4m3fn(1.0); + auto c = migraphx::fp8::fp8e4m3fn(0.0); + auto d = migraphx::fp8::fp8e4m3fn(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e4m3fn(10.0); + auto f = migraphx::fp8::fp8e4m3fn(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool{f <= e}); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + auto b = migraphx::fp8::fp8e4m3fn(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e4m3fn(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e4m3fnuz.cpp b/test/fp8e4m3fnuz.cpp new file mode 100644 index 00000000000..e86cf8d76a1 --- /dev/null +++ b/test/fp8e4m3fnuz.cpp @@ -0,0 +1,313 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e4m3fnuz_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0f, 0.0009765625f, 0.001953125f, + 0.0029296875f, 0.00390625f, 0.0048828125f, + 0.005859375f, 0.0068359375f, 0.0078125f, + 0.0087890625f, 0.009765625f, 0.0107421875f, + 0.01171875f, 0.0126953125f, 0.013671875f, + 0.0146484375f, 0.015625f, 0.017578125f, + 0.01953125f, 0.021484375f, 0.0234375f, + 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, + 0.04296875f, 0.046875f, 0.05078125f, + 0.0546875f, 0.05859375f, 0.0625f, + 0.0703125f, 0.078125f, 0.0859375f, + 0.09375f, 0.1015625f, 0.109375f, + 0.1171875f, 0.125f, 0.140625f, + 0.15625f, 0.171875f, 0.1875f, + 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, + 0.34375f, 0.375f, 0.40625f, + 0.4375f, 0.46875f, 0.5f, + 0.5625f, 0.625f, 0.6875f, + 0.75f, 0.8125f, 0.875f, + 0.9375f, 1.0f, 1.125f, + 1.25f, 1.375f, 1.5f, + 1.625f, 1.75f, 1.875f, + 2.0f, 2.25f, 2.5f, + 2.75f, 3.0f, 3.25f, + 3.5f, 3.75f, 4.0f, + 4.5f, 5.0f, 5.5f, + 6.0f, 6.5f, 7.0f, + 7.5f, 8.0f, 9.0f, + 10.0f, 11.0f, 12.0f, + 13.0f, 14.0f, 15.0f, + 16.0f, 18.0f, 20.0f, + 22.0f, 24.0f, 26.0f, + 28.0f, 30.0f, 32.0f, + 36.0f, 40.0f, 44.0f, + 48.0f, 52.0f, 56.0f, + 60.0f, 64.0f, 72.0f, + 80.0f, 88.0f, 96.0f, + 104.0f, 112.0f, 120.0f, + 128.0f, 144.0f, 160.0f, + 176.0f, 192.0f, 208.0f, + 224.0f, 240.0f, std::numeric_limits::quiet_NaN(), + -0.0009765625f, -0.001953125f, -0.0029296875f, + -0.00390625f, -0.0048828125f, -0.005859375f, + -0.0068359375f, -0.0078125f, -0.0087890625f, + -0.009765625f, -0.0107421875f, -0.01171875f, + -0.0126953125f, -0.013671875f, -0.0146484375f, + -0.015625f, -0.017578125f, -0.01953125f, + -0.021484375f, -0.0234375f, -0.025390625f, + -0.02734375f, -0.029296875f, -0.03125f, + -0.03515625f, -0.0390625f, -0.04296875f, + -0.046875f, -0.05078125f, -0.0546875f, + -0.05859375f, -0.0625f, -0.0703125f, + -0.078125f, -0.0859375f, -0.09375f, + -0.1015625f, -0.109375f, -0.1171875f, + -0.125f, -0.140625f, -0.15625f, + -0.171875f, -0.1875f, -0.203125f, + -0.21875f, -0.234375f, -0.25f, + -0.28125f, -0.3125f, -0.34375f, + -0.375f, -0.40625f, -0.4375f, + -0.46875f, -0.5f, -0.5625f, + -0.625f, -0.6875f, -0.75f, + -0.8125f, -0.875f, -0.9375f, + -1.0f, -1.125f, -1.25f, + -1.375f, -1.5f, -1.625f, + -1.75f, -1.875f, -2.0f, + -2.25f, -2.5f, -2.75f, + -3.0f, -3.25f, -3.5f, + -3.75f, -4.0f, -4.5f, + -5.0f, -5.5f, -6.0f, + -6.5f, -7.0f, -7.5f, + -8.0f, -9.0f, -10.0f, + -11.0f, -12.0f, -13.0f, + -14.0f, -15.0f, -16.0f, + -18.0f, -20.0f, -22.0f, + -24.0f, -26.0f, -28.0f, + -30.0f, -32.0f, -36.0f, + -40.0f, -44.0f, -48.0f, + -52.0f, -56.0f, -60.0f, + -64.0f, -72.0f, -80.0f, + -88.0f, -96.0f, -104.0f, + -112.0f, -120.0f, -128.0f, + -144.0f, -160.0f, -176.0f, + -192.0f, -208.0f, -224.0f, + -240.0f, + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx::fp8::fp8e4m3fnuz fp8_val(bit_val, migraphx::fp8::fp8e4m3fnuz::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e4m3fnuz_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e4m3fnuz_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = {{256, 0x7f}, {-256, 0xff}, + {240, 0x7f}, {-240, 0xff}, + {1e-07, 0x0}, {1e+07, 0x7f}, + {1, 0x40}, {-1, 0xc0}, + {0.1, 0x25}, {0.11, 0x26}, + {0.111, 0x26}, {0.1111, 0x26}, + {-0.1, 0xa5}, {-0.11, 0xa6}, + {-0.111, 0xa6}, {-0.1111, 0xa6}, + {0.2, 0x2d}, {2, 0x48}, + {20, 0x62}, {200, 0x7c}, + {-0.2, 0xad}, {-2, 0xc8}, + {-20, 0xe2}, {-200, 0xfc}, + {0.5, 0x38}, {-0.5, 0xb8}, + {1.17549e-38, 0x0}, {1.4013e-45, 0x0}, + {0.00390625, 0x4}, {-0.00390625, 0x84}, + {0.00195312, 0x2}, {-0.00195312, 0x82}, + {0.000976562, 0x1}, {-0.000976562, 0x81}, + {0.000488281, 0x0}, {-0.000488281, 0x0}}; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e4m3fnuz(sample.first), + migraphx::fp8::fp8e4m3fnuz(sample.second, migraphx::fp8::fp8e4m3fnuz::from_bits())); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx::fp8::fp8e4m3fnuz fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e4m3fnuz fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero gets converted to positive zero + EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e4m3fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e4m3fnuz::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e4m3fnuz it gets clipped to Nans + migraphx::fp8::fp8e4m3fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e4m3fnuz it gets clipped to NaNs + migraphx::fp8::fp8e4m3fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e4m3fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e4m3fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e4m3fnuz(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e4m3fnuz(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + auto b = migraphx::fp8::fp8e4m3fnuz(1.0); + auto c = migraphx::fp8::fp8e4m3fnuz(0.0); + auto d = migraphx::fp8::fp8e4m3fnuz(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e4m3fnuz(10.0); + auto f = migraphx::fp8::fp8e4m3fnuz(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool{f <= e}); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + auto b = migraphx::fp8::fp8e4m3fnuz(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e4m3fnuz(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2.cpp b/test/fp8e5m2.cpp new file mode 100644 index 00000000000..966aeb63d5c --- /dev/null +++ b/test/fp8e5m2.cpp @@ -0,0 +1,518 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include +#include + +float fp8e5m2_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0, + 1.52587890625e-05, + 3.0517578125e-05, + 4.57763671875e-05, + 6.103515625e-05, + 7.62939453125e-05, + 9.1552734375e-05, + 0.0001068115234375, + 0.0001220703125, + 0.000152587890625, + 0.00018310546875, + 0.000213623046875, + 0.000244140625, + 0.00030517578125, + 0.0003662109375, + 0.00042724609375, + 0.00048828125, + 0.0006103515625, + 0.000732421875, + 0.0008544921875, + 0.0009765625, + 0.001220703125, + 0.00146484375, + 0.001708984375, + 0.001953125, + 0.00244140625, + 0.0029296875, + 0.00341796875, + 0.00390625, + 0.0048828125, + 0.005859375, + 0.0068359375, + 0.0078125, + 0.009765625, + 0.01171875, + 0.013671875, + 0.015625, + 0.01953125, + 0.0234375, + 0.02734375, + 0.03125, + 0.0390625, + 0.046875, + 0.0546875, + 0.0625, + 0.078125, + 0.09375, + 0.109375, + 0.125, + 0.15625, + 0.1875, + 0.21875, + 0.25, + 0.3125, + 0.375, + 0.4375, + 0.5, + 0.625, + 0.75, + 0.875, + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 10.0, + 12.0, + 14.0, + 16.0, + 20.0, + 24.0, + 28.0, + 32.0, + 40.0, + 48.0, + 56.0, + 64.0, + 80.0, + 96.0, + 112.0, + 128.0, + 160.0, + 192.0, + 224.0, + 256.0, + 320.0, + 384.0, + 448.0, + 512.0, + 640.0, + 768.0, + 896.0, + 1024.0, + 1280.0, + 1536.0, + 1792.0, + 2048.0, + 2560.0, + 3072.0, + 3584.0, + 4096.0, + 5120.0, + 6144.0, + 7168.0, + 8192.0, + 10240.0, + 12288.0, + 14336.0, + 16384.0, + 20480.0, + 24576.0, + 28672.0, + 32768.0, + 40960.0, + 49152.0, + 57344.0, + std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.0, + -1.52587890625e-05, + -3.0517578125e-05, + -4.57763671875e-05, + -6.103515625e-05, + -7.62939453125e-05, + -9.1552734375e-05, + -0.0001068115234375, + -0.0001220703125, + -0.000152587890625, + -0.00018310546875, + -0.000213623046875, + -0.000244140625, + -0.00030517578125, + -0.0003662109375, + -0.00042724609375, + -0.00048828125, + -0.0006103515625, + -0.000732421875, + -0.0008544921875, + -0.0009765625, + -0.001220703125, + -0.00146484375, + -0.001708984375, + -0.001953125, + -0.00244140625, + -0.0029296875, + -0.00341796875, + -0.00390625, + -0.0048828125, + -0.005859375, + -0.0068359375, + -0.0078125, + -0.009765625, + -0.01171875, + -0.013671875, + -0.015625, + -0.01953125, + -0.0234375, + -0.02734375, + -0.03125, + -0.0390625, + -0.046875, + -0.0546875, + -0.0625, + -0.078125, + -0.09375, + -0.109375, + -0.125, + -0.15625, + -0.1875, + -0.21875, + -0.25, + -0.3125, + -0.375, + -0.4375, + -0.5, + -0.625, + -0.75, + -0.875, + -1.0, + -1.25, + -1.5, + -1.75, + -2.0, + -2.5, + -3.0, + -3.5, + -4.0, + -5.0, + -6.0, + -7.0, + -8.0, + -10.0, + -12.0, + -14.0, + -16.0, + -20.0, + -24.0, + -28.0, + -32.0, + -40.0, + -48.0, + -56.0, + -64.0, + -80.0, + -96.0, + -112.0, + -128.0, + -160.0, + -192.0, + -224.0, + -256.0, + -320.0, + -384.0, + -448.0, + -512.0, + -640.0, + -768.0, + -896.0, + -1024.0, + -1280.0, + -1536.0, + -1792.0, + -2048.0, + -2560.0, + -3072.0, + -3584.0, + -4096.0, + -5120.0, + -6144.0, + -7168.0, + -8192.0, + -10240.0, + -12288.0, + -14336.0, + -16384.0, + -20480.0, + -24576.0, + -28672.0, + -32768.0, + -40960.0, + -49152.0, + -57344.0, + -1.0f * std::numeric_limits::infinity(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx::fp8::fp8e5m2 fp8_val(bit_val, migraphx::fp8::fp8e5m2::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2_to_fp32_value(bit_val))) + { + return true; + } + else if(std::isinf(float(fp8_val)) and std::isinf(fp8e5m2_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e5m2_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = { + {-60000, 0xfb}, + {-57344, 0xfb}, + {-448, 0xdf}, + {-256, 0xdc}, + {-240, 0xdc}, + {-200, 0xda}, + {-20, 0xcd}, + {-2, 0xc0}, + {-1, 0xbc}, + {-0.5, 0xb8}, + {-0.2, 0xb2}, + {-0.1111, 0xaf}, + {-0.111, 0xaf}, + {-0.11, 0xaf}, + {-0.1, 0xae}, + {6.10351e-05, 0x4}, + {-6.10351e-05, 0x84}, + {3.05176e-05, 0x2}, + {-3.05176e-05, 0x82}, + {1.52588e-05, 0x1}, + {-1.52588e-05, 0x81}, + {7.62939e-06, 0x0}, + {-7.62939e-06, 0x80}, + {0.1, 0x2e}, + {0.11, 0x2f}, + {0.111, 0x2f}, + {0.1111, 0x2f}, + {0.2, 0x32}, + {0.5, 0x38}, + {1, 0x3c}, + {2, 0x40}, + {20, 0x4d}, + {200, 0x5a}, + {240, 0x5c}, + {256, 0x5c}, + {448, 0x5f}, + {57344, 0x7b}, + {60000, 0x7b}, + {1e+07, 0x7b}, + }; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e5m2(sample.first), + migraphx::fp8::fp8e5m2(sample.second, migraphx::fp8::fp8e5m2::from_bits())); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx::fp8::fp8e5m2 fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero is preserved for fp8e5m2 + EXPECT(migraphx::float_equal(nzero, float(fp8_nzero))); +} + +TEST_CASE(test_pos_zero_eq_neg_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e5m2 fp8_nzero(nzero); + migraphx::fp8::fp8e5m2 fp8_pzero(pzero); + EXPECT(fp8_nzero == fp8_pzero); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2 fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2 fp8_nan(fnan.data, migraphx::fp8::fp8e5m2::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + // float infinity should get clipped to max + float finf = std::numeric_limits::infinity(); + migraphx::fp8::fp8e5m2 fp8_max(finf); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e5m2, it gets clipped to lowest + migraphx::fp8::fp8e5m2 fp8_lowest(finf); + EXPECT(bool{fp8_lowest == std::numeric_limits::lowest()}); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx::fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e5m2 fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2 fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2 fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2(std::numeric_limits::quiet_NaN()))); + EXPECT(not std::isfinite(std::numeric_limits::infinity())); + // -1.0 * inf is float(-inf) which with clipping/saturation gets converted into fp8::lowest() + EXPECT(std::isfinite( + migraphx::fp8::fp8e5m2(-1.0 * std::numeric_limits::infinity()))); + EXPECT(not std::isfinite(migraphx::fp8::fp8e5m2(0xFC, migraphx::fp8::fp8e5m2::from_bits()))); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + auto c = migraphx::fp8::fp8e5m2(0.0); + auto d = migraphx::fp8::fp8e5m2(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2(10.0); + auto f = migraphx::fp8::fp8e5m2(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool{f <= e}); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + auto b = migraphx::fp8::fp8e5m2(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e5m2(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/fp8e5m2fnuz.cpp b/test/fp8e5m2fnuz.cpp new file mode 100644 index 00000000000..14be8bc80d7 --- /dev/null +++ b/test/fp8e5m2fnuz.cpp @@ -0,0 +1,477 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include +#include +#include "test.hpp" + +#include + +float fp8e5m2fnuz_to_fp32_value(uint8_t input) +{ + constexpr std::array e4m3fnuz_lut = { + 0.0, + 7.62939453125e-06, + 1.52587890625e-05, + 2.288818359375e-05, + 3.0517578125e-05, + 3.814697265625e-05, + 4.57763671875e-05, + 5.340576171875e-05, + 6.103515625e-05, + 7.62939453125e-05, + 9.1552734375e-05, + 0.0001068115234375, + 0.0001220703125, + 0.000152587890625, + 0.00018310546875, + 0.000213623046875, + 0.000244140625, + 0.00030517578125, + 0.0003662109375, + 0.00042724609375, + 0.00048828125, + 0.0006103515625, + 0.000732421875, + 0.0008544921875, + 0.0009765625, + 0.001220703125, + 0.00146484375, + 0.001708984375, + 0.001953125, + 0.00244140625, + 0.0029296875, + 0.00341796875, + 0.00390625, + 0.0048828125, + 0.005859375, + 0.0068359375, + 0.0078125, + 0.009765625, + 0.01171875, + 0.013671875, + 0.015625, + 0.01953125, + 0.0234375, + 0.02734375, + 0.03125, + 0.0390625, + 0.046875, + 0.0546875, + 0.0625, + 0.078125, + 0.09375, + 0.109375, + 0.125, + 0.15625, + 0.1875, + 0.21875, + 0.25, + 0.3125, + 0.375, + 0.4375, + 0.5, + 0.625, + 0.75, + 0.875, + 1.0, + 1.25, + 1.5, + 1.75, + 2.0, + 2.5, + 3.0, + 3.5, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 10.0, + 12.0, + 14.0, + 16.0, + 20.0, + 24.0, + 28.0, + 32.0, + 40.0, + 48.0, + 56.0, + 64.0, + 80.0, + 96.0, + 112.0, + 128.0, + 160.0, + 192.0, + 224.0, + 256.0, + 320.0, + 384.0, + 448.0, + 512.0, + 640.0, + 768.0, + 896.0, + 1024.0, + 1280.0, + 1536.0, + 1792.0, + 2048.0, + 2560.0, + 3072.0, + 3584.0, + 4096.0, + 5120.0, + 6144.0, + 7168.0, + 8192.0, + 10240.0, + 12288.0, + 14336.0, + 16384.0, + 20480.0, + 24576.0, + 28672.0, + 32768.0, + 40960.0, + 49152.0, + 57344.0, + std::numeric_limits::quiet_NaN(), + -7.62939453125e-06, + -1.52587890625e-05, + -2.288818359375e-05, + -3.0517578125e-05, + -3.814697265625e-05, + -4.57763671875e-05, + -5.340576171875e-05, + -6.103515625e-05, + -7.62939453125e-05, + -9.1552734375e-05, + -0.0001068115234375, + -0.0001220703125, + -0.000152587890625, + -0.00018310546875, + -0.000213623046875, + -0.000244140625, + -0.00030517578125, + -0.0003662109375, + -0.00042724609375, + -0.00048828125, + -0.0006103515625, + -0.000732421875, + -0.0008544921875, + -0.0009765625, + -0.001220703125, + -0.00146484375, + -0.001708984375, + -0.001953125, + -0.00244140625, + -0.0029296875, + -0.00341796875, + -0.00390625, + -0.0048828125, + -0.005859375, + -0.0068359375, + -0.0078125, + -0.009765625, + -0.01171875, + -0.013671875, + -0.015625, + -0.01953125, + -0.0234375, + -0.02734375, + -0.03125, + -0.0390625, + -0.046875, + -0.0546875, + -0.0625, + -0.078125, + -0.09375, + -0.109375, + -0.125, + -0.15625, + -0.1875, + -0.21875, + -0.25, + -0.3125, + -0.375, + -0.4375, + -0.5, + -0.625, + -0.75, + -0.875, + -1.0, + -1.25, + -1.5, + -1.75, + -2.0, + -2.5, + -3.0, + -3.5, + -4.0, + -5.0, + -6.0, + -7.0, + -8.0, + -10.0, + -12.0, + -14.0, + -16.0, + -20.0, + -24.0, + -28.0, + -32.0, + -40.0, + -48.0, + -56.0, + -64.0, + -80.0, + -96.0, + -112.0, + -128.0, + -160.0, + -192.0, + -224.0, + -256.0, + -320.0, + -384.0, + -448.0, + -512.0, + -640.0, + -768.0, + -896.0, + -1024.0, + -1280.0, + -1536.0, + -1792.0, + -2048.0, + -2560.0, + -3072.0, + -3584.0, + -4096.0, + -5120.0, + -6144.0, + -7168.0, + -8192.0, + -10240.0, + -12288.0, + -14336.0, + -16384.0, + -20480.0, + -24576.0, + -28672.0, + -32768.0, + -40960.0, + -49152.0, + -57344.0, + }; + + return e4m3fnuz_lut[input]; +} + +TEST_CASE(test_fp8_cast_to_float) +{ + std::vector bit_vals(256); + std::iota(bit_vals.begin(), bit_vals.end(), 0); + EXPECT(bool{std::all_of(bit_vals.begin(), bit_vals.end(), [](uint8_t bit_val) { + migraphx::fp8::fp8e5m2fnuz fp8_val(bit_val, migraphx::fp8::fp8e5m2fnuz::from_bits()); + if(std::isnan(float(fp8_val)) and std::isnan(fp8e5m2fnuz_to_fp32_value(bit_val))) + { + return true; + } + return migraphx::float_equal(float(fp8_val), fp8e5m2fnuz_to_fp32_value(bit_val)); + })}); +} + +TEST_CASE(test_fp8_cast_from_float) +{ + std::unordered_map test_vals = { + {57344, 0x7f}, {-57344, 0xff}, {60000, 0x7f}, {-60000, 0xff}, + {448, 0x63}, {-448, 0xe3}, {256, 0x60}, {-256, 0xe0}, + {240, 0x60}, {-240, 0xe0}, {3.05176e-05, 0x4}, {-3.05176e-05, 0x84}, + {1.52588e-05, 0x2}, {-1.52588e-05, 0x82}, {7.62939e-06, 0x1}, {-7.62939e-06, 0x81}, + {3.81469e-06, 0x0}, {-3.81469e-06, 0x0}, {1e+07, 0x7f}, {1, 0x40}, + {-1, 0xc0}, {0.1, 0x32}, {0.11, 0x33}, {0.111, 0x33}, + {0.1111, 0x33}, {-0.1, 0xb2}, {-0.11, 0xb3}, {-0.111, 0xb3}, + {-0.1111, 0xb3}, {0.2, 0x36}, {2, 0x44}, {20, 0x51}, + {200, 0x5e}, {-0.2, 0xb6}, {-2, 0xc4}, {-20, 0xd1}, + {-200, 0xde}, {0.5, 0x3c}, {-0.5, 0xbc}, {1.17549e-38, 0x0}, + {1.4013e-45, 0x0}, + }; + + EXPECT(bool{std::all_of(test_vals.begin(), test_vals.end(), [](const auto sample) { + return migraphx::float_equal( + migraphx::fp8::fp8e5m2fnuz(sample.first), + migraphx::fp8::fp8e5m2fnuz(sample.second, migraphx::fp8::fp8e5m2fnuz::from_bits())); + })}); +} + +TEST_CASE(test_positive_zero) +{ + float zero = 0.0; + migraphx::fp8::fp8e5m2fnuz fp8_zero(zero); + EXPECT(fp8_zero.is_zero()); + EXPECT(migraphx::float_equal(zero, float(fp8_zero))); +} + +TEST_CASE(test_negative_zero) +{ + float nzero = -0.0; + float pzero = 0.0; + migraphx::fp8::fp8e5m2fnuz fp8_nzero(nzero); + EXPECT(fp8_nzero.is_zero()); + // negative zero gets converted to positive zero + EXPECT(migraphx::float_equal(pzero, float(fp8_nzero))); +} + +TEST_CASE(test_nan_1) +{ + float fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); +} + +TEST_CASE(test_nan_2) +{ + auto fnan = std::numeric_limits::quiet_NaN(); + migraphx::fp8::fp8e5m2fnuz fp8_nan(fnan.data, migraphx::fp8::fp8e5m2fnuz::from_bits()); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(fp8_nan)); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_1) +{ + float finf = std::numeric_limits::infinity(); + // no inf in fp8e5m2fnuz it gets clipped to Nans + migraphx::fp8::fp8e5m2fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_infinity_2) +{ + // neg inf + float finf = -1.0 * std::numeric_limits::infinity(); + // no inf in fp8e5m2fnuz it gets clipped to NaNs + migraphx::fp8::fp8e5m2fnuz fp8_nan(finf); + EXPECT(fp8_nan.is_nan()); + EXPECT(std::isnan(float(fp8_nan))); +} + +TEST_CASE(test_numeric_max_1) +{ + float fmax = std::numeric_limits::max(); + migraphx::fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_max_2) +{ + // gets clipped to max + float fmax = 2 * std::numeric_limits::max(); + migraphx::fp8::fp8e5m2fnuz fp8_max(fmax); + EXPECT(fp8_max == std::numeric_limits::max()); +} + +TEST_CASE(test_numeric_lowest_1) +{ + float flowest = std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2fnuz fp8_lowest(flowest); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_numeric_lowest_2) +{ + // gets clipped to lowest + float fmin = 2.0 * std::numeric_limits::lowest(); + migraphx::fp8::fp8e5m2fnuz fp8_lowest(fmin); + EXPECT(fp8_lowest == std::numeric_limits::lowest()); +} + +TEST_CASE(test_max_eq_lowest) +{ + EXPECT(migraphx::float_equal(std::numeric_limits::lowest(), + -1 * std::numeric_limits::max())); +} + +TEST_CASE(test_isfinite) +{ + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(0.0))); + EXPECT(std::isfinite(migraphx::fp8::fp8e5m2fnuz(-0.0))); + EXPECT(not std::isfinite( + migraphx::fp8::fp8e5m2fnuz(std::numeric_limits::quiet_NaN()))); +} + +TEST_CASE(test_no_infinity) +{ + EXPECT(not bool{std::numeric_limits::has_infinity}); +} + +TEST_CASE(test_binary_ops) +{ + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + auto b = migraphx::fp8::fp8e5m2fnuz(1.0); + auto c = migraphx::fp8::fp8e5m2fnuz(0.0); + auto d = migraphx::fp8::fp8e5m2fnuz(-0.0); + EXPECT(migraphx::float_equal((c + d), c)); + EXPECT(migraphx::float_equal((c + d), d)); + EXPECT(migraphx::float_equal((a + b), c)); + EXPECT(migraphx::float_equal((a + b), d)); + + auto e = migraphx::fp8::fp8e5m2fnuz(10.0); + auto f = migraphx::fp8::fp8e5m2fnuz(-10.0); + EXPECT(bool{e > f}); + EXPECT(bool{f < e}); + EXPECT(bool{f <= e}); + EXPECT(bool{e >= f}); + EXPECT(bool{e <= e}); + EXPECT(bool{f >= f}); + EXPECT(not migraphx::float_equal(f, e)); +} + +TEST_CASE(test_fabs) +{ + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + auto b = migraphx::fp8::fp8e5m2fnuz(1.0); + EXPECT(migraphx::float_equal(b, migraphx::fp8::fabs(a))); +} + +TEST_CASE(test_stream_op) +{ + auto a = migraphx::fp8::fp8e5m2fnuz(-1.0); + std::stringstream ss; + ss << a; + EXPECT(std::string("-1") == ss.str()); + ss = std::stringstream(); + auto b = std::numeric_limits::quiet_NaN(); + ss << b; + EXPECT(std::string("nan") == ss.str()); +} +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/gpu/jit.cpp b/test/gpu/jit.cpp index b92f1419310..2b407178681 100644 --- a/test/gpu/jit.cpp +++ b/test/gpu/jit.cpp @@ -237,12 +237,12 @@ TEST_CASE(code_object_hip) std::vector expected_inputs = {input, input}; auto co = migraphx::make_op("gpu::code_object", - {{"code_object", migraphx::value::binary{binaries.front()}}, - {"symbol_name", "add_2"}, - {"global", input.elements()}, - {"local", 1024}, - {"expected_inputs", migraphx::to_value(expected_inputs)}, - {"output", migraphx::to_value(input)}}); + {{"code_object", migraphx::value::binary{binaries.front()}}, + {"symbol_name", "add_2"}, + {"global", input.elements()}, + {"local", 1024}, + {"expected_inputs", migraphx::to_value(expected_inputs)}, + {"output", migraphx::to_value(input)}}); migraphx::program p; auto* mm = p.get_main_module(); @@ -348,7 +348,10 @@ TEST_CASE(compile_math) auto vec_sizes = {2, 4, 6}; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::fp8e4m3fnuz_type, + migraphx::shape::tuple_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) @@ -396,7 +399,10 @@ TEST_CASE(assert_type_min_max) migraphx::gpu::hip_compile_options options; for(auto&& t : migraphx::shape::types()) { - if(contains({migraphx::shape::bool_type, migraphx::shape::tuple_type}, t)) + if(contains({migraphx::shape::bool_type, + migraphx::shape::fp8e4m3fnuz_type, + migraphx::shape::tuple_type}, + t)) continue; auto name = migraphx::shape::cpp_type(t); if(t == migraphx::shape::half_type) diff --git a/tools/api/migraphx.h b/tools/api/migraphx.h index 8179cfffd52..57441279b18 100644 --- a/tools/api/migraphx.h +++ b/tools/api/migraphx.h @@ -44,7 +44,8 @@ m(int32_type, int32_t) \ m(int64_type, int64_t) \ m(uint32_type, uint32_t) \ - m(uint64_type, uint64_t) + m(uint64_type, uint64_t) \ + m(fp8e4m3fnuz_type, migraphx::fp8::fp8e4m3fnuz) // clang-format on #ifdef __cplusplus @@ -70,7 +71,9 @@ typedef enum } migraphx_shape_datatype_t; #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES -<% generate_c_header() %> +<% + generate_c_header() +%> #ifdef __cplusplus }