From 05778b81f734ee0ac7ac9b6f5fc464563ea1f3fd Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Wed, 16 Oct 2024 09:23:19 -0700 Subject: [PATCH] Update the accumulation dtype if FP8 precision is used for AQT. PiperOrigin-RevId: 686533904 --- aqt/jax/v2/aqt_dot_general.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 22f4a43e..c07fa01a 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -117,6 +117,11 @@ def dot_general_raw_make( and 2 <= rhs_bits <= 8 ): dg_accumulator_dtype = jnp.int32 + elif ( + lhs_bits in fp8_numerics.fp8_map.keys() + or rhs_bits in fp8_numerics.fp8_map.keys() + ): + dg_accumulator_dtype = jnp.float32 else: dg_accumulator_dtype = None