diff --git a/aqt/jax/v2/numerics/fp_numerics.py b/aqt/jax/v2/numerics/fp_numerics.py index 84ec67c0..96e1ce7b 100644 --- a/aqt/jax/v2/numerics/fp_numerics.py +++ b/aqt/jax/v2/numerics/fp_numerics.py @@ -313,6 +313,7 @@ def radix2_round( key: jax.Array, stochastic_rounding: bool, test_noise_axis=None, + rounding_bias: int = 128, ): """FP stochastic rounding for a given mantissa and exponent. Returns bf16.""" nexp = cfg.nexp @@ -349,7 +350,12 @@ def radix2_round( noise = jax.lax.convert_element_type(rnd_bits, bits_dtype) & man_trunc_mask else: # example e2m1 in bf16: noise = 0b0000000000100000 (shift = 7-1 - 1) - noise = 1 << (man_trunc_bits - 1) # represents 0.5 in the container + # noise = (1<<7) << (man_trunc_bits - 1 - 7) represents 0.5 in the container + shift = man_trunc_bits - 1 - 7 + if shift < 0: + noise = rounding_bias >> (-shift) + else: + noise = rounding_bias << shift noise = jax.lax.convert_element_type(noise, bits_dtype) # This noise add might overflow up to sign bit if x(bf16) has max exp. # In bf16 this happens if x=nan or x=inf. diff --git a/aqt/jax/v2/numerics/fp_numerics_test.py b/aqt/jax/v2/numerics/fp_numerics_test.py index 2e40ee67..ec4e858d 100644 --- a/aqt/jax/v2/numerics/fp_numerics_test.py +++ b/aqt/jax/v2/numerics/fp_numerics_test.py @@ -589,5 +589,29 @@ def test_e1m2_vs_e0m3(self): y_e0m3, _ = e0m3.vjp_fwd(e0m3_input, context=context) assert (y_e1m2 == y_e0m3 * 2).all() + def test_radix2_ceil(self): + x = jnp.array([17, 31, 32, 46, 49, 64, 65], dtype=jnp.bfloat16) + cfg = fp_numerics.FpNumericsConfig( + nexp=8, + minexp=0, + nmant=1, + has_subnormals=True, + has_two_nan=False, + has_naninf=False, + radix=2, + ) + out = fp_numerics.radix2_round( + x, + cfg=cfg, + key=None, + stochastic_rounding=False, + test_noise_axis=None, + rounding_bias=255, + ) + assert ( + out == jnp.array([24, 32, 32, 48, 64, 64, 96], dtype=jnp.bfloat16) + ).all(), f"{out=}" + + if __name__ == "__main__": absltest.main()