Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #16585: Add support for float8_e4m3 and float8_e3m4 types #17774

Merged
merged 1 commit into from
Oct 2, 2024

Conversation

copybara-service[bot]
Copy link

PR #16585: Add support for float8_e4m3 and float8_e3m4 types

Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

f8E4M3 type follows IEEE 754 convention.

f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 17 =6
- Precision specifies the total number of bits used for the significand (mantisa), 
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)

f8E3M4 type follows IEEE 754 convention

f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 13 =2
- Precision specifies the total number of bits used for the significand (mantissa), 
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)

Testing:

bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test

Related PRs:

  • LLVM PR-97179 [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-97118 [MLIR] Add f8E4M3 IEEE 754 type (Merged)
  • LLVM PR-99698 [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
  • LLVM PR-101230 [MLIR] Add f8E3M4 IEEE 754 type (Merged)
  • StableHLO PR-2486 [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
  • StableHLO PR-2482 Add f8E4M3 and f8E3M4 types support (Merged)
  • ml_dtypes PR-161 Add float8_e4m3 (Merged)
  • ml_dtypes PR-171 Add float8_e3m4 (Merged)
  • XLA PR-17075 [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
  • XLA PR-3200 Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
  • JAX PR-23585 Add float8_e4m3 type support (in Review)
    Copybara import of the project:

--
ec1c723 by Alexander Pivovarov [email protected]:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723

@copybara-service copybara-service bot force-pushed the test_680651037 branch 5 times, most recently from 2bfed5e to d90cf75 Compare October 2, 2024 18:27
Imported from GitHub PR #16585

This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).

### `f8E4M3` type follows IEEE 754 convention.

```c
f8E4M3 (IEEE 754)
- Exponent bias: 7
- Maximum stored exponent value: 14 (binary 1110)
- Maximum unbiased exponent value: 14 - 7 = 7
- Minimum stored exponent value: 1 (binary 0001)
- Minimum unbiased exponent value: 1 − 7 = −6
- Precision specifies the total number of bits used for the significand (mantisa),
    including implicit leading integer bit = 3 + 1 = 4
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 7
- Min exp (unbiased): -6
- Infinities (+/-): S.1111.000
- Zeros (+/-): S.0000.000
- NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
- Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
- Min normal number: S.0001.000 = +/-2^(-6)
- Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
- Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
```

### `f8E3M4` type  follows IEEE 754 convention

```c
f8E3M4 (IEEE 754)
- Exponent bias: 3
- Maximum stored exponent value: 6 (binary 110)
- Maximum unbiased exponent value: 6 - 3 = 3
- Minimum stored exponent value: 1 (binary 001)
- Minimum unbiased exponent value: 1 − 3 = −2
- Precision specifies the total number of bits used for the significand (mantissa),
    including implicit leading integer bit = 4 + 1 = 5
- Follows IEEE 754 conventions for representation of special values
- Has Positive and Negative zero
- Has Positive and Negative infinity
- Has NaNs

Additional details:
- Max exp (unbiased): 3
- Min exp (unbiased): -2
- Infinities (+/-): S.111.0000
- Zeros (+/-): S.000.0000
- NaNs: S.111.{0,1}⁴ except S.111.0000
- Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
- Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
- Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
- Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
```

### Testing:
```
bazel test \
//xla:array2d_test \
//xla:fp_util_test \
//xla:literal_comparison_test \
//xla:literal_test \
//xla/mlir/utils:type_util_test \
//xla:primitive_util_test \
//xla/python/ifrt:dtype_test \
//xla/python:xla_client_test \
//xla/service:elemental_ir_emitter_test \
//xla/service:float_normalization_test \
//xla/service/gpu/tests:float_conversions_test \
//xla/tests:array_elementwise_ops_test \
//xla/tests:constants_test \
//xla/tests:convert_test \
//xla/tests:float8_test \
//xla:util_test

bazel test \
//xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
//xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
//xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
```

### Related PRs:
- LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
- LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
-  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
- StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
- StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
- ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
- ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
- XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
- XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
- JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
Copybara import of the project:

--
ec1c723 by Alexander Pivovarov <[email protected]>:

Add support for float8_e4m3 and float8_e3m4 types

Merging this change closes #16585

COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
PiperOrigin-RevId: 681551979
@copybara-service copybara-service bot merged commit 693ee2e into main Oct 2, 2024
1 check passed
@copybara-service copybara-service bot deleted the test_680651037 branch October 2, 2024 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant