Skip to content

Commit

Permalink
Add FP8 support to the exhaustive tests
Browse files Browse the repository at this point in the history
Adds new tests for two FP8 variants to both unary and binary exhaustive tests.

PiperOrigin-RevId: 663087125
  • Loading branch information
Gregory Pataky authored and Google-ML-Automation committed Sep 27, 2024
1 parent 0983168 commit f48e243
Show file tree
Hide file tree
Showing 13 changed files with 366 additions and 28 deletions.
1 change: 1 addition & 0 deletions xla/tests/exhaustive/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ exhaustive_xla_test(
":exhaustive_op_test_utils",
":exhaustive_unary_test_textual_hdrs",
"//xla:literal",
"//xla:types",
"//xla/client:xla_builder",
"//xla/client/lib:constants",
"//xla/client/lib:math",
Expand Down
114 changes: 110 additions & 4 deletions xla/tests/exhaustive/exhaustive_binary_test_definitions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,92 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Exhaustive test for binary operations for 8-bit floating point types,
// including float16 and bfloat.
//
// Test parameter is a pair of (begin, end) for range under test.
template <PrimitiveType T, bool kLeftToRightPacking = false>
class Exhaustive8BitBinaryTest
: public ExhaustiveBinaryTest<T>,
public ::testing::WithParamInterface<std::pair<int64_t, int64_t>> {
public:
int64_t GetInputSize() override {
int64_t begin, end;
std::tie(begin, end) = GetParam();
return end - begin;
}

// Given a range of uint64_t representation, uses bits 7..0 and bits 15..8
// for the values of src0 and src1 (see below for ordering) for the 8-bit
// binary operation being tested, and generate the cartesian product of the
// two sets as the two inputs for the test.
//
// If `kLeftToRightPacking == true`, then bits 15..8 are interpreted as src0
// and bits 7..0 are interpreted as src1. If `kLeftToRightPacking == false`,
// then bits 15..8 are interpreted as src1 and 7..0 are interpreted as src0.
void FillInput(std::array<Literal, 2>* input_literals) override {
int64_t input_size = GetInputSize();
CHECK_EQ(input_size, (*input_literals)[0].element_count());
CHECK_EQ(input_size, (*input_literals)[1].element_count());

int64_t begin, end;
std::tie(begin, end) = GetParam();

if (VLOG_IS_ON(2)) {
uint8_t left_begin, left_end, right_begin, right_end;
if constexpr (kLeftToRightPacking) {
left_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin >> 8));
left_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end >> 8));
right_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin));
right_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end));
} else {
left_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin));
left_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end));
right_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin >> 8));
right_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end >> 8));
}

LOG(INFO) << this->SuiteName() << this->TestName() << " Range:";
// N.B.: Cast to u32 to avoid printing values as char.
LOG(INFO) << "\tfrom=(" << static_cast<uint32_t>(left_begin) << ", "
<< static_cast<uint32_t>(right_begin) << "); hex=(" << std::hex
<< static_cast<uint32_t>(left_begin) << ", "
<< static_cast<uint32_t>(right_begin) << "); float=("
<< std::bit_cast<tsl::float8_e5m2>(left_begin) << ", "
<< std::bit_cast<tsl::float8_e5m2>(right_begin)
<< ") (inclusive)";
LOG(INFO) << "\tto=(" << static_cast<uint32_t>(left_end) << ", "
<< static_cast<uint32_t>(right_end) << "); hex=(" << std::hex
<< static_cast<uint32_t>(left_end) << ", "
<< static_cast<uint32_t>(right_end) << "); float=("
<< std::bit_cast<tsl::float8_e5m2>(left_end) << ", "
<< std::bit_cast<tsl::float8_e5m2>(right_end)
<< ") (exclusive)";
LOG(INFO) << "\ttotal values to test=" << (end - begin);
}

absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
for (int64_t i = 0; i < input_size; i++) {
uint32_t input_val = i + begin;
// Convert the packed bits to a pair of NativeT and replace known
// incorrect input values with 0.
//
// In either case, we only use 16 bits out of the 64 bits possible.
if constexpr (kLeftToRightPacking) {
input_arr_0[i] = this->ConvertValue(input_val >> 8);
input_arr_1[i] = this->ConvertValue(input_val);
} else {
input_arr_0[i] = this->ConvertValue(input_val);
input_arr_1[i] = this->ConvertValue(input_val >> 8);
}
}
}

protected:
using typename ExhaustiveBinaryTest<T>::NativeT;
};

// Exhaustive test for binary operations for 16 bit floating point types,
// including float16 and bfloat.
//
Expand Down Expand Up @@ -147,11 +233,29 @@ class Exhaustive32BitOrMoreBinaryTest
}
};

using ExhaustiveF8E4M3FNBinaryTest = Exhaustive8BitBinaryTest<F8E4M3FN>;
using ExhaustiveF8E5M2BinaryTest = Exhaustive8BitBinaryTest<F8E5M2>;
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN)
#define BINARY_TEST_F8E4M3FN(test_name, ...) \
XLA_TEST_P(ExhaustiveF8E4M3FNBinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_E4M3FN(test_name, ...)
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2)
#define BINARY_TEST_F8E5M2(test_name, ...) \
XLA_TEST_P(ExhaustiveF8E5M2BinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_E5M2(test_name, ...)
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
#define BINARY_TEST_F16(test_name, ...) \
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
Expand Down Expand Up @@ -180,10 +284,12 @@ using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
#define BINARY_TEST_F64(test_name, ...)
#endif

#define BINARY_TEST(test_name, ...) \
BINARY_TEST_F16(test_name, __VA_ARGS__) \
BINARY_TEST_BF16(test_name, __VA_ARGS__) \
BINARY_TEST_F32(test_name, __VA_ARGS__) \
#define BINARY_TEST(test_name, ...) \
BINARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \
BINARY_TEST_F8E5M2(test_name, __VA_ARGS__) \
BINARY_TEST_F16(test_name, __VA_ARGS__) \
BINARY_TEST_BF16(test_name, __VA_ARGS__) \
BINARY_TEST_F32(test_name, __VA_ARGS__) \
BINARY_TEST_F64(test_name, __VA_ARGS__)

#define BINARY_TEST_COMPLEX(test_name, ...) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,30 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN)
INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNBinaryTest,
::testing::ValuesIn(CreateExhaustiveU16Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2)
INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest,
::testing::ValuesIn(CreateExhaustiveU16Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);
#endif

#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
::testing::ValuesIn(CreateExhaustiveU32Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
::testing::ValuesIn(CreateExhaustiveU32Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
Expand Down
24 changes: 22 additions & 2 deletions xla/tests/exhaustive/exhaustive_binary_test_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) {
BINARY_TEST(Pow, {
PowOp<kT>(this)
.CpuError(+[](NativeT left, NativeT right) {
if constexpr (std::is_same_v<NativeT, xla::half>) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();
} else if constexpr (std::is_same_v<NativeT, xla::half>) {
return ErrorSpec::Builder()
.strict_signed_zeros()
.skip_comparison(PowCpuGpuF16Skip(left, right))
Expand Down Expand Up @@ -357,7 +363,14 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) {
BINARY_TEST(Atan2, {
Atan2Op<kT>(this)
.CpuError([](NativeT left, NativeT right) {
if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();

} else if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
return ErrorSpec::Builder()
.abs_err(
Atan2CpuBf16F32F64AbsErr<NativeT, NativeRefT>(left, right))
Expand All @@ -383,6 +396,13 @@ BINARY_TEST(Atan2, {
return ErrorSpec::Builder().strict_signed_zeros().build();
})
.GpuError(+[](NativeT, NativeT) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();
}
if constexpr (std::is_same_v<NativeT, xla::half> ||
std::is_same_v<NativeT, xla::bfloat16>) {
return ErrorSpec::Builder()
Expand Down
5 changes: 5 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/executable_run_options.h"
#include "xla/fp_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/shaped_buffer.h"
Expand Down Expand Up @@ -864,11 +865,15 @@ template class ExhaustiveOpTestBase<F64, 1>;
template class ExhaustiveOpTestBase<F32, 1>;
template class ExhaustiveOpTestBase<F16, 1>;
template class ExhaustiveOpTestBase<BF16, 1>;
template class ExhaustiveOpTestBase<F8E5M2, 1>;
template class ExhaustiveOpTestBase<F8E4M3FN, 1>;

template class ExhaustiveOpTestBase<F64, 2>;
template class ExhaustiveOpTestBase<F32, 2>;
template class ExhaustiveOpTestBase<F16, 2>;
template class ExhaustiveOpTestBase<BF16, 2>;
template class ExhaustiveOpTestBase<F8E5M2, 2>;
template class ExhaustiveOpTestBase<F8E4M3FN, 2>;

} // namespace exhaustive_op_test
} // namespace xla
4 changes: 4 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ template class ExhaustiveOpTestTraits<F64, 1>;
template class ExhaustiveOpTestTraits<F32, 1>;
template class ExhaustiveOpTestTraits<F16, 1>;
template class ExhaustiveOpTestTraits<BF16, 1>;
template class ExhaustiveOpTestTraits<F8E5M2, 1>;
template class ExhaustiveOpTestTraits<F8E4M3FN, 1>;

template class ExhaustiveOpTestTraits<F64, 2>;
template class ExhaustiveOpTestTraits<F32, 2>;
template class ExhaustiveOpTestTraits<F16, 2>;
template class ExhaustiveOpTestTraits<BF16, 2>;
template class ExhaustiveOpTestTraits<F8E5M2, 2>;
template class ExhaustiveOpTestTraits<F8E4M3FN, 2>;

bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); }

Expand Down
33 changes: 30 additions & 3 deletions xla/tests/exhaustive/exhaustive_op_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/client/xla_builder.h"
#include "xla/fp_util.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/tests/exhaustive/error_spec.h"
Expand All @@ -51,7 +50,7 @@ namespace exhaustive_op_test {

// The primitive type used to compute the reference output.
constexpr PrimitiveType Ref(PrimitiveType T) {
return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32;
return (!primitive_util::IsFloatingPointType(T) || T == F64) ? T : F32;
}

// The primitive type of the component of T. If T is not complex, then
Expand Down Expand Up @@ -195,6 +194,16 @@ inline ErrorSpec DefaultSpecGenerator<BF16, 1>(xla::bfloat16) {
return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E4M3FN, 1>(tsl::float8_e4m3fn) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E5M2, 1>(tsl::float8_e5m2) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) {
double atol = kDefaultAbsoluteToleranceSlackFactor *
Expand Down Expand Up @@ -231,6 +240,18 @@ inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E4M3FN, 2>(tsl::float8_e4m3fn,
tsl::float8_e4m3fn) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E5M2, 2>(tsl::float8_e5m2,
tsl::float8_e5m2) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <PrimitiveType T, size_t N>
typename ExhaustiveOpTestTraits<T, N>::ErrorSpecGen GetDefaultSpecGenerator() {
// Select overload by casting to fn ptr type.
Expand Down Expand Up @@ -782,7 +803,13 @@ CreateSubnormalExhaustiveRanges() {
return ret;
}

inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveF32Ranges() {
inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveU16Ranges() {
// The entire U16 range is small enough that we don't need to do any
// partitioning.
return {{0, std::numeric_limits<uint16_t>::max()}};
}

inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveU32Ranges() {
// We break up the 2^32-element space into small-ish chunks to keep peak
// memory usage low.
std::vector<std::pair<int64_t, int64_t>> result;
Expand Down
Loading

0 comments on commit f48e243

Please sign in to comment.