Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #16585: Add support for float8_e4m3 and float8_e3m4 types #17774

Merged
merged 1 commit into from
Oct 2, 2024

Commits on Oct 2, 2024

  1. PR #16585: Add support for float8_e4m3 and float8_e3m4 types

    Imported from GitHub PR #16585
    
    This PR adds f8E4M3 and f8E3M4 types support to XLA (mainly to cpu_compiler).
    
    ### `f8E4M3` type follows IEEE 754 convention.
    
    ```c
    f8E4M3 (IEEE 754)
    - Exponent bias: 7
    - Maximum stored exponent value: 14 (binary 1110)
    - Maximum unbiased exponent value: 14 - 7 = 7
    - Minimum stored exponent value: 1 (binary 0001)
    - Minimum unbiased exponent value: 1 − 7 = −6
    - Precision specifies the total number of bits used for the significand (mantisa),
        including implicit leading integer bit = 3 + 1 = 4
    - Follows IEEE 754 conventions for representation of special values
    - Has Positive and Negative zero
    - Has Positive and Negative infinity
    - Has NaNs
    
    Additional details:
    - Max exp (unbiased): 7
    - Min exp (unbiased): -6
    - Infinities (+/-): S.1111.000
    - Zeros (+/-): S.0000.000
    - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111}
    - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240
    - Min normal number: S.0001.000 = +/-2^(-6)
    - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7
    - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9)
    ```
    
    ### `f8E3M4` type  follows IEEE 754 convention
    
    ```c
    f8E3M4 (IEEE 754)
    - Exponent bias: 3
    - Maximum stored exponent value: 6 (binary 110)
    - Maximum unbiased exponent value: 6 - 3 = 3
    - Minimum stored exponent value: 1 (binary 001)
    - Minimum unbiased exponent value: 1 − 3 = −2
    - Precision specifies the total number of bits used for the significand (mantissa),
        including implicit leading integer bit = 4 + 1 = 5
    - Follows IEEE 754 conventions for representation of special values
    - Has Positive and Negative zero
    - Has Positive and Negative infinity
    - Has NaNs
    
    Additional details:
    - Max exp (unbiased): 3
    - Min exp (unbiased): -2
    - Infinities (+/-): S.111.0000
    - Zeros (+/-): S.000.0000
    - NaNs: S.111.{0,1}⁴ except S.111.0000
    - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5
    - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2)
    - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6)
    - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 =  +/-2^(-2) x 2^(-4) = +/-2^(-6)
    ```
    
    ### Testing:
    ```
    bazel test \
    //xla:array2d_test \
    //xla:fp_util_test \
    //xla:literal_comparison_test \
    //xla:literal_test \
    //xla/mlir/utils:type_util_test \
    //xla:primitive_util_test \
    //xla/python/ifrt:dtype_test \
    //xla/python:xla_client_test \
    //xla/service:elemental_ir_emitter_test \
    //xla/service:float_normalization_test \
    //xla/service/gpu/tests:float_conversions_test \
    //xla/tests:array_elementwise_ops_test \
    //xla/tests:constants_test \
    //xla/tests:convert_test \
    //xla/tests:float8_test \
    //xla:util_test
    
    bazel test \
    //xla/hlo/translate/hlo_to_mhlo/tests:import.hlo.test \
    //xla/hlo/translate/mhlo_to_hlo/tests:export.mlir.test \
    //xla/mlir_hlo/tests:Dialect/mhlo/hlo-legalize-to-stablehlo.mlir.test \
    //xla/mlir_hlo/tests:Dialect/mhlo/ops.mlir.test \
    //xla/mlir_hlo/tests:Dialect/mhlo/stablehlo-legalize-to-hlo.mlir.test
    ```
    
    ### Related PRs:
    - LLVM [PR-97179](llvm/llvm-project#97179) [APFloat] Add support for f8E4M3 IEEE 754 type (Merged)
    - LLVM [PR-97118](llvm/llvm-project#97118) [MLIR] Add f8E4M3 IEEE 754 type (Merged)
    - LLVM [PR-99698](llvm/llvm-project#99698) [APFloat] Add support for f8E3M4 IEEE 754 type (Merged)
    -  LLVM [PR-101230](llvm/llvm-project#101230) [MLIR] Add f8E3M4 IEEE 754 type (Merged)
    - StableHLO [PR-2486](openxla/stablehlo#2486) [RFC] Add f8E4M3 and f8E3M4 types support (Merged)
    - StableHLO [PR-2482](openxla/stablehlo#2482) Add f8E4M3 and f8E3M4 types support (Merged)
    - ml_dtypes [PR-161](jax-ml/ml_dtypes#161) Add float8_e4m3 (Merged)
    - ml_dtypes [PR-171](jax-ml/ml_dtypes#171) Add float8_e3m4 (Merged)
    - XLA [PR-17075](#17075) [TSL] Bump ml_dtypes. Add float8_e4m3, float8_e3m4 (Approved)
    - XLA [PR-3200](#3200) Add support for float8_e4m3fnuz and float8_e5m2fnuz (Template)
    - JAX [PR-23585](jax-ml/jax#23585) Add float8_e4m3 type support (in Review)
    Copybara import of the project:
    
    --
    ec1c723 by Alexander Pivovarov <[email protected]>:
    
    Add support for float8_e4m3 and float8_e3m4 types
    
    Merging this change closes #16585
    
    COPYBARA_INTEGRATE_REVIEW=#16585 from apivovarov:float8_e4m3 ec1c723
    PiperOrigin-RevId: 681551979
    apivovarov authored and Google-ML-Automation committed Oct 2, 2024
    Configuration menu
    Copy the full SHA
    693ee2e View commit details
    Browse the repository at this point in the history