From 99ba09a670224a11c0a4adb51afe2037a0bcd52a Mon Sep 17 00:00:00 2001 From: Yichi Zhang Date: Fri, 12 Jul 2024 14:35:04 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 651889166 --- aqt/jax/v2/calibration.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index b183ca4e..7dede933 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -104,6 +104,37 @@ def get_bound( return abs_max.astype(x.dtype) +@utils.flax_slots_kw_only_dataclass +class AbsMeanCalibration(Calibration): + """Simple scale * mean(abs(x)) calibration. + + Attributes: + scale: Set it to something. IntNumerics.clip_gradient=True is likely to be + important. + """ + + scale: float + p: float + + def get_bound( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + context: utils.Context | None = None, + ) -> jnp.ndarray: + """Calibration.""" + del context + assert shared_axes is not None + + abs_sum = jnp.sum(jnp.abs(x) ** self.p, axis=shared_axes, keepdims=True) + count = jnp.sum(x != 0.0, axis=shared_axes, keepdims=True) + count = jnp.where(count == 0.0, jnp.ones_like(count), count) + abs_mean = (abs_sum / count) ** (1.0 / self.p) + abs_mean = abs_mean * self.scale + abs_mean = jnp.where(abs_mean == 0.0, jnp.ones_like(abs_mean), abs_mean) + return abs_mean.astype(x.dtype) + + @utils.flax_slots_kw_only_dataclass class SnrBasedAutoCalibration(Calibration): """Automatically finds the best scale based on SNR values.