Skip to content

Commit

Permalink
PR #16775: Add test for EmitReducePrecisionIR
Browse files Browse the repository at this point in the history
Imported from GitHub PR #16775

I noticed that the `EmitReducePrecisionIR` function from `xla/service/elemental_ir_emitter.h` is not covered by unit tests.

Given its non-trivial logic, I believe it should be thoroughly tested, particularly for corner cases.

Changes in this PR:
- Declare `EmitReducePrecisionIR` function in `xla/service/elemental_ir_emitter.h`
- Add `EmitReducePrecisionIR_F16ToF8e5m2` test
- Add `EmitReducePrecisionIR_F16ToF8e4m3fn` test

Related PR:
- [PR-16585](#16585) Add support for float8_e4m3

Copybara import of the project:

--
5972205 by Alexander Pivovarov <[email protected]>:

Add test for EmitReducePrecisionIR

Merging this change closes #16775

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16775 from apivovarov:elemental_ir_emitter_test 5972205
PiperOrigin-RevId: 696730664
  • Loading branch information
apivovarov authored and Google-ML-Automation committed Nov 15, 2024
1 parent 3be0220 commit 4ec7e7d
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 2 deletions.
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4287,6 +4287,7 @@ xla_test(
"//xla:types",
"//xla/hlo/ir:hlo",
"//xla/service/llvm_ir:ir_array",
"//xla/service/llvm_ir:llvm_util",
"//xla/tests:hlo_test_base",
"//xla/tests:test_macros_header",
"//xla/tests:xla_internal_test_main",
Expand Down
4 changes: 2 additions & 2 deletions xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ using llvm_ir::SetToFirstInsertPoint;
using xla::float8_fnuz_ir_emitter::EmitF8fnuzToFloating;
using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz;

namespace {

absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b) {
Expand Down Expand Up @@ -231,6 +229,8 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
return result;
}

namespace {

template <int f8_exponent_bits>
llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits,
llvm::Value* f8_bits,
Expand Down
4 changes: 4 additions & 0 deletions xla/service/elemental_ir_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ class ElementalIrEmitterForTests : public ElementalIrEmitter {
HloToElementGeneratorMap generator_map_;
};

absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits,
int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilderBase* b);

} // namespace xla

#endif // XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_
192 changes: 192 additions & 0 deletions xla/service/elemental_ir_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/llvm_ir/ir_array.h"
#include "xla/service/llvm_ir/llvm_util.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/test_macros.h"
Expand All @@ -48,6 +49,11 @@ namespace {

using std::nullopt;

struct EmitReducePrecisionIrTestCase {
float input;
std::string expected_res;
};

class ElementalIrEmitterExecutionTest : public HloTestBase {
protected:
void RunTest(const std::string& hlo_text, absl::Span<Literal* const> args) {
Expand Down Expand Up @@ -123,6 +129,192 @@ ENTRY main {
RunTest(hlo_text, {&lhs, &rhs});
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e5m2) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-14, "half 0xH0400"},
{0.250, "half 0xH3400"},
{1.0, "half 0xH3C00"},
{0x1.2p0, "half 0xH3C00"},
{0x1.Cp15, "half 0xH7B00"},
{-0x1.Cp15, "half 0xHFB00"},
{0x1.Dp15, "half 0xH7B00"},
{0x1.Ep15, "half 0xH7C00"},
{0x1.0p16, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7F00"},
{-snan, "half 0xHFF00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/primitive_util::ExponentWidth(F8E5M2),
/*dest_mantissa_bits=*/primitive_util::SignificandWidth(F8E5M2) - 1,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e4m3) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-6, "half 0xH2400"},
{0.125, "half 0xH3000"},
{1.0, "half 0xH3C00"},
{0x1.1p0, "half 0xH3C00"},
{0x1.Ep7, "half 0xH5B80"},
{-0x1.Ep7, "half 0xHDB80"},
{0x1.E8p7, "half 0xH5B80"},
{0x1.Fp7, "half 0xH7C00"},
{0x1.0p8, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7E00"},
{-snan, "half 0xHFE00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/4,
/*dest_mantissa_bits=*/3,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e3m4) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();
float qnan = std::numeric_limits<float>::quiet_NaN();
float snan = std::numeric_limits<float>::signaling_NaN();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-2, "half 0xH3400"},
{0.5, "half 0xH3800"},
{1.0, "half 0xH3C00"},
{0x1.08p0, "half 0xH3C00"},
{0x1.Fp3, "half 0xH4BC0"},
{-0x1.Fp3, "half 0xHCBC0"},
{0x1.F4p3, "half 0xH4BC0"},
{0x1.F8p3, "half 0xH7C00"},
{0x1.0p4, "half 0xH7C00"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
{qnan, "half 0xH7E00"},
{-qnan, "half 0xHFE00"},
{snan, "half 0xH7E00"},
{-snan, "half 0xHFE00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/3,
/*dest_mantissa_bits=*/4,
/*quiet_nans=*/true, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest,
EmitReducePrecisionIR_F16ToF8e4m3fn) {
llvm::LLVMContext llvm_context;
llvm::IRBuilder<> builder(llvm_context);
llvm::IRBuilderBase* b = &builder;
llvm::Type* f16_type = b->getHalfTy();

float inf = std::numeric_limits<float>::infinity();

EmitReducePrecisionIrTestCase test_cases[] = {
// clang-format off
{0.0, "half 0xH0000"},
{0x1.0p-6, "half 0xH2400"},
{0.125, "half 0xH3000"},
{1.0, "half 0xH3C00"},
{0x1.1p0, "half 0xH3C00"},
{0x1.Cp8, "half 0xH5F00"},
{-0x1.Cp8, "half 0xHDF00"},
{0x1.Dp8, "half 0xH5F00"},
{0x1.Ep8, "half 0xH5F80"},
{0x1.0p9, "half 0xH6000"},
{inf, "half 0xH7C00"},
{-inf, "half 0xHFC00"},
// clang-format on
};

for (auto tc : test_cases) {
llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input);

// Truncate the mantissa to 3 bits. ReducePrecision cannot deal with
// f8E4M3FN's NaN representations, so don't use ReducePrecision to handle
// exponent reduction.
absl::StatusOr<llvm::Value*> f16_reduced_statusor = EmitReducePrecisionIR(
/*src_ty=*/F16, c0,
/*dest_exponent_bits=*/5,
/*dest_mantissa_bits=*/3,
/*quiet_nans=*/false, b);
CHECK(f16_reduced_statusor.ok());
llvm::Value* f16_reduced = f16_reduced_statusor.value();

std::string res = llvm_ir::DumpToString(f16_reduced);
EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input;
}
}

XLA_TEST_F(ElementalIrEmitterExecutionTest, ScalarDotFusion) {
const char* hlo_text = R"(
HloModule ScalarDotFusion
Expand Down

0 comments on commit 4ec7e7d

Please sign in to comment.