Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rounding bias that controls the rounding threshold #750

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion aqt/jax/v2/numerics/fp_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def radix2_round(
key: jax.Array,
stochastic_rounding: bool,
test_noise_axis=None,
rounding_bias: int = 128,
):
"""FP stochastic rounding for a given mantissa and exponent. Returns bf16."""
nexp = cfg.nexp
Expand Down Expand Up @@ -349,7 +350,12 @@ def radix2_round(
noise = jax.lax.convert_element_type(rnd_bits, bits_dtype) & man_trunc_mask
else:
# example e2m1 in bf16: noise = 0b0000000000100000 (shift = 7-1 - 1)
noise = 1 << (man_trunc_bits - 1) # represents 0.5 in the container
# noise = (1<<7) << (man_trunc_bits - 1 - 7) represents 0.5 in the container
shift = man_trunc_bits - 1 - 7
if shift < 0:
noise = rounding_bias >> (-shift)
else:
noise = rounding_bias << shift
noise = jax.lax.convert_element_type(noise, bits_dtype)
# This noise add might overflow up to sign bit if x(bf16) has max exp.
# In bf16 this happens if x=nan or x=inf.
Expand Down
24 changes: 24 additions & 0 deletions aqt/jax/v2/numerics/fp_numerics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,5 +589,29 @@ def test_e1m2_vs_e0m3(self):
y_e0m3, _ = e0m3.vjp_fwd(e0m3_input, context=context)
assert (y_e1m2 == y_e0m3 * 2).all()

def test_radix2_ceil(self):
x = jnp.array([17, 31, 32, 46, 49, 64, 65], dtype=jnp.bfloat16)
cfg = fp_numerics.FpNumericsConfig(
nexp=8,
minexp=0,
nmant=1,
has_subnormals=True,
has_two_nan=False,
has_naninf=False,
radix=2,
)
out = fp_numerics.radix2_round(
x,
cfg=cfg,
key=None,
stochastic_rounding=False,
test_noise_axis=None,
rounding_bias=255,
)
assert (
out == jnp.array([24, 32, 32, 48, 64, 64, 96], dtype=jnp.bfloat16)
).all(), f"{out=}"


if __name__ == "__main__":
absltest.main()