Skip to content

Commit

Permalink
internal
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 651889166
  • Loading branch information
ychzhang authored and copybara-github committed Aug 7, 2024
1 parent 31b8fa4 commit 99ba09a
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,37 @@ def get_bound(
return abs_max.astype(x.dtype)


@utils.flax_slots_kw_only_dataclass
class AbsMeanCalibration(Calibration):
"""Simple scale * mean(abs(x)) calibration.
Attributes:
scale: Set it to something. IntNumerics.clip_gradient=True is likely to be
important.
"""

scale: float
p: float

def get_bound(
self,
x: jnp.ndarray,
shared_axes: Sequence[utils.AxisIdx] | None,
context: utils.Context | None = None,
) -> jnp.ndarray:
"""Calibration."""
del context
assert shared_axes is not None

abs_sum = jnp.sum(jnp.abs(x) ** self.p, axis=shared_axes, keepdims=True)
count = jnp.sum(x != 0.0, axis=shared_axes, keepdims=True)
count = jnp.where(count == 0.0, jnp.ones_like(count), count)
abs_mean = (abs_sum / count) ** (1.0 / self.p)
abs_mean = abs_mean * self.scale
abs_mean = jnp.where(abs_mean == 0.0, jnp.ones_like(abs_mean), abs_mean)
return abs_mean.astype(x.dtype)


@utils.flax_slots_kw_only_dataclass
class SnrBasedAutoCalibration(Calibration):
"""Automatically finds the best scale based on SNR values.
Expand Down

0 comments on commit 99ba09a

Please sign in to comment.