diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index d21ac68f..7e6fcc69 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -77,7 +77,7 @@ def get_scale_and_bias( numerics_: numerics.AqtNumerics, context: utils.Context | None = None, ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: - del shared_axes, context + del context if isinstance(self.bound, float) and self.bound <= 0.0: raise ValueError(f'{self.bound=} should be positive.') dtype = self.dtype if self.dtype is not None else x.dtype @@ -85,7 +85,10 @@ def get_scale_and_bias( # TODO(yichizh): hardcode bf16 for the scales, subject to quality evaluation bound = self.bound if jnp.isscalar(bound): - bound = jnp.full(x.shape, bound, x.dtype) + bound_shape = list(x.shape) + for ax in shared_axes: + bound_shape[ax] = 1 + bound = jnp.ones(bound_shape, dtype=x.dtype) * bound scale = bound / numerics_.get_quant_bound() scale = ceil_to_po2(scale) if self.po2_scale else scale