Skip to content

Commit

Permalink
add an option to broadcast random noises for stochastic rounding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703849281
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Jan 14, 2025
1 parent 46ae0bc commit 477de80
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 7 deletions.
12 changes: 10 additions & 2 deletions aqt/jax/v2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import copy
import functools
from typing import Literal, TypeAlias
from typing import Literal, Sequence, TypeAlias

from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import aqt_quantizer
Expand Down Expand Up @@ -176,6 +176,7 @@ def set_stochastic_rounding(
vjp_lhs_stochastic_rounding: bool,
vjp_rhs_stochastic_rounding: bool,
implementation: str,
noise_sharing_axes: Sequence[int] = (),
):
"""Configure stochastic rounding implementation."""
noise_implementations = {
Expand All @@ -184,7 +185,14 @@ def set_stochastic_rounding(
}
msg = f'{implementation} not supported.'
assert implementation in noise_implementations.keys(), msg
noise_fn = noise_implementations[implementation]
if noise_sharing_axes:
noise_fn = functools.partial(
noise_implementations[implementation],
noise_sharing_axes=noise_sharing_axes,
)
else:
# for backward compatibility of the config tests.
noise_fn = noise_implementations[implementation]
assert isinstance(
cfg.dlhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer
)
Expand Down
32 changes: 27 additions & 5 deletions aqt/jax/v2/stochastic_rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Efficient stochastic rounding implementation."""

from typing import Callable
from typing import Callable, Sequence
from aqt.jax.v2 import utils
import jax
import jax.numpy as jnp
Expand All @@ -22,25 +22,47 @@
NoiseFn = Callable[[tuple[int, ...], jax.Array], jnp.ndarray]


def _degenerate_noise_shape(
shape: tuple[int, ...],
noise_sharing_axes: Sequence[int] = (),
) -> tuple[int, ...]:
"""Degenerate the given shape to 1 for the broadcasting axes."""
return tuple(
1 if axis in noise_sharing_axes else i for axis, i in enumerate(shape)
)


@utils.flax_slots_kw_only_dataclass
class JaxUniform:
"""Jax uniform noise."""

def __call__(self, shape: tuple[int, ...], key: jax.Array) -> jnp.ndarray:
return jax.random.uniform(key, shape) - 0.5
def __call__(
self,
shape: tuple[int, ...],
key: jax.Array,
noise_sharing_axes: Sequence[int] = (),
) -> jnp.ndarray:
noise_shape = _degenerate_noise_shape(shape, noise_sharing_axes)
return jax.random.uniform(key, noise_shape) - 0.5


@utils.flax_slots_kw_only_dataclass
class RandomCenteredUniform:
"""Customized efficient implementation for random centered uniform noise."""

def __call__(self, shape: tuple[int, ...], key: jax.Array) -> jnp.ndarray:
def __call__(
self,
shape: tuple[int, ...],
key: jax.Array,
noise_sharing_axes: Sequence[int] = (),
) -> jnp.ndarray:
"""Generates uniform number in [-0.5, 0.5]."""
dtype = jnp.dtype('uint16')
nbits = jnp.iinfo(dtype).bits
noise_shape = _degenerate_noise_shape(shape, noise_sharing_axes)

# Generate random bits.
bits = jax.random.bits(key, shape, dtype)
bits = jax.random.bits(key, noise_shape, dtype)

# Align bits with the mantissa of f32.
nmant = jnp.finfo(jnp.float32).nmant
Expand Down
54 changes: 54 additions & 0 deletions aqt/jax/v2/stochastic_rounding_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test for stochastic rounding."""

from absl.testing import absltest
from absl.testing import parameterized
from aqt.jax.v2 import stochastic_rounding
import jax
from numpy import testing as np_testing


JaxUniform = stochastic_rounding.JaxUniform
RandomCenteredUniform = stochastic_rounding.RandomCenteredUniform


class StochasticRoundingTest(parameterized.TestCase):

@parameterized.named_parameters(
("jax_uniform", JaxUniform()),
("random_centered_uniform", RandomCenteredUniform()),
)
def test_range(self, noise_fn):
noises = noise_fn(shape=(10000,), key=jax.random.PRNGKey(0))
np_testing.assert_array_less(noises, 0.5)
np_testing.assert_array_less(-0.5, noises)

@parameterized.named_parameters(
("jax_uniform", JaxUniform()),
("random_centered_uniform", RandomCenteredUniform()),
)
def test_shape(self, noise_fn):
noise_sharing_axes = (0,)
noises = noise_fn(
shape=(2, 3, 4),
key=jax.random.PRNGKey(0),
noise_sharing_axes=noise_sharing_axes,
)
self.assertEqual(noises.shape, (1, 3, 4))


if __name__ == "__main__":
absltest.main()

0 comments on commit 477de80

Please sign in to comment.