-
Notifications
You must be signed in to change notification settings - Fork 406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for float8_e4m3 and float8_e3m4 types #16585
base: main
Are you sure you want to change the base?
Conversation
0a7a780
to
5b32d43
Compare
5b32d43
to
98e4256
Compare
third_party/stablehlo/workspace.bzl
Outdated
# LINT.ThenChange(Google-internal path) | ||
|
||
tf_http_archive( | ||
name = "stablehlo", | ||
sha256 = STABLEHLO_SHA256, | ||
strip_prefix = "stablehlo-{commit}".format(commit = STABLEHLO_COMMIT), | ||
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), | ||
urls = tf_mirror_urls("https://github.com/apivovarov/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably not intended to fetch it from your repo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's intended that you also update the deps in the same PR, could you split it in separate PRs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is pending the merge of StableHLO openxla/stablehlo#2482 Add f8E4M3 and f8E3M4 types support (in Review).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Alexander, Integrate StableHLO at openxla/stablehlo@4f31b2e7 was werged to XLA main today. It includes float8_e4m3
type support. My temporary change in third_party/stablehlo/workspace.bzl
was removed from this PR. @mooskagh
### Summary This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point types to StableHLO. Feedback welcome, see [RFC: Float8E4M3 and Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md) for more details. ### References and Links - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - [RFC: FP8 in StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md) - [RFC: Float8E4M3FNUZ and Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md) - StableHLO [PR-2482](#2482) Add f8E4M3 and f8E3M4 types support - [Amazon EC2 Trn1 Instances](https://aws.amazon.com/ec2/instance-types/trn1/) - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-16585](openxla/xla#16585) Add support for float8_e4m3
This PR adds f8E4M3 and f8E3M4 types support. f8E4M3 and f8E3M4 types follow IEEE 754 convention. ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged) - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged) - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged) - LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged) - StableHLO [PR-2486](#2486) [RFC] Add f8E4M3 and f8E3M4 types support - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged) - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged) - XLA [PR-16585](openxla/xla#16585) Add support for float8_e4m3
98e4256
to
4804fca
Compare
I think Reed is the best person to review, I think this will require a patch on our end due to |
Hi Reed, This PR introduces support for the new f8E4M3 type, which adheres to the IEEE-754 convention. I've already added this type to the LLVM, MLIR, ml_dtypes, and StableHLO projects. This PR extends the support to XLA and includes a reference implementation for the CPU compiler. Could you please help review this PR? |
4804fca
to
9dce4c9
Compare
9dce4c9
to
88e586c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change! Sorry for the delay in reviewing.
Normally, it's better to minimize the size of PRs, but I would prefer if E3M4 is also added in the same PR, since it touches most of the same files in the exact same way as E4M3, so it makes it easy to batch review both dtypes at once.
But if adding E3M4 to the same PR is inconvenient with you, I'm fine with this being done as a separate, future PR.
xla/client/lib/math.cc
Outdated
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN, | ||
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the type list here is growing, I would avoid hardcoding the list of FP8 types. One way to avoid this is to have DoWithUpcastToF32 take a should_upcast
bool instead of the existing upcast_types
list. Then you can pass something like should_upcast = BitWidth(b.GetShape(x).element_type) <= 16
.
There are a lot of places where we list out all FP8 types, but every place we can remove listing these out will help when more types are added :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17130 - Add default upcasting behavior to DoWithUpcastToF32
xla/fp_util_test.cc
Outdated
@@ -111,6 +111,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest, | |||
0x1.fffffffffffffp-127, | |||
0x1.aaaaaaaaaaaaap-127)); | |||
|
|||
TEST(FPDistanceTest, F8E4M3Distance) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this test is almost identical to F8E4M3FNDistance, can you merge them to avoid duplication?
One way to to create a type-parameterized test with TYPED_TEST_P. Another way would be to have a for-loop over primtiives types F8E4M3 and F8E4M3FN, and in the body use primitive_util::PrimitiveTypeSwitch
with a lambda that does the CalculateDistanceInFloats
calls. See here for an example of how to use PrimitiveTypeSwitch
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17135 - Add TypeParam to FP8E4M3DistanceTest
} else if constexpr (std::is_integral_v<ElementwiseT>) { | ||
if constexpr (std::is_signed_v<ElementwiseT>) { | ||
if (rhs_el < static_cast<ElementwiseT>(0)) { | ||
ElementWiseBinaryOp( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't change the formatting here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restored
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
github workflow runs pipx
to check clang formatting. Opened PR #17234 Format hlo_evaluator_typed_visitor.h
xla/literal_comparison_test.cc
Outdated
@@ -25,6 +25,63 @@ limitations under the License. | |||
namespace xla { | |||
namespace { | |||
|
|||
TEST(LiteralComparisonTest, F8E4M3CompareNear_Equal) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the FP8 tests are duplicated for each FP8 type. Can you use TYPED_TEST_P
to deduplicate them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #17133 - Dedup LiteralComparisonTests
xla/literal_test.cc
Outdated
@@ -644,15 +648,19 @@ TEST_F(LiteralUtilTest, IsAll) { | |||
// 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false | |||
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e5m2>({q16}).IsAll(9)); | |||
|
|||
tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3 | |||
tsl::float8_e4m3 e4m3(9); // Exactly representable in e4m3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why there is a convention of using subsequent single letters to name the FP8 values (q16, r16, s16, etc) but you should follow it or change the convention. Either name this q16, renaming the above e5m2 to p16, or rename all the other FP8 variable names to something more descriptive, as you did for this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renamed to q16
case xla::F8E4M3: | ||
return absl::UnimplementedError("F8E4M3 not implemented"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid having to modify this every time a new FP8 type is added, remove all these FP8 cases and check if IsF8Type(literal.shape().element_type()
before the switch statement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17170 Code dedup in execution_trace_utils LiteralToValue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, this was an intentional choice. Missing switch cases are a compiler error, so having a switch without a default case is preferable when possible. No big deal though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using single-case switch statements can make it easier for the compiler to detect potential errors in the code.
Opened PR #17279 - Use switch case without default in LiteralToValue
@@ -500,6 +500,36 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) { | |||
EXPECT_EQ(root->operand(0)->shape().element_type(), U16); | |||
} | |||
|
|||
TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e4m3) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Merge this with the existing test ResolveIfUnsupportedF8e5m2, either by looping over values (F8E4M3, F8E5M2) or by using a value-parameterized test with TEST_P.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17177 - Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8
xla/tests/constants_test.cc
Outdated
|
||
XlaBuilder builder(TestName()); | ||
auto c = ConstantR1<tsl::float8_e4m3>(&builder, constant); | ||
// F8 outputs are not yet supported so convert to F32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is actually no longer true. We should change the other tests with this comment as well. The test OneCellF8e5m2fnuz does have an FP8 output, so you can use that as an example in modifying this test.
If you want, you can also change the two existing tests with the comment to have FP8 outputs as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened PR #17182 - Parametrize ConstantsFloatTest OneCellFloat
xla/service/elemental_ir_emitter.cc
Outdated
f16_reduced = | ||
b->CreateOr(b->CreateAnd(f16_reduced, i16_const(0x9FFF)), | ||
b->CreateLShr(b->CreateAnd(f16_reduced, i16_const(0x4000)), | ||
i16_const(1))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is effectively subtracting 8 from the exponent I think, as the difference in exponent bias is 8. Why not do that subtraction directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In contrast to the EmitF16ToF8e4m3fn
function, the EmitF16ToF8e4m3
function does not include special code to handle Inf and NaN cases. (code around constexpr int max_finite_value = 0x5F7F;
)
If I use -8 approach then several tests in //xla/tests:convert_test_cpu
FAILED.
e.g.
- inf -> -1.0
- nan -> -1.5
Example:
input is inf
EmitReducePrecisionIR returns
x = 0.11111.0000000000 (0x7C00)
Option1: minus 8
x -= 0.01000.0000000000
// x is 0.10111.0000000000
// Shift to convert to F8: x is 1.0111.000
// f8e4m3 Result is -1.0 (Wrong)
Option2: Right shift E5 exponent's leftmost bit
x = (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1)
// x is 0.01111.0000000000
// Shift to convert to F8: x is 0.1111.000
// f8e4m3 Result is inf (Correct)
xla/service/elemental_ir_emitter.cc
Outdated
|
||
// Set output exponent to 11111 if input exponent is 1111 (Inf or NaN) | ||
// 0.1111.000 is 0x78 | ||
// 0.11111.000000000000 is 0x7C00 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0.11111.000000000000 has 12 zeros at the end, when it should have 10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
@@ -220,6 +220,59 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR( | |||
return result; | |||
} | |||
|
|||
llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, I think you might have to modify expand_float_ops.cc, which is used by the new MLIR emitters which replace the existing emitters on GPUs. But I'm not very familiar with these new emitters. @jreiffers, can you advice on what needs to be done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have generic support for all possible float conversions, but the emitted code might not be optimal, so it should be considered a fallback. I didn't look at these conversion routines here in detail, but if they're better, it would make sense to port them to the MLIR pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated xla/service/gpu/fusions/transforms/expand_float_ops.cc
and added f8E4M3
cases to:
- IsInf()
- IsNaN()
- RewriteF8Cst::matchAndRewrite() // If we're comparing to +-0, compare the absolute values.
expand_float_ops.cc
includes a specialized function for the f8e5m2 type - EmitF16ToF8e5m2()
. This is because F16 is technically f16e5m10
. The two types are similar, with the primary difference being that the mantissa in f8e5m2
is truncated to 2 bits.
f8E4M3
has a different number of exponent and mantissa bits. The conversion can be efficiently managed using the "generic support for all possible float conversions".
Tested xla on CUDA:
//xla/tests/... 799 tests: 799 tests pass
//xla/service/... 865 tests: 865 tests pass
//xla/client/... 77 tests: 77 tests pass
//xla/runtime/... 1 test: 1 test passes
//xla/ffi/... 6 tests: 6 tests pass
//xla/hlo/... 12 tests: 12 tests pass
//xla/mlir/... 141 tests: 141 tests pass
//xla/mlir_hlo/... 98 tests: 98 tests pass
//xla/pjrt/... 26 tests: 26 tests pass
//xla/tools/... 41 tests: 41 tests pass
//xla/translate/... 61 tests: 61 tests pass
Apologies for the delay, I'm OOO this week. Will take a look on Monday. |
ml_dtypes Updates: Add float8_e4m3 and float8_e3m4 types support Fix float divmod with zero denominator Add int2 and uint2 types ml_dtypes/commits Related PRs ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged) XLA PR Add support for float8_e4m3 #16585 (In Review) This closes #17075 PiperOrigin-RevId: 674396944
ml_dtypes Updates: Add float8_e4m3 and float8_e3m4 types support Fix float divmod with zero denominator Add int2 and uint2 types ml_dtypes/commits Related PRs ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged) XLA PR Add support for float8_e4m3 #16585 (In Review) This closes #17075 PiperOrigin-RevId: 674396944
Yes, I plan to focus on adding f8e3m4 support for the rest of the week. |
7bda957
to
2790700
Compare
2790700
to
0e9a38c
Compare
Added |
0118994
to
24eb476
Compare
Templated the following functions with
|
24eb476
to
be0a847
Compare
xla/tests/convert_test.cc
Outdated
{0x1p8, inf}, // Overflow | ||
{0x1p-6, 0x1p-6}, // Smallest normal | ||
{0x0.2p-6, 0x0.2p-6}, // Smallest denormal | ||
{0x0.Ep-6, 0x0.Ep-6}, // Largest denormal |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update test to match recent changes in xla/tests/convert_test.cc #17437
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the tests (ConvertF16 + Roundtrip tests).
One additional test was added recently DISABLED_ON_CPU(ConvertF32F8e5m2Roundtrip)
. Need to add similar test for e4m3 and e3m4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added F32 test:
DISABLED_ON_CPU(ConvertF32F8e4m3Roundtrip)
DISABLED_ON_CPU(ConvertF32F8e3m4Roundtrip)
5109f0a
to
b3b684a
Compare
Reed, could you please review this PR again? I've implemented both types and addressed all technical debt. |
d999e3b
to
fa92c93
Compare
I made |
xla/tests/convert_test.cc
Outdated
{0x1.18p0, 0x1.2p0}, // Round-to-even up | ||
{0x1.Fp3, 0x1.Fp3}, // Max value | ||
{0x1.F7Cp3, 0x1.Fp3}, // Largest number that doesn't overflow | ||
{0x1.F8p7, inf}, // Smallest number that overflows |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
0x1.F8p7 should be 0x1.F8p3. And similarly for ConvertF32F8e3m4Roundtrip.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in both places
xla/tests/convert_test.cc
Outdated
{0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down | ||
{0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up | ||
{0x0.08p-2, 0}, // Largest number that underflows | ||
{0x0.084p-2, 0x0.1p-2}, // Smallest number that doesn't underflow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 0x0.084p-2 should be 0x0.0802p-2. There should be 9 zeros between the first binary 1 and the second binary 1 to get the next highest FP16 number above 0x0.08p-2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Because we only have m10 bits - better to switch to 0x1. notation for F16 input number here - 0x1.004p-10
. Updated // Largest number that underflows to 0x1p-10
too, since they are related.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed "Smallest number that underflows" in ConvertF32F8e4m3Roundtrip and ConvertF32F8e3m4Roundtrip too - 0x1.000002p-10
and 0x1.000002p-7
.
xla/service/elemental_ir_emitter.cc
Outdated
f16_sign = b->CreateLShr(f16_sign, i16_const(8)); | ||
Value* f8_sign = b->CreateTrunc(f16_sign, i8_type); | ||
|
||
// Truncate the mantissa to f8 mantissa bits. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment should also mention the exponent is made smaller. Maybe:
"Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated the comment as suggested
xla/service/elemental_ir_emitter.cc
Outdated
// Right shift E5 exponent's leftmost bit to convert from E5 to E4 format. | ||
// For example, 10011 becomes 01011, another example, 01011 becomes 00011 | ||
// (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1) | ||
from_e5_convert_mask = 0x9FFF; | ||
from_e5_convert_shift = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works but is difficult to understand. IIUC, there are four cases depending on what the first two bits of E5's exponent are
- The first two bits of the exponent are 00. In this case, we will underflow to zero, which will be handled by the denormal case, so it doesn't matter what we do here.
- The first two bits of the exponent are 01. In this case, we turn them into 00, effectively subtracting 8 from the exponent, which is what we want since the exponent bias difference from F16 to F8 is also 8.
- The first two bits of the exponent are 10. In this case, we turn them into 01, also effectively subtracting 8 from the exponent.
- The first two bits of the exponent are 11. In this case, the exponent is high enough to guarentee overflow to Inf, or is already Inf, and so EmitReducePrecisionIR already set all the other exponent bits to 1 as well, resulting in an Inf value. If the value is NaN, EmitReducePrecisionIR will have set the most-significant bit of the mantissa to 1 due to passing
/*quiet_nans=*/true
, ensuring the value will remain nan even as the least-significant bits of the manitssa are truncated. So this case also works.
It took me a long term to understand and verify this. Instead, can we just subtract 8 from the exponent if the value is finite, and 0 otherwise? I.e., create a Select instruction that results in either 8 if the exponent bits are not all 1s, and 0 otherwise, and subtract that from the exponent.
I haven't fully checked the E3M4 case yet but I imagine it's similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recent changes:
- Updated exponent Adjustment code to utilize "subtracting the difference in exponent bias" approach
1. For non-11111 exponent - subtract (exponent_bias_difference << f16_mantissa_bits)
Example:
0x1p2 -> 0 10001 0000 0000 00 // f16_reduced
e4m3: e - 8 -> 0 01001 0000 0000 00
e3m4: e - 12 -> 0 00101 0000 0000 00
2. For 11111 exponent - subtract (exponent_bias_difference << (f16_mantissa_bits + 1))
Example:
inf -> 0 11111 0000 0000 00
e4m3: e - 16 -> 0 01111 0000 0000 00
e3m4: e - 24 -> 0 00111 0000 0000 00
11111 exponent handling is needed to make sure that f8_bits
sign bit is 0 after "// Shift to convert to F8." step.
The code below F8 conversion assumes that f8_bits
sign bit is 0.
min_normal_value
is calculated as(exponent_bias_difference + 1) << f16_mantissa_bits;
83b9546
to
864c530
Compare
864c530
to
ec1c723
Compare
This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).
f8E4M3
type follows IEEE 754 convention.f8E3M4
type follows IEEE 754 conventionTesting:
Related PRs: