Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 support to the exhaustive tests #17720

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/tests/exhaustive/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ exhaustive_xla_test(
":exhaustive_op_test_utils",
":exhaustive_unary_test_textual_hdrs",
"//xla:literal",
"//xla:types",
"//xla/client:xla_builder",
"//xla/client/lib:constants",
"//xla/client/lib:math",
Expand Down
114 changes: 110 additions & 4 deletions xla/tests/exhaustive/exhaustive_binary_test_definitions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,92 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Exhaustive test for binary operations for 8-bit floating point types,
// including float16 and bfloat.
//
// Test parameter is a pair of (begin, end) for range under test.
template <PrimitiveType T, bool kLeftToRightPacking = false>
class Exhaustive8BitBinaryTest
: public ExhaustiveBinaryTest<T>,
public ::testing::WithParamInterface<std::pair<int64_t, int64_t>> {
public:
int64_t GetInputSize() override {
int64_t begin, end;
std::tie(begin, end) = GetParam();
return end - begin;
}

// Given a range of uint64_t representation, uses bits 7..0 and bits 15..8
// for the values of src0 and src1 (see below for ordering) for the 8-bit
// binary operation being tested, and generate the cartesian product of the
// two sets as the two inputs for the test.
//
// If `kLeftToRightPacking == true`, then bits 15..8 are interpreted as src0
// and bits 7..0 are interpreted as src1. If `kLeftToRightPacking == false`,
// then bits 15..8 are interpreted as src1 and 7..0 are interpreted as src0.
void FillInput(std::array<Literal, 2>* input_literals) override {
int64_t input_size = GetInputSize();
CHECK_EQ(input_size, (*input_literals)[0].element_count());
CHECK_EQ(input_size, (*input_literals)[1].element_count());

int64_t begin, end;
std::tie(begin, end) = GetParam();

if (VLOG_IS_ON(2)) {
uint8_t left_begin, left_end, right_begin, right_end;
if constexpr (kLeftToRightPacking) {
left_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin >> 8));
left_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end >> 8));
right_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin));
right_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end));
} else {
left_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin));
left_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end));
right_begin = std::bit_cast<uint8_t>(static_cast<int8_t>(begin >> 8));
right_end = std::bit_cast<uint8_t>(static_cast<int8_t>(end >> 8));
}

LOG(INFO) << this->SuiteName() << this->TestName() << " Range:";
// N.B.: Cast to u32 to avoid printing values as char.
LOG(INFO) << "\tfrom=(" << static_cast<uint32_t>(left_begin) << ", "
<< static_cast<uint32_t>(right_begin) << "); hex=(" << std::hex
<< static_cast<uint32_t>(left_begin) << ", "
<< static_cast<uint32_t>(right_begin) << "); float=("
<< std::bit_cast<tsl::float8_e5m2>(left_begin) << ", "
<< std::bit_cast<tsl::float8_e5m2>(right_begin)
<< ") (inclusive)";
LOG(INFO) << "\tto=(" << static_cast<uint32_t>(left_end) << ", "
<< static_cast<uint32_t>(right_end) << "); hex=(" << std::hex
<< static_cast<uint32_t>(left_end) << ", "
<< static_cast<uint32_t>(right_end) << "); float=("
<< std::bit_cast<tsl::float8_e5m2>(left_end) << ", "
<< std::bit_cast<tsl::float8_e5m2>(right_end)
<< ") (exclusive)";
LOG(INFO) << "\ttotal values to test=" << (end - begin);
}

absl::Span<NativeT> input_arr_0 = (*input_literals)[0].data<NativeT>();
absl::Span<NativeT> input_arr_1 = (*input_literals)[1].data<NativeT>();
for (int64_t i = 0; i < input_size; i++) {
uint32_t input_val = i + begin;
// Convert the packed bits to a pair of NativeT and replace known
// incorrect input values with 0.
//
// In either case, we only use 16 bits out of the 64 bits possible.
if constexpr (kLeftToRightPacking) {
input_arr_0[i] = this->ConvertValue(input_val >> 8);
input_arr_1[i] = this->ConvertValue(input_val);
} else {
input_arr_0[i] = this->ConvertValue(input_val);
input_arr_1[i] = this->ConvertValue(input_val >> 8);
}
}
}

protected:
using typename ExhaustiveBinaryTest<T>::NativeT;
};

// Exhaustive test for binary operations for 16 bit floating point types,
// including float16 and bfloat.
//
Expand Down Expand Up @@ -147,11 +233,29 @@ class Exhaustive32BitOrMoreBinaryTest
}
};

using ExhaustiveF8E4M3FNBinaryTest = Exhaustive8BitBinaryTest<F8E4M3FN>;
using ExhaustiveF8E5M2BinaryTest = Exhaustive8BitBinaryTest<F8E5M2>;
using ExhaustiveF16BinaryTest = Exhaustive16BitBinaryTest<F16>;
using ExhaustiveBF16BinaryTest = Exhaustive16BitBinaryTest<BF16>;
using ExhaustiveF32BinaryTest = Exhaustive32BitOrMoreBinaryTest<F32>;
using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN)
#define BINARY_TEST_F8E4M3FN(test_name, ...) \
XLA_TEST_P(ExhaustiveF8E4M3FNBinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_E4M3FN(test_name, ...)
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2)
#define BINARY_TEST_F8E5M2(test_name, ...) \
XLA_TEST_P(ExhaustiveF8E5M2BinaryTest, test_name) \
__VA_ARGS__
#else
#define BINARY_TEST_E5M2(test_name, ...)
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
#define BINARY_TEST_F16(test_name, ...) \
XLA_TEST_P(ExhaustiveF16BinaryTest, test_name) \
Expand Down Expand Up @@ -180,10 +284,12 @@ using ExhaustiveF64BinaryTest = Exhaustive32BitOrMoreBinaryTest<F64>;
#define BINARY_TEST_F64(test_name, ...)
#endif

#define BINARY_TEST(test_name, ...) \
BINARY_TEST_F16(test_name, __VA_ARGS__) \
BINARY_TEST_BF16(test_name, __VA_ARGS__) \
BINARY_TEST_F32(test_name, __VA_ARGS__) \
#define BINARY_TEST(test_name, ...) \
BINARY_TEST_F8E4M3FN(test_name, __VA_ARGS__) \
BINARY_TEST_F8E5M2(test_name, __VA_ARGS__) \
BINARY_TEST_F16(test_name, __VA_ARGS__) \
BINARY_TEST_BF16(test_name, __VA_ARGS__) \
BINARY_TEST_F32(test_name, __VA_ARGS__) \
BINARY_TEST_F64(test_name, __VA_ARGS__)

#define BINARY_TEST_COMPLEX(test_name, ...) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,30 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E4M3FN)
INSTANTIATE_TEST_SUITE_P(F8E4M3FN, ExhaustiveF8E4M3FNBinaryTest,
::testing::ValuesIn(CreateExhaustiveU16Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_F8E5M2)
INSTANTIATE_TEST_SUITE_P(F8E5M2, ExhaustiveF8E5M2BinaryTest,
::testing::ValuesIn(CreateExhaustiveU16Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);
#endif

#if defined(XLA_BACKEND_SUPPORTS_BFLOAT16)
INSTANTIATE_TEST_SUITE_P(BF16, ExhaustiveBF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
::testing::ValuesIn(CreateExhaustiveU32Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);
#endif

#if !defined(XLA_BACKEND_DOES_NOT_SUPPORT_FLOAT16)
INSTANTIATE_TEST_SUITE_P(F16, ExhaustiveF16BinaryTest,
::testing::ValuesIn(CreateExhaustiveF32Ranges()));
::testing::ValuesIn(CreateExhaustiveU32Ranges()));
#else
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E4M3FNBinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF8E5M2BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveBF16BinaryTest);

GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(ExhaustiveF16BinaryTest);
Expand Down
24 changes: 22 additions & 2 deletions xla/tests/exhaustive/exhaustive_binary_test_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ bool PowCpuGpuF16Skip(NativeT left, NativeT right) {
BINARY_TEST(Pow, {
PowOp<kT>(this)
.CpuError(+[](NativeT left, NativeT right) {
if constexpr (std::is_same_v<NativeT, xla::half>) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();
} else if constexpr (std::is_same_v<NativeT, xla::half>) {
return ErrorSpec::Builder()
.strict_signed_zeros()
.skip_comparison(PowCpuGpuF16Skip(left, right))
Expand Down Expand Up @@ -357,7 +363,14 @@ bool Atan2CpuBf16F32Skip(NativeT left, NativeT right) {
BINARY_TEST(Atan2, {
Atan2Op<kT>(this)
.CpuError([](NativeT left, NativeT right) {
if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();

} else if constexpr (std::is_same_v<NativeT, xla::bfloat16>) {
return ErrorSpec::Builder()
.abs_err(
Atan2CpuBf16F32F64AbsErr<NativeT, NativeRefT>(left, right))
Expand All @@ -383,6 +396,13 @@ BINARY_TEST(Atan2, {
return ErrorSpec::Builder().strict_signed_zeros().build();
})
.GpuError(+[](NativeT, NativeT) {
if constexpr (std::is_same_v<NativeT, tsl::float8_e4m3fn> ||
std::is_same_v<NativeT, tsl::float8_e5m2>) {
return ErrorSpec::Builder()
.distance_err(1)
.strict_signed_zeros()
.build();
}
if constexpr (std::is_same_v<NativeT, xla::half> ||
std::is_same_v<NativeT, xla::bfloat16>) {
return ErrorSpec::Builder()
Expand Down
5 changes: 5 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limitations under the License.
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/executable_run_options.h"
#include "xla/fp_util.h"
#include "xla/literal.h"
#include "xla/literal_util.h"
#include "xla/service/shaped_buffer.h"
Expand Down Expand Up @@ -864,11 +865,15 @@ template class ExhaustiveOpTestBase<F64, 1>;
template class ExhaustiveOpTestBase<F32, 1>;
template class ExhaustiveOpTestBase<F16, 1>;
template class ExhaustiveOpTestBase<BF16, 1>;
template class ExhaustiveOpTestBase<F8E5M2, 1>;
template class ExhaustiveOpTestBase<F8E4M3FN, 1>;

template class ExhaustiveOpTestBase<F64, 2>;
template class ExhaustiveOpTestBase<F32, 2>;
template class ExhaustiveOpTestBase<F16, 2>;
template class ExhaustiveOpTestBase<BF16, 2>;
template class ExhaustiveOpTestBase<F8E5M2, 2>;
template class ExhaustiveOpTestBase<F8E4M3FN, 2>;

} // namespace exhaustive_op_test
} // namespace xla
4 changes: 4 additions & 0 deletions xla/tests/exhaustive/exhaustive_op_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,15 @@ template class ExhaustiveOpTestTraits<F64, 1>;
template class ExhaustiveOpTestTraits<F32, 1>;
template class ExhaustiveOpTestTraits<F16, 1>;
template class ExhaustiveOpTestTraits<BF16, 1>;
template class ExhaustiveOpTestTraits<F8E5M2, 1>;
template class ExhaustiveOpTestTraits<F8E4M3FN, 1>;

template class ExhaustiveOpTestTraits<F64, 2>;
template class ExhaustiveOpTestTraits<F32, 2>;
template class ExhaustiveOpTestTraits<F16, 2>;
template class ExhaustiveOpTestTraits<BF16, 2>;
template class ExhaustiveOpTestTraits<F8E5M2, 2>;
template class ExhaustiveOpTestTraits<F8E4M3FN, 2>;

bool IsSubnormalReal(xla::complex64 value) { return IsSubnormal(value.real()); }

Expand Down
33 changes: 30 additions & 3 deletions xla/tests/exhaustive/exhaustive_op_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/client/xla_builder.h"
#include "xla/fp_util.h"
#include "xla/literal.h"
#include "xla/primitive_util.h"
#include "xla/tests/exhaustive/error_spec.h"
Expand All @@ -51,7 +50,7 @@ namespace exhaustive_op_test {

// The primitive type used to compute the reference output.
constexpr PrimitiveType Ref(PrimitiveType T) {
return !primitive_util::IsFloatingPointType(T) || T == F64 ? T : F32;
return (!primitive_util::IsFloatingPointType(T) || T == F64) ? T : F32;
}

// The primitive type of the component of T. If T is not complex, then
Expand Down Expand Up @@ -195,6 +194,16 @@ inline ErrorSpec DefaultSpecGenerator<BF16, 1>(xla::bfloat16) {
return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E4M3FN, 1>(tsl::float8_e4m3fn) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E5M2, 1>(tsl::float8_e5m2) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F64, 2>(double, double) {
double atol = kDefaultAbsoluteToleranceSlackFactor *
Expand Down Expand Up @@ -231,6 +240,18 @@ inline ErrorSpec DefaultSpecGenerator<BF16, 2>(bfloat16, bfloat16) {
return ErrorSpec::Builder().abs_err(atol).rel_err(rtol).build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E4M3FN, 2>(tsl::float8_e4m3fn,
tsl::float8_e4m3fn) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <>
inline ErrorSpec DefaultSpecGenerator<F8E5M2, 2>(tsl::float8_e5m2,
tsl::float8_e5m2) {
return ErrorSpec::Builder().strict_signed_zeros().build();
}

template <PrimitiveType T, size_t N>
typename ExhaustiveOpTestTraits<T, N>::ErrorSpecGen GetDefaultSpecGenerator() {
// Select overload by casting to fn ptr type.
Expand Down Expand Up @@ -782,7 +803,13 @@ CreateSubnormalExhaustiveRanges() {
return ret;
}

inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveF32Ranges() {
inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveU16Ranges() {
// The entire U16 range is small enough that we don't need to do any
// partitioning.
return {{0, std::numeric_limits<uint16_t>::max()}};
}

inline std::vector<std::pair<int64_t, int64_t>> CreateExhaustiveU32Ranges() {
// We break up the 2^32-element space into small-ish chunks to keep peak
// memory usage low.
std::vector<std::pair<int64_t, int64_t>> result;
Expand Down
Loading
Loading