[RFC] Microscaling (MX) types in XLA #18085
Replies: 1 comment 4 replies
-
Does this just mean that there will be no MX datatypes in HLO since they'll be represented by the underlying math? Or would this proposal include specifying arithmetic ops on tuple types? I.e. at what point is the tuple needed?
Would like to hear more on this / see an example of an MX datatype program. Tuples have very little use in StableHLO today and are mostly deprecated as a relic of HLO. Also cc @sdasgup3 to see if this numeric format should have any ties to the quantized type. Seems plausible that they could be. If XLA will require only the underlying math, we could handle this similarly to existing quantization where StableHLO has a more static representation and a pass to decompose to the underlying Q/DQ math. |
Beta Was this translation helpful? Give feedback.
-
RFC: Microscaling (MX) types in XLA
Overview
Open Compute Project (OCP) proposed Microscaling Formats (MX) Specification v1.0 in September 2023. It defines floating-point formats such as
MXFP8
,MXFP6
andMXFP4
.This RFC proposes to add new primitive types that will allow implementing MX formats in XLA (using tuple types).
Summary
MX floating point formats
The MX specification defines a way to represent block scaled data using three components: private elements, scaling factors and block size.
Concrete MX-compliant formats:
The type names used in the MX specification correspond to the following XLA primitive types:
The
FN
type suffix denotes types that can represent finite values only (F
) and have a special NaN encoding (N
).Important to note: XLA has both
F8E4M3
andF8E4M3FN
primitive types, which have different semantics. The MX spec uses the latter for theMXFP8
format.New XLA primitive types
The primitive types necessary to implement MX floating-point formats in XLA were added to LLVM APFloat1 2 3, to MLIR4 5 6 7 and to JAX-ML8 9, the latter also makes them available in NumPy. StableHLO RFC10 is in review.
F8E8M0FNU
8-bit floating point type with no sign bit, 8 bits exponent and no mantissa.
This type cannot encode negative values or zeros, but it's only intended to be used for scaling factors where such values are not needed.
F6E3M2FN
6-bit floating point type with 1 sign bit, 3 bits exponent and 2 bits mantissa.
F6E2M3FN
6-bit floating point type with 1 sign bit, 2 bits exponent and 3 bits mantissa.
F4E2M1FN
4-bit floating point type with 1 sign bit, 2 bits exponent and 1 bit mantissa.
Composite types
MXFP8
data could be conceptually represented in XLA as a tuple type, e.g.(f8e4m3fn[…,N], f8e8m0fnu[…,N/32])
, once the proposed primitive types are added. Similarly forMXFP6
andMXFP4
.A possible alternative is to add
MXFP8
to the list of primitive types, but it's not primitive. This would add a lot of tech debt for no good reason.Memory layout
F4E2M1FN
tensors could be packed similarly toU4
tensors, where every byte stores two values. We can piggyback on the existing implementation for loads and stores.6-bit types (
F6E2M3FN
andF6E3M2FN
) could be packed in a way where every three bytes store four values. An alternative memory layout for sub-byte types is described in eXmY paper11, which could be used for 6-bit types, but this is out of scope of this RFC.F8E8M0FNU
tensors do not require special memory layout, and could be implemented similarly to the other FP8 types.Type conversion
HLO convert op
HLO
convert
op will be updated to support the new primitive types. The conversion will be done using RN (round to nearest even) rounding mode, similarly to the other XLA floating point type conversions.XLA currently has two implementations of type conversion lowering:
ElementalIrEmitter
(used on CPU) andExpandFloatOpsPass
(used on GPU) - both will need to be updated.During the conversion from IEEE-754 types, infinities and exponent overflows will be clamped to the maximum absolute value (preserving the sign).
The MX specification doesn't define how NaN values should be encoded for the types that don't support NaN (
F4E2M1FN
,F6E2M3FN
,F6E3M2FN
). Two possible options are either to use negative zero value, or to use maximum absolute value (preserving the sign).When converting a signed type to the
F8E8M0FNU
type, the sign will be ignored. There's a shortcut for converting fromFP32
orBF16
types in RZ (round to zero) rounding mode - right shift to keep the exponent only.Dequantization
In order to convert an MX format (a tuple of element and scaling tensors) to a wider type (e.g. FP16), one should upcast and multiply the element tensor by the broadcasted scaling tensor.
MLIR example of converting
MXFP8
format toFP16
:Quantization
Conversion of a floating point tensor to an MX format is described in the MX specification (section 6.3).
MLIR example of converting
FP32
toMXFP8
format:The quantized tensor may contain non-finite values due to conversion overflow in
MXFP8
- these should be replaced by the absolute maximum value (as saturating conversion is not available in StableHLO). This doesn't happen withMXFP6
andMXFP4
, as their element types are finite-only.HLO ops
Arithmetic ops
We can support the arithmetic ops on new primitive types in a way similar to FP8 by using the
FloatNormalization
compiler pass to upcast the smaller type toFP16
, perform the operation and downcast back. Adjacent convert ops should be eliminated by theSimplifyFPConversions
compiler pass.As the MX formats will be represented by tuple types in XLA, doing any arithmetics on such composite types would require explicit quantization and dequantization around the arithmetic ops.
Scaled dot op
A scaled dot op (doesn't exist in HLO as of today) could accept a block scaled format for either LHS or RHS input (or both). This means it would have three or four tensor parameters instead of two.
I propose to use a custom call for representing a scaled dot op, until it is no longer experimental - at that point we could introduce a new HLO op.
Analysis
Microscaling formats
For
MXFP8
, a block of data takes 33 bytes of memory, a 48% reduction in size compared to FP16, and a 3% overhead compared to FP8. Q/DQ mean relative error is ~4.7% withF8E5M2
and ~2.4% withF8E4M3FN
types.For
MXFP6
, a block of data takes 25 bytes of memory, a 61% reduction in size compared to FP16, and a 22% reduction compared to FP8. Q/DQ mean relative error is ~5.0% with bothF6E3M2FN
andF6E2M3FN
types.For
MXFP4
, a block of data takes 17 bytes of memory, a 73% reduction in size compared to FP16, and a 47% reduction compared to FP8. Q/DQ mean relative error is ~16%.The Q/DQ mean relative error was calculated by converting an uniform distribution of FP32 values to the MX format and back (quantize followed by dequantize, as specified above) and aggregating the absolute delta divided by the absolute value. Different value distributions could yield different results.
Comparison to FP8
FP8 tensors in XLA have an accompanying scaling factor scalar12, which is used to dequantize the data (implicitly in the case of dot operation). This has a few implications:
If the input data has outliers (e.g. a normal distribution), then the majority of the values will have their quantized accuracy reduced compared to a block scaled format, where such outliers would only affect the accuracy of their block.
To compute the tensor scaling factor, a tensor-wide reduction is necessary - this results in an extra collective operation, which could be slow in multi-host setups. With block scaled formats this could be avoided.
StableHLO
StableHLO RFC10 proposes adding the MX floating point primitive types, this is a requirement for adding these primitive types to XLA.
The MX floating point formats will be represented similarly in StableHLO - using the tuple composite types. The existing quantization types in StableHLO are integer and cannot represent MX block formats.
Footnotes
LLVM PR#95392 [APFloat] Add APFloat support for FP4 data type ↩
LLVM PR#94735 [APFloat] Add APFloat support for FP6 data types ↩
LLVM PR#107127 [APFloat] Add APFloat support for E8M0 type ↩
LLVM PR#108877 [MLIR] Add f4E2M1FN type ↩
LLVM PR#107999 [MLIR] Add f6E2M3FN type ↩
LLVM PR#105573 [MLIR] Add f6E3M2FN type ↩
LLVM PR#111028 [MLIR] Add f8E8M0FNU type ↩
JAX-ML PR#181 Add sub-byte data types: float4_e2m1fn, float6_e2m3fn, float6_e3m2fn ↩
JAX-ML PR#166 Add float8_e8m0_fnu (E8M0) OCP MX scale format ↩
StableHLO PR#2581 [RFC] Microscaling data types (f4E2M1FN, f6E2M3FN, f6E3M2FN, f8E8M0FNU) ↩ ↩2
eXmY: A Data Type and Technique for Arbitrary Bit Precision Quantization ↩
RFC: FP8 in XLA ↩
Beta Was this translation helpful? Give feedback.
All reactions