Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for float8_e4m3 and float8_e3m4 types #16585

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

apivovarov
Copy link
Contributor

@apivovarov apivovarov commented Aug 28, 2024

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

f8E4M3 type follows IEEE 754 convention.

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Testing:

bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
  • XLA PR-3200 Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
  • JAX PR-23585 Add float8_e4m3 type support (in Review)

# LINT.ThenChange(Google-internal path)

tf_http_archive(
name = "stablehlo",
sha256 = STABLEHLO_SHA256,
strip_prefix = "stablehlo-{commit}".format(commit = STABLEHLO_COMMIT),
urls = tf_mirror_urls("https://github.com/openxla/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
urls = tf_mirror_urls("https://github.com/apivovarov/stablehlo/archive/{commit}.zip".format(commit = STABLEHLO_COMMIT)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not intended to fetch it from your repo.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's intended that you also update the deps in the same PR, could you split it in separate PRs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is pending the merge of StableHLO openxla/stablehlo#2482 Add f8E4M3 and f8E3M4 types support (in Review).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Alexander, Integrate StableHLO at openxla/stablehlo@4f31b2e7 was werged to XLA main today. It includes float8_e4m3 type support. My temporary change in third_party/stablehlo/workspace.bzl was removed from this PR. @mooskagh

GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 3, 2024
### Summary
This is a proposal to add `Float8E4M3` and `Float8E3M4` floating point
types to StableHLO.
Feedback welcome, see [RFC: Float8E4M3 and
Float8E3M4](https://github.com/apivovarov/stablehlo/blob/rfc_f8E4M3_f8E3M4/rfcs/20240808-f8E4M3_f8E3M4.md)
for more details.

### References and Links
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- [RFC: FP8 in
StableHLO](https://github.com/openxla/stablehlo/blob/main/rfcs/20221031-fp8.md)
- [RFC: Float8E4M3FNUZ and
Float8E5M2FNUZ](https://github.com/openxla/stablehlo/blob/main/rfcs/20230321-fp8_fnuz.md)
- StableHLO [PR-2482](#2482)
Add f8E4M3 and f8E3M4 types support
- [Amazon EC2 Trn1
Instances](https://aws.amazon.com/ec2/instance-types/trn1/)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Sep 4, 2024
This PR adds f8E4M3 and f8E3M4 types support.

f8E4M3 and f8E3M4 types follow IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179)
[APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118)
[MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698)
[APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
- LLVM [PR-101230](llvm/llvm-project#101230)
[MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](#2486)
[RFC] Add f8E4M3 and f8E3M4 types support
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add
float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add
float8_e3m4 (Merged)
- XLA [PR-16585](openxla/xla#16585) Add support
for float8_e4m3
@apivovarov
Copy link
Contributor Author

Reed, David, could you please help review this PR? @reedwm @ddunl

@ddunl
Copy link
Member

ddunl commented Sep 5, 2024

I think Reed is the best person to review, I think this will require a patch on our end due to tensorflow/third_party, let me know when you approve and I can take care of the patch

@loislo loislo removed their request for review September 9, 2024 11:14
@apivovarov
Copy link
Contributor Author

Hi Reed,

This PR introduces support for the new f8E4M3 type, which adheres to the IEEE-754 convention. I've already added this type to the LLVM, MLIR, ml_dtypes, and StableHLO projects. This PR extends the support to XLA and includes a reference implementation for the CPU compiler.

Could you please help review this PR?

@reedwm

Copy link
Member

@reedwm reedwm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Sorry for the delay in reviewing.

Normally, it's better to minimize the size of PRs, but I would prefer if E3M4 is also added in the same PR, since it touches most of the same files in the exact same way as E4M3, so it makes it easy to batch review both dtypes at once.

But if adding E3M4 to the same PR is inconvenient with you, I'm fine with this being done as a separate, future PR.

Comment on lines 320 to 321
{BF16, F16, F8E5M2, F8E4M3, F8E4M3FN,
F8E4M3B11FNUZ, F8E5M2FNUZ, F8E4M3FNUZ},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the type list here is growing, I would avoid hardcoding the list of FP8 types. One way to avoid this is to have DoWithUpcastToF32 take a should_upcast bool instead of the existing upcast_types list. Then you can pass something like should_upcast = BitWidth(b.GetShape(x).element_type) <= 16.

There are a lot of places where we list out all FP8 types, but every place we can remove listing these out will help when more types are added :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17130 - Add default upcasting behavior to DoWithUpcastToF32

@@ -111,6 +111,59 @@ INSTANTIATE_TEST_SUITE_P(DoublePrecisionInputs, FixedValueTest,
0x1.fffffffffffffp-127,
0x1.aaaaaaaaaaaaap-127));

TEST(FPDistanceTest, F8E4M3Distance) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this test is almost identical to F8E4M3FNDistance, can you merge them to avoid duplication?

One way to to create a type-parameterized test with TYPED_TEST_P. Another way would be to have a for-loop over primtiives types F8E4M3 and F8E4M3FN, and in the body use primitive_util::PrimitiveTypeSwitch with a lambda that does the CalculateDistanceInFloats calls. See here for an example of how to use PrimitiveTypeSwitch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17135 - Add TypeParam to FP8E4M3DistanceTest

} else if constexpr (std::is_integral_v<ElementwiseT>) {
if constexpr (std::is_signed_v<ElementwiseT>) {
if (rhs_el < static_cast<ElementwiseT>(0)) {
ElementWiseBinaryOp(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't change the formatting here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restored

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

github workflow runs pipx to check clang formatting. Opened PR #17234 Format hlo_evaluator_typed_visitor.h

@@ -25,6 +25,63 @@ limitations under the License.
namespace xla {
namespace {

TEST(LiteralComparisonTest, F8E4M3CompareNear_Equal) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the FP8 tests are duplicated for each FP8 type. Can you use TYPED_TEST_P to deduplicate them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #17133 - Dedup LiteralComparisonTests

@@ -644,15 +648,19 @@ TEST_F(LiteralUtilTest, IsAll) {
// 9 rounds to 8 in E5M2 but is not equal to 8, so this should be false
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e5m2>({q16}).IsAll(9));

tsl::float8_e4m3fn r16(9); // Exactly representable in e4m3
tsl::float8_e4m3 e4m3(9); // Exactly representable in e4m3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why there is a convention of using subsequent single letters to name the FP8 values (q16, r16, s16, etc) but you should follow it or change the convention. Either name this q16, renaming the above e5m2 to p16, or rename all the other FP8 variable names to something more descriptive, as you did for this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed to q16

Comment on lines 283 to 284
case xla::F8E4M3:
return absl::UnimplementedError("F8E4M3 not implemented");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid having to modify this every time a new FP8 type is added, remove all these FP8 cases and check if IsF8Type(literal.shape().element_type() before the switch statement.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17170 Code dedup in execution_trace_utils LiteralToValue

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, this was an intentional choice. Missing switch cases are a compiler error, so having a switch without a default case is preferable when possible. No big deal though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using single-case switch statements can make it easier for the compiler to detect potential errors in the code.
Opened PR #17279 - Use switch case without default in LiteralToValue

@@ -500,6 +500,36 @@ TEST_F(FloatNormalizationTest, DoNotChangeBitcastConvert) {
EXPECT_EQ(root->operand(0)->shape().element_type(), U16);
}

TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e4m3) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge this with the existing test ResolveIfUnsupportedF8e5m2, either by looping over values (F8E4M3, F8E5M2) or by using a value-parameterized test with TEST_P.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17177 - Parametrize FloatNormalizationF8Test ResolveIfUnsupportedF8


XlaBuilder builder(TestName());
auto c = ConstantR1<tsl::float8_e4m3>(&builder, constant);
// F8 outputs are not yet supported so convert to F32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is actually no longer true. We should change the other tests with this comment as well. The test OneCellF8e5m2fnuz does have an FP8 output, so you can use that as an example in modifying this test.

If you want, you can also change the two existing tests with the comment to have FP8 outputs as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened PR #17182 - Parametrize ConstantsFloatTest OneCellFloat

Comment on lines 355 to 389
f16_reduced =
b->CreateOr(b->CreateAnd(f16_reduced, i16_const(0x9FFF)),
b->CreateLShr(b->CreateAnd(f16_reduced, i16_const(0x4000)),
i16_const(1)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is effectively subtracting 8 from the exponent I think, as the difference in exponent bias is 8. Why not do that subtraction directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In contrast to the EmitF16ToF8e4m3fn function, the EmitF16ToF8e4m3 function does not include special code to handle Inf and NaN cases. (code around constexpr int max_finite_value = 0x5F7F;)

If I use -8 approach then several tests in //xla/tests:convert_test_cpu FAILED.
e.g.

  • inf -> -1.0
  • nan -> -1.5

Example:

input is inf
EmitReducePrecisionIR returns
x = 0.11111.0000000000 (0x7C00)

Option1: minus 8
x -= 0.01000.0000000000
// x is 0.10111.0000000000
// Shift to convert to F8: x is 1.0111.000
// f8e4m3 Result is -1.0 (Wrong)

Option2: Right shift E5 exponent's leftmost bit
x = (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1)
// x is 0.01111.0000000000
// Shift to convert to F8: x is 0.1111.000
// f8e4m3 Result is inf (Correct)


// Set output exponent to 11111 if input exponent is 1111 (Inf or NaN)
// 0.1111.000 is 0x78
// 0.11111.000000000000 is 0x7C00
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0.11111.000000000000 has 12 zeros at the end, when it should have 10.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -220,6 +220,59 @@ absl::StatusOr<llvm::Value*> EmitReducePrecisionIR(
return result;
}

llvm::Value* handle_halfway_points_F16ToF8(llvm::Value* f16_abs_bits,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, I think you might have to modify expand_float_ops.cc, which is used by the new MLIR emitters which replace the existing emitters on GPUs. But I'm not very familiar with these new emitters. @jreiffers, can you advice on what needs to be done here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have generic support for all possible float conversions, but the emitted code might not be optimal, so it should be considered a fallback. I didn't look at these conversion routines here in detail, but if they're better, it would make sense to port them to the MLIR pipeline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated xla/service/gpu/fusions/transforms/expand_float_ops.cc and added f8E4M3 cases to:

  • IsInf()
  • IsNaN()
  • RewriteF8Cst::matchAndRewrite() // If we're comparing to +-0, compare the absolute values.

expand_float_ops.cc includes a specialized function for the f8e5m2 type - EmitF16ToF8e5m2(). This is because F16 is technically f16e5m10. The two types are similar, with the primary difference being that the mantissa in f8e5m2 is truncated to 2 bits.

f8E4M3 has a different number of exponent and mantissa bits. The conversion can be efficiently managed using the "generic support for all possible float conversions".

Tested xla on CUDA:

//xla/tests/...    799 tests: 799 tests pass
//xla/service/...  865 tests: 865 tests pass
//xla/client/...    77 tests: 77 tests pass
//xla/runtime/...    1 test: 1 test passes
//xla/ffi/...        6 tests: 6 tests pass
//xla/hlo/...       12 tests: 12 tests pass
//xla/mlir/...     141 tests: 141 tests pass
//xla/mlir_hlo/...  98 tests: 98 tests pass
//xla/pjrt/...      26 tests: 26 tests pass
//xla/tools/...     41 tests: 41 tests pass
//xla/translate/... 61 tests: 61 tests pass

@jreiffers
Copy link
Member

Apologies for the delay, I'm OOO this week. Will take a look on Monday.

copybara-service bot pushed a commit that referenced this pull request Sep 13, 2024
ml_dtypes Updates:
Add float8_e4m3 and float8_e3m4 types support
Fix float divmod with zero denominator
Add int2 and uint2 types
ml_dtypes/commits

Related PRs
ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged)
XLA PR Add support for float8_e4m3 #16585 (In Review)

This closes #17075

PiperOrigin-RevId: 674396944
copybara-service bot pushed a commit that referenced this pull request Sep 13, 2024
ml_dtypes Updates:
Add float8_e4m3 and float8_e3m4 types support
Fix float divmod with zero denominator
Add int2 and uint2 types
ml_dtypes/commits

Related PRs
ml_dtypes PR Add float8_e4m3 jax-ml/ml_dtypes#161 Add float8_e4m3 (Merged)
XLA PR Add support for float8_e4m3 #16585 (In Review)

This closes #17075

PiperOrigin-RevId: 674396944
@apivovarov
Copy link
Contributor Author

Do you plan on adding add f8e3m4 support in this PR? If so, I'd prefer to wait until you add that before re-reviewing, as it allows me to batch my review of both dtypes in one go.

Yes, I plan to focus on adding f8e3m4 support for the rest of the week.

@apivovarov apivovarov changed the title Add support for float8_e4m3 Add support for float8_e4m3 and float8_e3m4 types Sep 23, 2024
@apivovarov
Copy link
Contributor Author

Added float8_e3m4 type support to this PR

@apivovarov apivovarov force-pushed the float8_e4m3 branch 2 times, most recently from 0118994 to 24eb476 Compare September 23, 2024 19:36
@apivovarov
Copy link
Contributor Author

apivovarov commented Sep 23, 2024

Templated the following functions with <f8_exponent_bits> parameter:

  • handle_halfway_points_F16ToF8<f8_exponent_bits>()
  • EmitF16ToF8e<f8_exponent_bits>()
  • EmitToF16F8e<f8_exponent_bits>()

{0x1p8, inf}, // Overflow
{0x1p-6, 0x1p-6}, // Smallest normal
{0x0.2p-6, 0x0.2p-6}, // Smallest denormal
{0x0.Ep-6, 0x0.Ep-6}, // Largest denormal
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update test to match recent changes in xla/tests/convert_test.cc #17437

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the tests (ConvertF16 + Roundtrip tests).

One additional test was added recently DISABLED_ON_CPU(ConvertF32F8e5m2Roundtrip). Need to add similar test for e4m3 and e3m4

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added F32 test:

  • DISABLED_ON_CPU(ConvertF32F8e4m3Roundtrip)
  • DISABLED_ON_CPU(ConvertF32F8e3m4Roundtrip)

@apivovarov apivovarov force-pushed the float8_e4m3 branch 3 times, most recently from 5109f0a to b3b684a Compare September 24, 2024 20:12
@apivovarov
Copy link
Contributor Author

Reed, could you please review this PR again? I've implemented both types and addressed all technical debt.
@reedwm

@apivovarov apivovarov force-pushed the float8_e4m3 branch 4 times, most recently from d999e3b to fa92c93 Compare September 26, 2024 04:13
@apivovarov
Copy link
Contributor Author

apivovarov commented Sep 26, 2024

I made float8_e4m3 and float8_e3m4 optional in xla/python - similar to int2 and uint2. It is needed for JAX because it still uses ml_dtypes-0.4.0 and the update to 0.5.0 is not possible in a near future because of TF.

{0x1.18p0, 0x1.2p0}, // Round-to-even up
{0x1.Fp3, 0x1.Fp3}, // Max value
{0x1.F7Cp3, 0x1.Fp3}, // Largest number that doesn't overflow
{0x1.F8p7, inf}, // Smallest number that overflows
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

0x1.F8p7 should be 0x1.F8p3. And similarly for ConvertF32F8e3m4Roundtrip.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in both places

{0x0.87p-2, 0x0.8p-2}, // Round-to-nearest down
{0x0.89p-2, 0x0.9p-2}, // Round-to-nearest up
{0x0.08p-2, 0}, // Largest number that underflows
{0x0.084p-2, 0x0.1p-2}, // Smallest number that doesn't underflow
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 0x0.084p-2 should be 0x0.0802p-2. There should be 9 zeros between the first binary 1 and the second binary 1 to get the next highest FP16 number above 0x0.08p-2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Because we only have m10 bits - better to switch to 0x1. notation for F16 input number here - 0x1.004p-10. Updated // Largest number that underflows to 0x1p-10 too, since they are related.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed "Smallest number that underflows" in ConvertF32F8e4m3Roundtrip and ConvertF32F8e3m4Roundtrip too - 0x1.000002p-10 and 0x1.000002p-7.

f16_sign = b->CreateLShr(f16_sign, i16_const(8));
Value* f8_sign = b->CreateTrunc(f16_sign, i8_type);

// Truncate the mantissa to f8 mantissa bits.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment should also mention the exponent is made smaller. Maybe:

"Truncate the mantissa to f8 mantissa bits and exponent to f8 exponent bits.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated the comment as suggested

Comment on lines 381 to 385
// Right shift E5 exponent's leftmost bit to convert from E5 to E4 format.
// For example, 10011 becomes 01011, another example, 01011 becomes 00011
// (x & 0b1001'1111'1111'1111) | ((x & 0b0100'0000'0000'0000) >> 1)
from_e5_convert_mask = 0x9FFF;
from_e5_convert_shift = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works but is difficult to understand. IIUC, there are four cases depending on what the first two bits of E5's exponent are

  • The first two bits of the exponent are 00. In this case, we will underflow to zero, which will be handled by the denormal case, so it doesn't matter what we do here.
  • The first two bits of the exponent are 01. In this case, we turn them into 00, effectively subtracting 8 from the exponent, which is what we want since the exponent bias difference from F16 to F8 is also 8.
  • The first two bits of the exponent are 10. In this case, we turn them into 01, also effectively subtracting 8 from the exponent.
  • The first two bits of the exponent are 11. In this case, the exponent is high enough to guarentee overflow to Inf, or is already Inf, and so EmitReducePrecisionIR already set all the other exponent bits to 1 as well, resulting in an Inf value. If the value is NaN, EmitReducePrecisionIR will have set the most-significant bit of the mantissa to 1 due to passing /*quiet_nans=*/true, ensuring the value will remain nan even as the least-significant bits of the manitssa are truncated. So this case also works.

It took me a long term to understand and verify this. Instead, can we just subtract 8 from the exponent if the value is finite, and 0 otherwise? I.e., create a Select instruction that results in either 8 if the exponent bits are not all 1s, and 0 otherwise, and subtract that from the exponent.

I haven't fully checked the E3M4 case yet but I imagine it's similar.

Copy link
Contributor Author

@apivovarov apivovarov Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recent changes:

  1. Updated exponent Adjustment code to utilize "subtracting the difference in exponent bias" approach
1. For non-11111 exponent - subtract (exponent_bias_difference << f16_mantissa_bits)
Example:
       0x1p2 -> 0 10001 0000 0000 00  // f16_reduced
e4m3:  e - 8 -> 0 01001 0000 0000 00
e3m4: e - 12 -> 0 00101 0000 0000 00


2. For 11111 exponent - subtract (exponent_bias_difference << (f16_mantissa_bits + 1))
Example:
         inf -> 0 11111 0000 0000 00
e4m3: e - 16 -> 0 01111 0000 0000 00
e3m4: e - 24 -> 0 00111 0000 0000 00

11111 exponent handling is needed to make sure that f8_bits sign bit is 0 after "// Shift to convert to F8." step.

The code below F8 conversion assumes that f8_bits sign bit is 0.

  1. min_normal_value is calculated as (exponent_bias_difference + 1) << f16_mantissa_bits;

@apivovarov apivovarov force-pushed the float8_e4m3 branch 5 times, most recently from 83b9546 to 864c530 Compare September 27, 2024 23:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants