diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index b04e4e554a8a8e..686a9bf5d0efd9 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -75,8 +75,6 @@ using llvm_ir::SetToFirstInsertPoint; using xla::float8_fnuz_ir_emitter::EmitF8fnuzToFloating; using xla::float8_fnuz_ir_emitter::EmitFloatingToF8fnuz; -namespace { - absl::StatusOr EmitReducePrecisionIR( PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits, int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b) { @@ -220,6 +218,8 @@ absl::StatusOr EmitReducePrecisionIR( return result; } +namespace { + template llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, llvm::Value* f8_bits, diff --git a/xla/service/elemental_ir_emitter.h b/xla/service/elemental_ir_emitter.h index fe33977297572d..f15c60d93ccc74 100644 --- a/xla/service/elemental_ir_emitter.h +++ b/xla/service/elemental_ir_emitter.h @@ -351,6 +351,10 @@ class ElementalIrEmitterForTests : public ElementalIrEmitter { HloToElementGeneratorMap generator_map_; }; +absl::StatusOr EmitReducePrecisionIR( + PrimitiveType src_ty, llvm::Value* x, int64_t dest_exponent_bits, + int64_t dest_mantissa_bits, bool quiet_nans, llvm::IRBuilder<>* b); + } // namespace xla #endif // XLA_SERVICE_ELEMENTAL_IR_EMITTER_H_ diff --git a/xla/service/elemental_ir_emitter_test.cc b/xla/service/elemental_ir_emitter_test.cc index 60c4535909d158..b684878a4b2e98 100644 --- a/xla/service/elemental_ir_emitter_test.cc +++ b/xla/service/elemental_ir_emitter_test.cc @@ -36,6 +36,7 @@ limitations under the License. #include "xla/literal_util.h" #include "xla/service/hlo_module_config.h" #include "xla/service/llvm_ir/ir_array.h" +#include "xla/service/llvm_ir/llvm_util.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_macros.h" @@ -123,6 +124,204 @@ ENTRY main { RunTest(hlo_text, {&lhs, &rhs}); } +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e5m2) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilder<>* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + struct TestCase { + float input; + std::string expected_res; + } test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-14, "half 0xH0400"}, + {0.250, "half 0xH3400"}, + {1.0, "half 0xH3C00"}, + {0x1.2p0, "half 0xH3C00"}, + {0x1.Cp15, "half 0xH7B00"}, + {-0x1.Cp15, "half 0xHFB00"}, + {0x1.Dp15, "half 0xH7B00"}, + {0x1.Ep15, "half 0xH7C00"}, + {0x1.0p16, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7F00"}, + {-snan, "half 0xHFF00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/primitive_util::ExponentWidth(F8E5M2), + /*dest_mantissa_bits=*/primitive_util::SignificandWidth(F8E5M2) - 1, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e4m3) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilder<>* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + struct TestCase { + float input; + std::string expected_res; + } test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-6, "half 0xH2400"}, + {0.125, "half 0xH3000"}, + {1.0, "half 0xH3C00"}, + {0x1.1p0, "half 0xH3C00"}, + {0x1.Ep7, "half 0xH5B80"}, + {-0x1.Ep7, "half 0xHDB80"}, + {0x1.E8p7, "half 0xH5B80"}, + {0x1.Fp7, "half 0xH7C00"}, + {0x1.0p8, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7E00"}, + {-snan, "half 0xHFE00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/4, + /*dest_mantissa_bits=*/3, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, EmitReducePrecisionIR_F16ToF8e3m4) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilder<>* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + float qnan = std::numeric_limits::quiet_NaN(); + float snan = std::numeric_limits::signaling_NaN(); + + struct TestCase { + float input; + std::string expected_res; + } test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-2, "half 0xH3400"}, + {0.5, "half 0xH3800"}, + {1.0, "half 0xH3C00"}, + {0x1.08p0, "half 0xH3C00"}, + {0x1.Fp3, "half 0xH4BC0"}, + {-0x1.Fp3, "half 0xHCBC0"}, + {0x1.F4p3, "half 0xH4BC0"}, + {0x1.F8p3, "half 0xH7C00"}, + {0x1.0p4, "half 0xH7C00"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + {qnan, "half 0xH7E00"}, + {-qnan, "half 0xHFE00"}, + {snan, "half 0xH7E00"}, + {-snan, "half 0xHFE00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/3, + /*dest_mantissa_bits=*/4, + /*quiet_nans=*/true, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + +XLA_TEST_F(ElementalIrEmitterExecutionTest, + EmitReducePrecisionIR_F16ToF8e4m3fn) { + llvm::LLVMContext llvm_context; + llvm::IRBuilder<> builder(llvm_context); + llvm::IRBuilder<>* b = &builder; + llvm::Type* f16_type = b->getHalfTy(); + + float inf = std::numeric_limits::infinity(); + + struct TestCase { + float input; + std::string expected_res; + } test_cases[] = { + // clang-format off + {0.0, "half 0xH0000"}, + {0x1.0p-6, "half 0xH2400"}, + {0.125, "half 0xH3000"}, + {1.0, "half 0xH3C00"}, + {0x1.1p0, "half 0xH3C00"}, + {0x1.Cp8, "half 0xH5F00"}, + {-0x1.Cp8, "half 0xHDF00"}, + {0x1.Dp8, "half 0xH5F00"}, + {0x1.Ep8, "half 0xH5F80"}, + {0x1.0p9, "half 0xH6000"}, + {inf, "half 0xH7C00"}, + {-inf, "half 0xHFC00"}, + // clang-format on + }; + + for (auto tc : test_cases) { + llvm::Value* c0 = llvm::ConstantFP::get(f16_type, tc.input); + + // Truncate the mantissa to 3 bits. ReducePrecision cannot deal with + // f8E4M3FN's NaN representations, so don't use ReducePrecision to handle + // exponent reduction. + absl::StatusOr f16_reduced_statusor = EmitReducePrecisionIR( + /*src_ty=*/F16, c0, + /*dest_exponent_bits=*/5, + /*dest_mantissa_bits=*/3, + /*quiet_nans=*/false, b); + CHECK(f16_reduced_statusor.ok()); + llvm::Value* f16_reduced = f16_reduced_statusor.value(); + + std::string res = llvm_ir::DumpToString(f16_reduced); + EXPECT_EQ(res, tc.expected_res) << "Wrong result for input " << tc.input; + } +} + XLA_TEST_F(ElementalIrEmitterExecutionTest, ScalarDotFusion) { const char* hlo_text = R"( HloModule ScalarDotFusion