Skip to content

Commit

Permalink
Support per-channel static quant with fake scale factors
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 678328776
  • Loading branch information
Cerebra Catalyst Team authored and copybara-github committed Sep 24, 2024
1 parent 449c561 commit dd10639
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,18 @@ def get_scale_and_bias(
numerics_: numerics.AqtNumerics,
context: utils.Context | None = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
del shared_axes, context
del context
if isinstance(self.bound, float) and self.bound <= 0.0:
raise ValueError(f'{self.bound=} should be positive.')
dtype = self.dtype if self.dtype is not None else x.dtype

# TODO(yichizh): hardcode bf16 for the scales, subject to quality evaluation
bound = self.bound
if jnp.isscalar(bound):
bound = jnp.full(x.shape, bound, x.dtype)
bound_shape = list(x.shape)
for ax in shared_axes:
bound_shape[ax] = 1
bound = jnp.ones(bound_shape, dtype=x.dtype) * bound
scale = bound / numerics_.get_quant_bound()
scale = ceil_to_po2(scale) if self.po2_scale else scale

Expand Down

0 comments on commit dd10639

Please sign in to comment.