From f48e24385b33f1f79945941ffd8fa66ac2771aac Mon Sep 17 00:00:00 2001 From: Gregory Pataky Date: Wed, 14 Aug 2024 16:04:11 -0700 Subject: [PATCH] Add FP8 support to the exhaustive tests Adds new tests for two FP8 variants to both unary and binary exhaustive tests. PiperOrigin-RevId: 663087125 --- xla/tests/exhaustive/BUILD | 1 + .../exhaustive_binary_test_definitions.inc | 114 ++++++++++++++- ...ary_test_f16_and_smaller_instantiation.inc | 18 ++- ...haustive_binary_test_f32_instantiation.inc | 4 + ...haustive_binary_test_f64_instantiation.inc | 4 + .../exhaustive_binary_test_functions.cc | 24 +++- .../exhaustive/exhaustive_op_test_base.cc | 5 + .../exhaustive/exhaustive_op_test_utils.cc | 4 + .../exhaustive/exhaustive_op_test_utils.h | 33 ++++- .../exhaustive_unary_test_definitions.inc | 32 ++++- ...ary_test_f32_and_smaller_instantiation.inc | 16 ++- ...xhaustive_unary_test_f64_instantiation.inc | 4 + .../exhaustive_unary_test_functions.cc | 135 ++++++++++++++++-- 13 files changed, 366 insertions(+), 28 deletions(-) diff --git a/xla/tests/exhaustive/BUILD b/xla/tests/exhaustive/BUILD index a6b268e326e7d8..eeca4dc2ebe3aa 100644 --- a/xla/tests/exhaustive/BUILD +++ b/xla/tests/exhaustive/BUILD @@ -142,6 +142,7 @@ exhaustive_xla_test( ":exhaustive_op_test_utils", ":exhaustive_unary_test_textual_hdrs", "//xla:literal", + "//xla:types", "//xla/client:xla_builder", "//xla/client/lib:constants", "//xla/client/lib:math", diff --git a/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc b/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc index 8fe0a71d45277d..e4feb10c9918cd 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc +++ b/xla/tests/exhaustive/exhaustive_binary_test_definitions.inc @@ -13,6 +13,92 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// Exhaustive test for binary operations for 8-bit floating point types, +// including float16 and bfloat. +// +// Test parameter is a pair of (begin, end) for range under test. +template +class Exhaustive8BitBinaryTest + : public ExhaustiveBinaryTest, + public ::testing::WithParamInterface> { + public: + int64_t GetInputSize() override { + int64_t begin, end; + std::tie(begin, end) = GetParam(); + return end - begin; + } + + // Given a range of uint64_t representation, uses bits 7..0 and bits 15..8 + // for the values of src0 and src1 (see below for ordering) for the 8-bit + // binary operation being tested, and generate the cartesian product of the + // two sets as the two inputs for the test. + // + // If `kLeftToRightPacking == true`, then bits 15..8 are interpreted as src0 + // and bits 7..0 are interpreted as src1. If `kLeftToRightPacking == false`, + // then bits 15..8 are interpreted as src1 and 7..0 are interpreted as src0. + void FillInput(std::array* input_literals) override { + int64_t input_size = GetInputSize(); + CHECK_EQ(input_size, (*input_literals)[0].element_count()); + CHECK_EQ(input_size, (*input_literals)[1].element_count()); + + int64_t begin, end; + std::tie(begin, end) = GetParam(); + + if (VLOG_IS_ON(2)) { + uint8_t left_begin, left_end, right_begin, right_end; + if constexpr (kLeftToRightPacking) { + left_begin = std::bit_cast(static_cast(begin >> 8)); + left_end = std::bit_cast(static_cast(end >> 8)); + right_begin = std::bit_cast(static_cast(begin)); + right_end = std::bit_cast(static_cast(end)); + } else { + left_begin = std::bit_cast(static_cast(begin)); + left_end = std::bit_cast(static_cast(end)); + right_begin = std::bit_cast(static_cast(begin >> 8)); + right_end = std::bit_cast(static_cast(end >> 8)); + } + + LOG(INFO) << this->SuiteName() << this->TestName() << " Range:"; + // N.B.: Cast to u32 to avoid printing values as char. + LOG(INFO) << "\tfrom=(" << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); hex=(" << std::hex + << static_cast(left_begin) << ", " + << static_cast(right_begin) << "); float=(" + << std::bit_cast(left_begin) << ", " + << std::bit_cast(right_begin) + << ") (inclusive)"; + LOG(INFO) << "\tto=(" << static_cast(left_end) << ", " + << static_cast(right_end) << "); hex=(" << std::hex + << static_cast(left_end) << ", " + << static_cast(right_end) << "); float=(" + << std::bit_cast(left_end) << ", " + << std::bit_cast(right_end) + << ") (exclusive)"; + LOG(INFO) << "\ttotal values to test=" << (end - begin); + } + + absl::Span input_arr_0 = (*input_literals)[0].data(); + absl::Span input_arr_1 = (*input_literals)[1].data(); + for (int64_t i = 0; i < input_size; i++) { + uint32_t input_val = i + begin; + // Convert the packed bits to a pair of NativeT and replace known + // incorrect input values with 0. + // + // In either case, we only use 16 bits out of the 64 bits possible. + if constexpr (kLeftToRightPacking) { + input_arr_0[i] = this->ConvertValue(input_val >> 8); + input_arr_1[i] = this->ConvertValue(input_val); + } else { + input_arr_0[i] = this->ConvertValue(input_val); + input_arr_1[i] = this->ConvertValue(input_val >> 8); + } + } + } + + protected: + using typename ExhaustiveBinaryTest::NativeT; +}; + // Exhaustive test for binary operations for 16 bit floating point types, // including float16 and bfloat. // @@ -147,11 +233,29 @@ class Exhaustive32BitOrMoreBinaryTest } }; +using ExhaustiveF8E4M3FNBinaryTest = Exhaustive8BitBinaryTest; +using ExhaustiveF8E5M2BinaryTest = Exhaustive8BitBinaryTest; using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest; using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest; using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest; using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define BINARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNBinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define BINARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2BinaryTest, test_name) \ + __VA_ARGS__ +#else +#define BINARY_TEST_E5M2(test_name, ...) +#endif + #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) #define BINARY_TEST_F16(test_name, ...) \ XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \ @@ -180,10 +284,12 @@ using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest; #define BINARY_TEST_F64(test_name, ...) #endif -#define BINARY_TEST(test_name, ...) \ - BINARY_TEST_F16(test_name, __VA_ARGS__) \ - BINARY_TEST_BF16(test_name, __VA_ARGS__) \ - BINARY_TEST_F32(test_name, __VA_ARGS__) \ +#define BINARY_TEST(test_name, ...) \ + BINARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + BINARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + BINARY_TEST_F16(test_name, __VA_ARGS__) \ + BINARY_TEST_BF16(test_name, __VA_ARGS__) \ + BINARY_TEST_F32(test_name, __VA_ARGS__) \ BINARY_TEST_F64(test_name, __VA_ARGS__) #define BINARY_TEST_COMPLEX(test_name, ...) \ diff --git a/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc b/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc index 1e88061028a65d..04456e6f3a8eaa 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc +++ b/xla/tests/exhaustive/exhaustive_binary_test_f16_and_smaller_instantiation.inc @@ -13,16 +13,30 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNBinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest, + ::testing::ValuesIn(CreateExhaustiveU16Ranges())); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); +#endif + #if defined(XLA_BACKEND_SUPPORTS_BFLOAT16) INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); #else GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); #endif #if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16) INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); #else GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); #endif diff --git a/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc b/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc index 1c8e97d1d5d41e..ba62061d7437fc 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc +++ b/xla/tests/exhaustive/exhaustive_binary_test_f32_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); diff --git a/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc b/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc index 0de1e1242d6b7c..a91f93ee155d45 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc +++ b/xla/tests/exhaustive/exhaustive_binary_test_f64_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest); diff --git a/xla/tests/exhaustive/exhaustive_binary_test_functions.cc b/xla/tests/exhaustive/exhaustive_binary_test_functions.cc index 07485354cf2427..217c16ab7af9c8 100644 --- a/xla/tests/exhaustive/exhaustive_binary_test_functions.cc +++ b/xla/tests/exhaustive/exhaustive_binary_test_functions.cc @@ -299,7 +299,13 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) { BINARY_TEST(Pow, { PowOp(this) .CpuError(+[](NativeT left, NativeT right) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } else if constexpr (std::is_same_v) { return ErrorSpec::Builder() .strict_signed_zeros() .skip_comparison(PowCpuGpuF16Skip(left, right)) @@ -357,7 +363,14 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) { BINARY_TEST(Atan2, { Atan2Op(this) .CpuError([](NativeT left, NativeT right) { - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + + } else if constexpr (std::is_same_v) { return ErrorSpec::Builder() .abs_err( Atan2CpuBf16F32F64AbsErr(left, right)) @@ -383,6 +396,13 @@ BINARY_TEST(Atan2, { return ErrorSpec::Builder().strict_signed_zeros().build(); }) .GpuError(+[](NativeT, NativeT) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder() + .distance_err(1) + .strict_signed_zeros() + .build(); + } if constexpr (std::is_same_v || std::is_same_v) { return ErrorSpec::Builder() diff --git a/xla/tests/exhaustive/exhaustive_op_test_base.cc b/xla/tests/exhaustive/exhaustive_op_test_base.cc index 1e393f20078e32..2d243ad7f90993 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_base.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_base.cc @@ -43,6 +43,7 @@ limitations under the License. #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" #include "xla/executable_run_options.h" +#include "xla/fp_util.h" #include "xla/literal.h" #include "xla/literal_util.h" #include "xla/service/shaped_buffer.h" @@ -864,11 +865,15 @@ template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; +template class ExhaustiveOpTestBase; } // namespace exhaustive_op_test } // namespace xla diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.cc b/xla/tests/exhaustive/exhaustive_op_test_utils.cc index c6dafe39b1f2ff..ae339f82c70b2c 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.cc +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.cc @@ -44,11 +44,15 @@ template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; +template class ExhaustiveOpTestTraits; bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); } diff --git a/xla/tests/exhaustive/exhaustive_op_test_utils.h b/xla/tests/exhaustive/exhaustive_op_test_utils.h index 62bf6786330ec4..1f4985a4c9e8c4 100644 --- a/xla/tests/exhaustive/exhaustive_op_test_utils.h +++ b/xla/tests/exhaustive/exhaustive_op_test_utils.h @@ -39,7 +39,6 @@ limitations under the License. #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "xla/client/xla_builder.h" -#include "xla/fp_util.h" #include "xla/literal.h" #include "xla/primitive_util.h" #include "xla/tests/exhaustive/error_spec.h" @@ -51,7 +50,7 @@ namespace exhaustive_op_test { // The primitive type used to compute the reference output. constexpr PrimitiveType Ref(PrimitiveType T) { - return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32; + return (!primitive_util::IsFloatingPointType(T) || T == F64) ? T : F32; } // The primitive type of the component of T. If T is not complex, then @@ -195,6 +194,16 @@ inline ErrorSpec DefaultSpecGenerator(xla::bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template <> inline ErrorSpec DefaultSpecGenerator(double, double) { double atol = kDefaultAbsoluteToleranceSlackFactor * @@ -231,6 +240,18 @@ inline ErrorSpec DefaultSpecGenerator(bfloat16, bfloat16) { return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build(); } +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e4m3fn, + tsl::float8_e4m3fn) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + +template <> +inline ErrorSpec DefaultSpecGenerator(tsl::float8_e5m2, + tsl::float8_e5m2) { + return ErrorSpec::Builder().strict_signed_zeros().build(); +} + template typename ExhaustiveOpTestTraits::ErrorSpecGen GetDefaultSpecGenerator() { // Select overload by casting to fn ptr type. @@ -782,7 +803,13 @@ CreateSubnormalExhaustiveRanges() { return ret; } -inline std::vector> CreateExhaustiveF32Ranges() { +inline std::vector> CreateExhaustiveU16Ranges() { + // The entire U16 range is small enough that we don't need to do any + // partitioning. + return {{0, std::numeric_limits::max()}}; +} + +inline std::vector> CreateExhaustiveU32Ranges() { // We break up the 2^32-element space into small-ish chunks to keep peak // memory usage low. std::vector> result; diff --git a/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc b/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc index 64f491fbb7bcb7..dc160bac741954 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc +++ b/xla/tests/exhaustive/exhaustive_unary_test_definitions.inc @@ -67,9 +67,11 @@ class Exhaustive32BitOrLessUnaryTest } }; -using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; -using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF8E4M3FNUnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF8E5M2UnaryTest = Exhaustive32BitOrLessUnaryTest; using ExhaustiveBF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF16UnaryTest = Exhaustive32BitOrLessUnaryTest; +using ExhaustiveF32UnaryTest = Exhaustive32BitOrLessUnaryTest; // Exhaustive test for unary operations for double. // @@ -105,6 +107,22 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, } }; +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +#define UNARY_TEST_F8E4M3FN(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E4M3FNUnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E4M3FN(test_name, ...) +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +#define UNARY_TEST_F8E5M2(test_name, ...) \ + XLA_TEST_P(ExhaustiveF8E5M2UnaryTest, test_name) \ + __VA_ARGS__ +#else +#define UNARY_TEST_E5M2(test_name, ...) +#endif + #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 #define UNARY_TEST_BF16(test_name, ...) \ XLA_TEST_P(ExhaustiveBF16UnaryTest, test_name) \ @@ -133,8 +151,10 @@ class ExhaustiveF64UnaryTest : public ExhaustiveUnaryTest, #define UNARY_TEST_F64(test_name, ...) #endif -#define UNARY_TEST(test_name, ...) \ - UNARY_TEST_BF16(test_name, __VA_ARGS__) \ - UNARY_TEST_F16(test_name, __VA_ARGS__) \ - UNARY_TEST_F32(test_name, __VA_ARGS__) \ +#define UNARY_TEST(test_name, ...) \ + UNARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \ + UNARY_TEST_F8E5M2(test_name, __VA_ARGS__) \ + UNARY_TEST_BF16(test_name, __VA_ARGS__) \ + UNARY_TEST_F16(test_name, __VA_ARGS__) \ + UNARY_TEST_F32(test_name, __VA_ARGS__) \ UNARY_TEST_F64(test_name, __VA_ARGS__) diff --git a/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc b/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc index a958e2bbc88c74..b0c1a087b9283b 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc +++ b/xla/tests/exhaustive/exhaustive_unary_test_f32_and_smaller_instantiation.inc @@ -13,6 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN) +INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNUnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); +#endif + +#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2) +INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2UnaryTest, + ::testing::Values(std::make_pair(0, 1 << 8))); +#else +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); +#endif + #ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16UnaryTest, ::testing::Values(std::make_pair(0, 1 << 16))); @@ -28,6 +42,6 @@ GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); #endif INSTANTIATE_TEST_SUITE_P(F32, ExhaustiveF32UnaryTest, - ::testing::ValuesIn(CreateExhaustiveF32Ranges())); + ::testing::ValuesIn(CreateExhaustiveU32Ranges())); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF64UnaryTest); diff --git a/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc b/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc index b558fb85f3f8e8..a2e67ff4f8fb0c 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc +++ b/xla/tests/exhaustive/exhaustive_unary_test_f64_instantiation.inc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNUnaryTest); + +GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2UnaryTest); + GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16UnaryTest); GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16UnaryTest); diff --git a/xla/tests/exhaustive/exhaustive_unary_test_functions.cc b/xla/tests/exhaustive/exhaustive_unary_test_functions.cc index 5baa12f15ca455..a8af6bcc915bc4 100644 --- a/xla/tests/exhaustive/exhaustive_unary_test_functions.cc +++ b/xla/tests/exhaustive/exhaustive_unary_test_functions.cc @@ -29,6 +29,7 @@ limitations under the License. #include "xla/tests/exhaustive/exhaustive_op_test_utils.h" #include "xla/tests/exhaustive/exhaustive_unary_test_definitions.h" #include "xla/tests/exhaustive/test_op.h" // IWYU pragma: keep, exhaustive_unary_test_ops.inc +#include "xla/types.h" #ifdef __FAST_MATH__ #error "Can't be compiled with fast math on" @@ -41,9 +42,39 @@ namespace { #include "xla/tests/exhaustive/exhaustive_unary_test_ops.inc" UNARY_TEST(Log, { LogOp(this).Error(GetDefaultSpecGenerator()).Run(); }) -UNARY_TEST(Log1p, { Log1pOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Log1p, { + Log1pOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) -UNARY_TEST(Exp, { ExpOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Exp, { + ExpOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) UNARY_TEST(Expm1, { Expm1Op(this).Error(GetDefaultSpecGenerator()).Run(); }) UNARY_TEST(Logistic, { @@ -54,8 +85,20 @@ UNARY_TEST(Logistic, { } return std::abs(out) <= 1.0f; }) - // FIXME(rmlarsen): Break into region around zero and everything else. - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + // FIXME(rmlarsen): Break into region around zero and everything else. + return GetDefaultSpecGenerator()(x); + }) .Run(); }) @@ -135,13 +178,46 @@ UNARY_TEST(Acosh, { }) .Run(); }) -UNARY_TEST(Asinh, { AsinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) -UNARY_TEST(Atanh, { AtanhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Asinh, { + AsinhOp(this) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) +UNARY_TEST(Atanh, { + AtanhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) // Tests for inverse trigonometric functions. UNARY_TEST(Acos, { AcosOp(this) - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) .GpuError(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(1e-6).rel_err(10 * eps).build(); @@ -175,12 +251,44 @@ UNARY_TEST(Atan, { UNARY_TEST(Cosh, { CoshOp(this) - .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v || + std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) .OutputRangeCheck( +[](NativeInputs in, NativeT actual) { return !(actual < 1); }) .Run(); }) -UNARY_TEST(Sinh, { SinhOp(this).Error(GetDefaultSpecGenerator()).Run(); }) +UNARY_TEST(Sinh, { + SinhOp(this) + .Error(GetDefaultSpecGenerator()) + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(3).build(); + } else if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(4).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .GpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + return GetDefaultSpecGenerator()(x); + }) + .Run(); +}) UNARY_TEST(Tanh, { TanhOp(this) .Error(GetDefaultSpecGenerator()) @@ -275,7 +383,14 @@ UNARY_TEST(ErfInv, { UNARY_TEST(Digamma, { DigammaOp(this) - .Error(+[](NativeT x) { + .CpuError(+[](NativeT x) { + if constexpr (std::is_same_v) { + return ErrorSpec::Builder().distance_err(1).build(); + } + NativeT eps = std::numeric_limits::epsilon(); + return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); + }) + .GpuError(+[](NativeT x) { NativeT eps = std::numeric_limits::epsilon(); return ErrorSpec::Builder().abs_err(2e-5).rel_err(10 * eps).build(); })