From dd10639318cd8224326bde5701ff473e79eb20ea Mon Sep 17 00:00:00 2001 From: Cerebra Catalyst Team Date: Tue, 24 Sep 2024 11:27:19 -0700 Subject: [PATCH] Support per-channel static quant with fake scale factors PiperOrigin-RevId: 678328776 --- aqt/jax/v2/calibration.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index d21ac68f..7e6fcc69 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -77,7 +77,7 @@ def get_scale_and_bias( numerics_: numerics.AqtNumerics, context: utils.Context | None = None, ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: - del shared_axes, context + del context if isinstance(self.bound, float) and self.bound <= 0.0: raise ValueError(f'{self.bound=} should be positive.') dtype = self.dtype if self.dtype is not None else x.dtype @@ -85,7 +85,10 @@ def get_scale_and_bias( # TODO(yichizh): hardcode bf16 for the scales, subject to quality evaluation bound = self.bound if jnp.isscalar(bound): - bound = jnp.full(x.shape, bound, x.dtype) + bound_shape = list(x.shape) + for ax in shared_axes: + bound_shape[ax] = 1 + bound = jnp.ones(bound_shape, dtype=x.dtype) * bound scale = bound / numerics_.get_quant_bound() scale = ceil_to_po2(scale) if self.po2_scale else scale