Skip to content

Commit

Permalink
Add support asymmetric fake-quantization to AQTv2.
Browse files Browse the repository at this point in the history
Integration of native quantization with biases will require computing the cross terms. See [#725](#725)

Itemized changes:

- Add `IntAsymmetric` to handle asymmetric integer numerics.
  - this class forgoes some of the more research-y parameters present on `IntSymmetric`.
- Add `MinMaxCalibration` to calculate the scale and bias for asymmetric quantization.

I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`.

PiperOrigin-RevId: 651580879
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Sep 20, 2024
1 parent b907430 commit 47d405a
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 50 deletions.
104 changes: 101 additions & 3 deletions aqt/jax/v2/aqt_dot_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def _modify_dg(
fwd_lhs_tricky_clip_and_round: bool = False,
local_aqt: aqt.LocalAqt | None = None,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> aqt.DotGeneral:
dg = copy.deepcopy(readonly_dg)
if fwd_lhs_tricky_clip_and_round:
Expand Down Expand Up @@ -256,11 +257,15 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True):
# that the scales are not too large.
def disable_quant(c):
_disable_quant_types(c)
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.IntSymmetric):
int_numerics_types = (
int_numerics.IntSymmetric,
int_numerics.IntAsymmetric,
)
if isinstance(c.dg_quantizer.lhs.numerics, int_numerics_types):
c.dg_quantizer.lhs.numerics = (
c.dg_quantizer.lhs.numerics.replace(round=False)
)
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.IntSymmetric):
if isinstance(c.dg_quantizer.rhs.numerics, int_numerics_types):
c.dg_quantizer.rhs.numerics = (
c.dg_quantizer.rhs.numerics.replace(round=False)
)
Expand Down Expand Up @@ -291,6 +296,11 @@ def disable_quant(c):
dg.fwd.dg_quantizer.rhs.numerics.replace(clip_gradient=clip_gradient)
)

if use_asymmetric:
# TODO(aqt): use native asymmetric quantization once it is supported.
# https://github.com/google/aqt/issues/725
config.set_asymmetric_quantization(dg, use_fake_quant=True)

return dg


Expand All @@ -307,6 +317,7 @@ def _aqt_dg_full_lr_diff(
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
dg = _modify_dg(
readonly_dg,
Expand All @@ -319,6 +330,7 @@ def _aqt_dg_full_lr_diff(
fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round,
local_aqt=local_aqt,
clip_gradient=clip_gradient,
use_asymmetric=use_asymmetric,
)
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
return lambda lhs, rhs: dg(lhs, rhs, dims)
Expand All @@ -335,6 +347,7 @@ def _aqt_dg_full(
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
clip_gradient: bool = False,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
return _aqt_dg_full_lr_diff(
lhs_dequant_mode=dequant_mode,
Expand All @@ -348,6 +361,7 @@ def _aqt_dg_full(
readonly_dg=readonly_dg,
dims=dims,
clip_gradient=clip_gradient,
use_asymmetric=use_asymmetric,
)


Expand All @@ -359,13 +373,15 @@ def _aqt_dg_raw_lr_diff(
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
dg = _modify_dg(
readonly_dg,
lhs_dequant_mode=lhs_dequant_mode,
rhs_dequant_mode=rhs_dequant_mode,
lhs_calibration_mode=lhs_calibration_mode,
rhs_calibration_mode=rhs_calibration_mode,
use_asymmetric=use_asymmetric,
)
dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None)
dg.fwd.dg_quantizer.init_calibration()
Expand All @@ -378,6 +394,7 @@ def _aqt_dg_raw(
*,
readonly_dg: aqt.DotGeneral,
dims: jax.lax.DotDimensionNumbers,
use_asymmetric: bool = False,
) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
return _aqt_dg_raw_lr_diff(
dequant_mode,
Expand All @@ -386,6 +403,7 @@ def _aqt_dg_raw(
calibration_mode,
readonly_dg=readonly_dg,
dims=dims,
use_asymmetric=use_asymmetric,
)


Expand Down Expand Up @@ -557,6 +575,15 @@ def test_dot_general_calibration_with_contracting_axis(
dtype=jnp.float32,
clip_gradient=False,
):
is_quantized = not all([
isinstance(dg.fwd.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
isinstance(dg.dlhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.dlhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
isinstance(dg.drhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics),
isinstance(dg.drhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics),
])

readonly_dg = dg
del dg

Expand All @@ -571,9 +598,24 @@ def test_dot_general_calibration_with_contracting_axis(
dims=dims,
clip_gradient=clip_gradient,
)
asym_dg_full = functools.partial(
_aqt_dg_full,
readonly_dg=readonly_dg,
dims=dims,
clip_gradient=clip_gradient,
# This should be removed once asymmetric quant supports use_fwd_quant.
use_fwd_quant=False,
use_asymmetric=True,
)
aqt_dg_raw = functools.partial(
_aqt_dg_raw, readonly_dg=readonly_dg, dims=dims
)
asym_dg_raw = functools.partial(
_aqt_dg_raw,
readonly_dg=readonly_dg,
dims=dims,
use_asymmetric=True,
)
modify_dg = functools.partial(_modify_dg, readonly_dg=readonly_dg)
check = functools.partial(_check_result_eq, lhs=lhs, rhs=rhs, gra=gra)

Expand Down Expand Up @@ -609,6 +651,20 @@ def test_dot_general_calibration_with_contracting_axis(
dict(test_gradient=False),
),
])
check([
("default ", asym_dg_full(aqt.DequantMode.OUTPUT), dict()),
("FQ ", asym_dg_full(aqt.DequantMode.THIS_INPUT), dict()),
(
"raw fwd ",
asym_dg_raw(aqt.DequantMode.OUTPUT),
dict(test_gradient=False),
),
(
"raw fwd FQ ",
asym_dg_raw(aqt.DequantMode.THIS_INPUT),
dict(test_gradient=False),
),
])

check([
(
Expand All @@ -631,6 +687,30 @@ def test_dot_general_calibration_with_contracting_axis(
),
])

if is_quantized:
# Asymmetric quantization does not currently support forward quantization.
with self.assertRaisesRegex(NotImplementedError, r"biases.*forward"):
check([
(
"fwd_quant=F",
aqt_dg_full(
aqt.DequantMode.OUTPUT,
use_fwd_quant=False,
use_asymmetric=True,
),
dict(),
),
(
"fwd_quant=T",
aqt_dg_full(
aqt.DequantMode.OUTPUT,
use_fwd_quant=True,
use_asymmetric=True,
),
dict(),
),
])

check([
(
"default ",
Expand All @@ -641,14 +721,32 @@ def test_dot_general_calibration_with_contracting_axis(
dict(),
),
(
"default ",
"FQ ",
aqt_dg_full(
aqt.DequantMode.THIS_INPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
])
check([
(
"default ",
asym_dg_full(
aqt.DequantMode.OUTPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
(
"FQ ",
asym_dg_full(
aqt.DequantMode.THIS_INPUT,
local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2),
),
dict(),
),
])

if isinstance(
readonly_dg.fwd.dg_quantizer.lhs.numerics,
Expand Down
47 changes: 47 additions & 0 deletions aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Union
from aqt.jax.v2 import aqt_tensor
from aqt.jax.v2 import utils
from aqt.jax.v2.numerics import int_numerics
from aqt.jax.v2.numerics import numerics
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -392,3 +393,49 @@ def _calculate_snr(
snr = jnp.log(1 + signal / noise)

return snr


@utils.flax_slots_kw_only_dataclass
class MinMaxCalibration(Calibration):
"""Calibration between the min and max values.
Attributes:
eps: Optional epsilon to add to the bound to avoid division by zero. Inf
filtering is also performed by QTensor.quant() after division.
"""

eps: float | None = None

def get_scale_and_bias(
self,
x: jnp.ndarray,
shared_axes: Sequence[utils.AxisIdx] | None,
numerics_: int_numerics.IntAsymmetric,
context: utils.Context | None = None,
) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]:
del context
msg = (
'Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot'
' to set them.'
)
assert shared_axes is not None, msg
if not isinstance(numerics_, int_numerics.IntAsymmetric):
raise NotImplementedError(
'MinMaxCalibration only supports int_numerics.IntAsymmetric, but got '
f'{numerics}'
)
dtype = self.dtype if self.dtype is not None else x.dtype

# Scale the full width of the input to the width of the quantization range.
x_min = jnp.min(x, axis=shared_axes, keepdims=True)
x_max = jnp.max(x, axis=shared_axes, keepdims=True)
bound = x_max - x_min
if self.eps is not None:
bound += self.eps
scale = bound / numerics_.get_quant_bound()

# Calculate bias s.t. quant(min(x)) = (min(x) + bias) / scale = quant_min.
quant_min, _ = numerics_.get_quant_range()
bias = quant_min * scale - x_min

return [scale.astype(dtype)], [bias.astype(dtype)]
86 changes: 82 additions & 4 deletions aqt/jax/v2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,35 @@ def set_dg_raw_context(cfg_raw: DotGeneralRaw, key: Optional[jax.Array]):
return ret_cfg


def set_fwd_dequant_mode(
cfg: DotGeneral,
def set_dequant_mode(
cfg: DotGeneralRaw,
*,
lhs_dequant_mode: Optional[DequantMode] = None,
rhs_dequant_mode: Optional[DequantMode] = None,
):
"""Sets the dequant mode for the lhs and rhs of a single dot general raw."""
if lhs_dequant_mode is not None:
cfg.fwd.lhs.dequant_mode = lhs_dequant_mode
cfg.lhs.dequant_mode = lhs_dequant_mode
if rhs_dequant_mode is not None:
cfg.fwd.rhs.dequant_mode = rhs_dequant_mode
cfg.rhs.dequant_mode = rhs_dequant_mode

fake_quant = DequantMode.THIS_INPUT in [lhs_dequant_mode, rhs_dequant_mode]
if fake_quant and jnp.issubdtype(cfg.dg_accumulator_dtype, jnp.integer):
# Fake-quantization is not compatible with integer accumulation.
cfg.dg_accumulator_dtype = None


def set_fwd_dequant_mode(
cfg: DotGeneral,
*,
lhs_dequant_mode: Optional[DequantMode] = None,
rhs_dequant_mode: Optional[DequantMode] = None,
):
set_dequant_mode(
cfg.fwd,
lhs_dequant_mode=lhs_dequant_mode,
rhs_dequant_mode=rhs_dequant_mode,
)


def set_fwd_calibration_mode(
Expand Down Expand Up @@ -404,6 +423,65 @@ def set_bits(
return cfg


def _get_asym_numerics(numerics_: numerics.AqtNumerics):
"""Gets the asymmetric equivalent of the given numerics."""
if isinstance(
numerics_, (int_numerics.IntSymmetric, int_numerics.IntAsymmetric)
):
# pytype: disable=attribute-error
return int_numerics.IntAsymmetric(
bits=numerics_.bits,
clip=numerics_.clip,
clip_gradient=numerics_.clip_gradient,
round=numerics_.round,
noise_fn=numerics_.noise_fn,
dtype=numerics_.dtype,
)
# pytype: enable=attribute-error
elif isinstance(numerics_, no_numerics.NoNumerics):
return numerics_
else:
raise NotImplementedError(
'Asymmetric quantization currently only supports integer numerics,'
f' but got {numerics_}'
)


def _set_asymmetric_quantization(cfg: DotGeneralRaw, use_fake_quant: bool):
"""Replaces symmetric quantization with asymmetric quantization."""
set_numerics(
cfg,
_get_asym_numerics(cfg.dg_quantizer.lhs.numerics),
_get_asym_numerics(cfg.dg_quantizer.rhs.numerics),
)

def replace_calibration(quantizer: aqt_quantizer.Quantizer):
if isinstance(quantizer.calibration, functools.partial):
quantizer.calibration = functools.partial(
calibration.MinMaxCalibration, **quantizer.calibration.keywords
)
else:
quantizer.calibration = calibration.MinMaxCalibration

replace_calibration(cfg.dg_quantizer.lhs)
replace_calibration(cfg.dg_quantizer.rhs)

# Only fake quantization currently supports quantization biases.
if use_fake_quant:
set_dequant_mode(
cfg,
lhs_dequant_mode=DequantMode.THIS_INPUT,
rhs_dequant_mode=DequantMode.THIS_INPUT,
)


def set_asymmetric_quantization(cfg: DotGeneral, *, use_fake_quant: bool):
"""Replaces symmetric quantization with asymmetric quantization."""
_set_asymmetric_quantization(cfg.fwd, use_fake_quant)
_set_asymmetric_quantization(cfg.dlhs, use_fake_quant)
_set_asymmetric_quantization(cfg.drhs, use_fake_quant)


def set_scale_and_bias_dtype(cfg: DotGeneral, dtype: jnp.dtype):
"""Set the dtype for all scales and biases in the given DotGeneral config."""
assert isinstance(
Expand Down
Loading

0 comments on commit 47d405a

Please sign in to comment.