From cef287ad7a6aa9612cb6490e4f588ddaf3167322 Mon Sep 17 00:00:00 2001 From: Phoenix Meadowlark Date: Thu, 11 Jul 2024 17:07:03 -0700 Subject: [PATCH] Add support asymmetric fake-quantization to AQTv2. Integration of native quantization with biases will require computing the cross terms, likely in the AQT operation quantizer (`DefaultGeneralQuantizer`). Itemized changes: - `AqtNumerics`: - Rename `AqtNumerics.abs_val_mapped_to` to `AqtNumerics.get_scaled_bound` to reflect that the calibration bound may be span the whole quantization range (instead of ~half the range for a strictly linear transformation). - Refactor `IntNumerics` into `BaseIntNumerics`, `SymIntNumerics` and `AsymIntNumerics`. - `AsymIntNumerics` doesn't need `preserve_zero` or `preserve_max_val`. - Add `MinMaxCalibration`. I additionally tested this change by training MNIST models using `flax_e2e_model`. With symmetric quantization the model fails to converge for `config.config_v4(fwd_bits=2, dlhs_bits=None, drhs_bits=None)` (due to `NaN` losses). With asymmetric quantization the model converges even with `config.config_v4(fwd_bits=2, dlhs_bits=2, drhs_bits=4)`. PiperOrigin-RevId: 651580879 --- README.md | 2 +- aqt/jax/v2/aqt_conv_general_test.py | 28 +- aqt/jax/v2/aqt_dot_general_test.py | 171 ++++++++-- aqt/jax/v2/aqt_quantizer.py | 35 +- aqt/jax/v2/calibration.py | 306 ++++++++++++------ aqt/jax/v2/config.py | 160 +++++++-- aqt/jax/v2/config_test.py | 298 ++++++++--------- aqt/jax/v2/examples/flax_e2e_model_test.py | 103 ++++-- .../gptq/examples/gptq_flax_e2e_model.py | 4 +- .../gptq/gptq_dot_general_quantizer.py | 4 +- aqt/jax/v2/flax/aqt_flax_calibration.py | 27 ++ .../v2/flax/delayed_scaling_calibration.py | 15 + aqt/jax/v2/numerics/fp8_numerics.py | 5 +- aqt/jax/v2/numerics/fp_numerics.py | 7 +- aqt/jax/v2/numerics/int_numerics.py | 109 ++++++- aqt/jax/v2/numerics/int_numerics_test.py | 3 +- aqt/jax/v2/numerics/no_numerics.py | 5 +- aqt/jax/v2/numerics/numerics.py | 18 +- aqt/jax/v2/utils.py | 2 +- 19 files changed, 908 insertions(+), 394 deletions(-) diff --git a/README.md b/README.md index b00475a7..64376f08 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ from aqt.jax.v2 import utils as aqt_utils from aqt.jax.v2.numerics import int_numerics q = aqt_quantizer.Quantizer( - numerics=int_numerics.IntNumerics( + numerics=int_numerics.SymIntNumerics( bits=4, preserve_zero=True, preserve_max_val=True, diff --git a/aqt/jax/v2/aqt_conv_general_test.py b/aqt/jax/v2/aqt_conv_general_test.py index 8f8e0dcf..95764f06 100644 --- a/aqt/jax/v2/aqt_conv_general_test.py +++ b/aqt/jax/v2/aqt_conv_general_test.py @@ -15,6 +15,7 @@ from absl.testing import absltest from absl.testing import parameterized from aqt.jax.v2 import aqt_quantizer +from aqt.jax.v2 import config import aqt.jax.v2.aqt_conv_general as aqt_conv import flax.linen.linear as fl import jax @@ -48,13 +49,17 @@ def test_conv_general_dilated( rhs_maxval=20.0, seed=0, ): - dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits) - + dg_raw_conv = aqt_conv.conv_general_dilated_make( + 2, lhs_bits, rhs_bits, initialize_calibration=False + ) + # Power-of-2 scales allow FQ and AQT to be exactly the same. + dg_quantizer = dg_raw_conv.dg_quantizer if dg_raw_conv.lhs: - # Power-of-2 scales allow FQ and AQT to be exactly the same. - dg_raw_conv.dg_quantizer.lhs.po2_scale = True + config.set_quantizer_calibration_config(dg_quantizer.lhs, po2_scale=True) + dg_quantizer.lhs.init_calibration() if dg_raw_conv.rhs: - dg_raw_conv.dg_quantizer.rhs.po2_scale = True + config.set_quantizer_calibration_config(dg_quantizer.rhs, po2_scale=True) + dg_quantizer.rhs.init_calibration() batch_n = 10 contr_n = 20 @@ -94,12 +99,17 @@ def test_conv_general_dilated_quantized( seed=0, ): """Check that passing quantized lhs/rhs to aqt_conv_fn works.""" - dg_raw_conv = aqt_conv.conv_general_dilated_make(2, lhs_bits, rhs_bits) + dg_raw_conv = aqt_conv.conv_general_dilated_make( + 2, lhs_bits, rhs_bits, initialize_calibration=False + ) + # Power-of-2 scales allow FQ and AQT to be exactly the same. + dg_quantizer = dg_raw_conv.dg_quantizer if dg_raw_conv.lhs: - # Power-of-2 scales allow FQ and AQT to be exactly the same. - dg_raw_conv.dg_quantizer.lhs.po2_scale = True + config.set_quantizer_calibration_config(dg_quantizer.lhs, po2_scale=True) + dg_quantizer.lhs.init_calibration() if dg_raw_conv.rhs: - dg_raw_conv.dg_quantizer.rhs.po2_scale = True + config.set_quantizer_calibration_config(dg_quantizer.rhs, po2_scale=True) + dg_quantizer.rhs.init_calibration() batch_n = 10 contr_n = 20 diff --git a/aqt/jax/v2/aqt_dot_general_test.py b/aqt/jax/v2/aqt_dot_general_test.py index 4a6bb48d..169b0c15 100644 --- a/aqt/jax/v2/aqt_dot_general_test.py +++ b/aqt/jax/v2/aqt_dot_general_test.py @@ -165,9 +165,12 @@ class _TrickyNumerics(numerics.AqtNumerics, flax.struct.PyTreeNode): def get_dtype(self): return self.dtype - def abs_val_mapped_to(self) -> jnp.ndarray: + def get_scaled_bound(self) -> jnp.ndarray: return jnp.array(1.0) + def get_quant_range(self) -> tuple[jnp.ndarray, jnp.ndarray]: + return -self.get_scaled_bound(), self.get_scaled_bound() + def fwd(self, x, context): del context return jax.lax.round(x, jax.lax.RoundingMethod.TO_NEAREST_EVEN) @@ -193,6 +196,7 @@ def _modify_dg( fwd_lhs_tricky_clip_and_round: bool = False, local_aqt: aqt.LocalAqt | None = None, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> aqt.DotGeneral: dg = copy.deepcopy(readonly_dg) if fwd_lhs_tricky_clip_and_round: @@ -200,14 +204,6 @@ def _modify_dg( dg.fwd.dg_quantizer.lhs.numerics = _TrickyNumerics() dg.fwd.dg_accumulator_dtype = None - # Setting po2_scale is ensuring that fake_quant and full dot_general - # have the same numerics when scales are power of two (po2). - # We are passing dims to config so that we can reuse it in fake_quant. - # Power-of-2 scales allow FQ and AQT to be exactly the same. - def _apply_po2_scale(c): - c.dg_quantizer.lhs.po2_scale = True - c.dg_quantizer.rhs.po2_scale = True - def _apply_dequant_mode(c, lhs_dequant_mode, rhs_dequant_mode): c.lhs.dequant_mode = lhs_dequant_mode c.rhs.dequant_mode = rhs_dequant_mode @@ -231,7 +227,13 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True): disable_lhs_quant = lhs_dequant_mode == aqt.DequantMode.THIS_INPUT disable_rhs_quant = rhs_dequant_mode == aqt.DequantMode.THIS_INPUT for c in [dg.fwd, dg.dlhs, dg.drhs]: - _apply_po2_scale(c) + # Setting po2_scale is ensuring that fake_quant and full dot_general + # have the same numerics when scales are power of two (po2). + # We are passing dims to config so that we can reuse it in fake_quant. + # Power-of-2 scales allow FQ and AQT to be exactly the same. + config.set_quantizer_calibration_config(c.dg_quantizer.lhs, po2_scale=True) + config.set_quantizer_calibration_config(c.dg_quantizer.rhs, po2_scale=True) + _apply_dequant_mode(c, lhs_dequant_mode, rhs_dequant_mode) _apply_calibration_mode( c, lhs_calibration_mode, rhs_calibration_mode @@ -244,11 +246,15 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True): # that the scales are not too large. def disable_quant(c): _disable_quant_types(c) - if isinstance(c.dg_quantizer.lhs.numerics, int_numerics.IntNumerics): + int_numerics_types = ( + int_numerics.SymIntNumerics, + int_numerics.AsymIntNumerics, + ) + if isinstance(c.dg_quantizer.lhs.numerics, int_numerics_types): c.dg_quantizer.lhs.numerics = ( c.dg_quantizer.lhs.numerics.replace(round=False) ) - if isinstance(c.dg_quantizer.rhs.numerics, int_numerics.IntNumerics): + if isinstance(c.dg_quantizer.rhs.numerics, int_numerics_types): c.dg_quantizer.rhs.numerics = ( c.dg_quantizer.rhs.numerics.replace(round=False) ) @@ -273,15 +279,18 @@ def disable_quant(c): dg.drhs.local_aqt = local_aqt # When using abs-max scaling, this should be a no-op. - if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.IntNumerics): + if isinstance(dg.fwd.dg_quantizer.lhs.numerics, int_numerics.SymIntNumerics): dg.fwd.dg_quantizer.lhs.numerics = ( dg.fwd.dg_quantizer.lhs.numerics.replace(clip_gradient=clip_gradient) ) - if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics): + if isinstance(dg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics): dg.fwd.dg_quantizer.rhs.numerics = ( dg.fwd.dg_quantizer.rhs.numerics.replace(clip_gradient=clip_gradient) ) + if use_asymmetric: + config.set_asymmetric_quantization(dg) + return dg @@ -297,6 +306,7 @@ def _aqt_dg_full_lr_diff( readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: dg = _modify_dg( readonly_dg, @@ -308,6 +318,7 @@ def _aqt_dg_full_lr_diff( fwd_lhs_tricky_clip_and_round=fwd_lhs_tricky_clip_and_round, local_aqt=local_aqt, clip_gradient=clip_gradient, + use_asymmetric=use_asymmetric, ) dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None) return lambda lhs, rhs: dg(lhs, rhs, dims) @@ -323,6 +334,7 @@ def _aqt_dg_full( readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, clip_gradient: bool = False, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: return _aqt_dg_full_lr_diff( dequant_mode, @@ -334,7 +346,8 @@ def _aqt_dg_full( local_aqt, readonly_dg=readonly_dg, dims=dims, - clip_gradient=clip_gradient + clip_gradient=clip_gradient, + use_asymmetric=use_asymmetric, ) @@ -346,6 +359,7 @@ def _aqt_dg_raw_lr_diff( *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: dg = _modify_dg( readonly_dg, @@ -353,6 +367,7 @@ def _aqt_dg_raw_lr_diff( rhs_dequant_mode=rhs_dequant_mode, lhs_calibration_mode=lhs_calibration_mode, rhs_calibration_mode=rhs_calibration_mode, + use_asymmetric=use_asymmetric, ) dg = config.set_context(dg, key=jax.random.PRNGKey(4), train_step=None) dg.fwd.dg_quantizer.init_calibration() @@ -365,6 +380,7 @@ def _aqt_dg_raw( *, readonly_dg: aqt.DotGeneral, dims: jax.lax.DotDimensionNumbers, + use_asymmetric: bool = False, ) -> Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]: return _aqt_dg_raw_lr_diff( dequant_mode, @@ -373,6 +389,7 @@ def _aqt_dg_raw( calibration_mode, readonly_dg=readonly_dg, dims=dims, + use_asymmetric=use_asymmetric, ) @@ -391,7 +408,7 @@ def test_empty(self): def test_fq_noise(self, preserve_zero, prec, v, seed): key = jax.random.PRNGKey(seed) quantizer = config.quantizer_make(prec) - if isinstance(quantizer.numerics, int_numerics.IntNumerics): + if isinstance(quantizer.numerics, int_numerics.SymIntNumerics): quantizer.numerics.preserve_zero = preserve_zero if not preserve_zero: quantizer.numerics.dtype = None @@ -438,8 +455,9 @@ def test_fake_quant( maxval=10.0, shape=(20, 1), ): - quantizer = config.quantizer_make(bits) - quantizer.po2_scale = True + quantizer = config.quantizer_make(bits, initialize_calibration=False) + config.set_quantizer_calibration_config(quantizer, po2_scale=True) + quantizer.init_calibration() quantizer.calib_shared_axes = (0,) x = jnp.linspace(-maxval, maxval, num=shape[0]).reshape(shape) grad = jnp.ones(shape) * 12345.0 @@ -543,6 +561,24 @@ def test_dot_general_calibration_with_contracting_axis( dtype=jnp.float32, clip_gradient=False, ): + # This should be removed once asymmetric quant supports use_fwd_quant. + test_asym = not any([ + dg.fwd.lhs.use_fwd_quant, + dg.fwd.rhs.use_fwd_quant, + dg.dlhs.lhs.use_fwd_quant, + dg.dlhs.rhs.use_fwd_quant, + dg.drhs.lhs.use_fwd_quant, + dg.drhs.rhs.use_fwd_quant, + ]) + is_quantized = not all([ + isinstance(dg.fwd.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.fwd.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + isinstance(dg.dlhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.dlhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + isinstance(dg.drhs.dg_quantizer.lhs.numerics, no_numerics.NoNumerics), + isinstance(dg.drhs.dg_quantizer.rhs.numerics, no_numerics.NoNumerics), + ]) + readonly_dg = dg del dg @@ -557,9 +593,25 @@ def test_dot_general_calibration_with_contracting_axis( dims=dims, clip_gradient=clip_gradient, ) + asym_dg_full = functools.partial( + _aqt_dg_full, + readonly_dg=readonly_dg, + dims=dims, + clip_gradient=clip_gradient, + # As an argument to _modify_dg this must be None, not False. + # Unrelated things happen when False. + use_fwd_quant=None, + use_asymmetric=True, + ) aqt_dg_raw = functools.partial( _aqt_dg_raw, readonly_dg=readonly_dg, dims=dims ) + asym_dg_raw = functools.partial( + _aqt_dg_raw, + readonly_dg=readonly_dg, + dims=dims, + use_asymmetric=True, + ) modify_dg = functools.partial(_modify_dg, readonly_dg=readonly_dg) check = functools.partial(_check_result_eq, lhs=lhs, rhs=rhs, gra=gra) @@ -595,19 +647,57 @@ def test_dot_general_calibration_with_contracting_axis( dict(test_gradient=False), ), ]) + if test_asym: + check([ + ("default ", asym_dg_full(aqt.DequantMode.OUTPUT), dict()), + ("FQ ", asym_dg_full(aqt.DequantMode.THIS_INPUT), dict()), + ( + "raw fwd ", + asym_dg_raw(aqt.DequantMode.OUTPUT), + dict(test_gradient=False), + ), + ( + "raw fwd FQ ", + asym_dg_raw(aqt.DequantMode.THIS_INPUT), + dict(test_gradient=False), + ), + ]) check([ ( - "fwd_quant=T", + "fwd_quant=F", aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=False), dict(), ), ( - "fwd_quant=F", + "fwd_quant=T", aqt_dg_full(aqt.DequantMode.OUTPUT, use_fwd_quant=True), dict(), ), ]) + if test_asym and is_quantized: + # Asymmetric quantization does not currently support forward quantization. + with self.assertRaisesRegex(NotImplementedError, r"biases.*forward"): + check([ + ( + "fwd_quant=F", + aqt_dg_full( + aqt.DequantMode.OUTPUT, + use_fwd_quant=False, + use_asymmetric=True, + ), + dict(), + ), + ( + "fwd_quant=T", + aqt_dg_full( + aqt.DequantMode.OUTPUT, + use_fwd_quant=True, + use_asymmetric=True, + ), + dict(), + ), + ]) check([ ( @@ -619,7 +709,7 @@ def test_dot_general_calibration_with_contracting_axis( dict(), ), ( - "default ", + "FQ ", aqt_dg_full( aqt.DequantMode.THIS_INPUT, local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), @@ -627,10 +717,29 @@ def test_dot_general_calibration_with_contracting_axis( dict(), ), ]) + if test_asym: + check([ + ( + "default ", + asym_dg_full( + aqt.DequantMode.OUTPUT, + local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), + ), + dict(), + ), + ( + "FQ ", + asym_dg_full( + aqt.DequantMode.THIS_INPUT, + local_aqt=aqt.LocalAqt(contraction_axis_shard_count=2), + ), + dict(), + ), + ]) if isinstance( readonly_dg.fwd.dg_quantizer.lhs.numerics, - int_numerics.IntNumerics, + int_numerics.SymIntNumerics, ): check([ ( @@ -1059,7 +1168,7 @@ def test_local_aqt(self, shard_count, lhs, expected_product): def test_per_tensor(self): # TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes bits = 4 - my_numerics = int_numerics.IntNumerics( + my_numerics = int_numerics.SymIntNumerics( bits=bits, preserve_zero=True, preserve_max_val=False, @@ -1074,7 +1183,6 @@ def test_per_tensor(self): calib_shared_axes="per_tensor", scale_stop_grad=True, calibration=calibration.AbsMaxCalibration, - po2_scale=False, context=utils.Context(key=None, train_step=None), ) # TODO(lew): Perhaps post_init call could work? @@ -1091,12 +1199,18 @@ def test_per_tensor(self): def test_per_subchannel(self): # TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes bits = 4 - quantizer = aqt_quantizer.quantizer_make(bits) + quantizer = aqt_quantizer.quantizer_make(bits, initialize_calibration=False) x = jnp.arange(0, 64).reshape((4, 4, 4)) - tiling_state = tiled_dot_general.generate_tiling_state(x, [ - tiled_dot_general.AxisTiling(axis=2, tile_size=2), - ]) + # NOTE: The scale dtype must be set to a float dtype when quantizing an + # integer input, as jax does not support taking the inverse of an integer. + config.set_quantizer_calibration_config(quantizer, dtype=jnp.float32) + quantizer.init_calibration() + + tiling_state = tiled_dot_general.generate_tiling_state( + x, + [tiled_dot_general.AxisTiling(axis=2, tile_size=2)], + ) qx, _ = quantizer.quant( x, calibration_axes=[0, 2], @@ -1104,6 +1218,7 @@ def test_per_subchannel(self): ) self.assertEqual(qx.qvalue.shape, (4, 4, 2, 2)) self.assertEqual(qx.scale[0].shape, (1, 4, 2, 1)) + self.assertEqual(qx.scale[0].dtype, jnp.float32) x = qx.dequant() self.assertEqual(x.shape, (4, 4, 4)) diff --git a/aqt/jax/v2/aqt_quantizer.py b/aqt/jax/v2/aqt_quantizer.py index 772b039a..cf20f427 100644 --- a/aqt/jax/v2/aqt_quantizer.py +++ b/aqt/jax/v2/aqt_quantizer.py @@ -14,20 +14,20 @@ """Configuration dataclasses.""" from typing import Literal, Sequence + from aqt.jax.v2 import aqt_tensor from aqt.jax.v2 import calibration from aqt.jax.v2 import tiled_dot_general from aqt.jax.v2 import utils - from aqt.jax.v2.numerics import int_numerics from aqt.jax.v2.numerics import no_numerics from aqt.jax.v2.numerics import numerics import jax -import jax.numpy as jnp AbstractAqtNumerics = numerics.AqtNumerics AbstractAqtCalibration = calibration.Calibration +Axes = Sequence[utils.AxisIdx] AxisTiling = tiled_dot_general.AxisTiling TilingState = tiled_dot_general.TilingState @@ -38,19 +38,12 @@ class Quantizer: """Configuration of quantization of one tensor.""" numerics: AbstractAqtNumerics = utils.static_field() - calib_shared_axes: Sequence[utils.AxisIdx] | Literal["per_tensor"] | None = ( - utils.static_field() - ) + calib_shared_axes: Axes | Literal["per_tensor"] | None = utils.static_field() scale_stop_grad: bool = utils.static_field() # noise+clip+round # We apply gradient of clip_and_round in bwd pass. calibration: type[AbstractAqtCalibration] = utils.static_field() _calibrator: AbstractAqtCalibration | None = utils.static_field(default=None) - # Round up the calibration to power of 2 (po2). - po2_scale: bool = utils.static_field() - # The dtype of the quantization scale array. If not set, the scale array will - # be in the same dtype as the input. - scale_dtype: jnp.dtype | None = utils.static_field(default=None) # TODO(yichizh): Factor out auxiliary dataclasses into a separate file. context: utils.Context @@ -129,27 +122,20 @@ def calibrate( shared_axes = self.calib_shared_axes or calibration_axes assert self._calibrator is not None, "forgot self.init_calibration()?" - bound = self._calibrator.get_bound(x, shared_axes, self.context) - abs_max_mapped_to = self.numerics.abs_val_mapped_to() - scale = bound / abs_max_mapped_to - - if self.po2_scale: - # With floor the biggest value (we are using jnp.max) is in the range of - # clipping and therefore have a correct gradient. - scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale))) - scale = jax.lax.reciprocal(scale) + + scale, bias = self._calibrator.get_scale_and_bias( + x, shared_axes, self.numerics, self.context + ) if self.scale_stop_grad: # TODO(lew): Does not matter in DG, because we are using custom gradient. # We should take that into account somehow. scale = jax.lax.stop_gradient(scale) - if self.scale_dtype is not None: - scale = scale.astype(self.scale_dtype) qt = aqt_tensor.QTensor( qvalue=None, - scale=[scale], + scale=scale, scale_t=None, - bias=[], + bias=bias, dequant_dtype=dequant_dtype, tiling_state=tiling_state, ) @@ -188,7 +174,7 @@ def quantizer_make( else: pz = False if n_bits == 1 else True dtype = utils.infer_dtype_from_bits(n_bits) if pz else None - effective_numerics = int_numerics.IntNumerics( + effective_numerics = int_numerics.SymIntNumerics( bits=n_bits, preserve_zero=pz, preserve_max_val=preserve_max_val, @@ -203,7 +189,6 @@ def quantizer_make( calib_shared_axes=None, scale_stop_grad=True, calibration=calibration.AbsMaxCalibration, - po2_scale=False, context=utils.Context(key=None, train_step=None), ) # TODO(lew): We should try to move to to class constructor or post-init. diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index ddb18b4d..e2c56d0d 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -18,21 +18,46 @@ from typing import Union from aqt.jax.v2 import aqt_tensor from aqt.jax.v2 import utils +from aqt.jax.v2.numerics import int_numerics from aqt.jax.v2.numerics import numerics +import jax import jax.numpy as jnp +def ceil_to_po2(scale: jnp.ndarray) -> jnp.ndarray: + # With floor the biggest value (we are using jnp.max) is in the range of + # clipping and therefore have a correct gradient. + scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale))) + scale = jax.lax.reciprocal(scale) + return scale + + @utils.flax_slots_kw_only_dataclass class Calibration(abc.ABC): - """Abstract class for calibration.""" + """Abstract class for scale and bias calibration.""" + + # The dtype of the quantization scale and bias arrays. If not set, the arrays + # will be in the same dtype as the input. + dtype: jnp.dtype | None = utils.static_field(default=None) + # Round up the calibration to power of 2 (po2). + po2_scale: bool = utils.static_field(default=False) @abc.abstractmethod - def get_bound( + def get_scale_and_bias( self, x: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, context: utils.Context | None = None, - ) -> jnp.ndarray: + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + """Returns the quantizaiton scale and bias for the given input tensor.""" + # NOTE: The scale and bias calculation are handled by the Calibration + # class because there is not a single order in which they should be + # calculated. In the case of symmetric quantization, the scale depends on + # the bias as the bias shifts the symmetric upperbound. In the case of + # asymmetric quantization, the bias depends on the scale as the scale + # determines how far the bias should shift the input s.t. the minimum + # quantized value aligns with the minimum quantization bucket. pass def init_calibration(self): @@ -41,21 +66,28 @@ def init_calibration(self): @utils.flax_slots_kw_only_dataclass class ConstantCalibration(Calibration): - """Calibration with a constant value.""" + """Calibration with a constant per-tensor value.""" bound: Union[jnp.ndarray, float] + bias: Union[jnp.ndarray, float] | None = None - def get_bound( + def get_scale_and_bias( self, x: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, context: utils.Context | None = None, - ) -> jnp.ndarray: - """Calibration.""" + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: del shared_axes, context assert self.bound > 0, 'Bound should be positive.' + dtype = self.dtype if self.dtype is not None else x.dtype + # TODO(yichizh): hardcode bf16 for the scales, subject to quality evaluation - return jnp.asarray(self.bound).reshape((1,) * len(x.shape)).astype(x.dtype) + bound = jnp.full(x.shape, self.bound, x.dtype) + scale = bound / numerics_.get_scaled_bound() + scale = ceil_to_po2(scale) if self.po2_scale else scale + bias = [] if self.bias is None else [jnp.full(x.shape, self.bias, dtype)] + return [scale.astype(dtype)], bias @utils.flax_slots_kw_only_dataclass @@ -63,179 +95,200 @@ class AbsMaxCalibration(Calibration): """Simple max(abs(x)) calibration. Attributes: - scale: Set it to something like 0.3, 0.1, 0.03. If scale < 1.0, setting - IntNumerics.clip_gradient=True is likely to be important. + clipping_factor: Set it to something like 0.3, 0.1, 0.03. If clipping_factor + < 1.0, setting IntNumerics.clip_gradient=True is likely to be important. """ - scale: float | None = None + clipping_factor: float | None = None - def get_bound( + def get_scale_and_bias( self, x: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, context: utils.Context | None = None, - ) -> jnp.ndarray: + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: """Calibration. Args: x: The input tensor. shared_axes: Axes that share a calibration bound. For AbsMaxCalibration, it should not be None. + numerics_: An `AqtNumerics` object containing information regarding + quantization. Used to create the scale and bias arrays. context: The quantization context. Returns: - The bound tensor containing the bound values for each group (can + The scale tensor containing the scale values for each group (can potentially be a subchannel). Its shape will be the same as `x.shape` but - with `shared_axes` collapsed to 1. + with `shared_axes` collapsed to 1. Bias is not supported. """ del context - msg = ( 'Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot' ' to set them.' ) assert shared_axes is not None, msg + dtype = self.dtype if self.dtype is not None else x.dtype - # NOTE: If you want to clip, consider using clip and clip_gradient in - # int_numerics.IntNumerics. + # NOTE: If you use a clipping_factor, consider using clip and clip_gradient + # in int_numerics.IntNumerics. abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) # TODO(yichizh): the zero filtering is not needed anymore because inf is # filtered when calculating the reciprocal of scaling factor abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) - if self.scale is not None: - abs_max = abs_max * self.scale - return abs_max.astype(x.dtype) + if self.clipping_factor is not None: + abs_max = abs_max * self.clipping_factor + + scale = abs_max / numerics_.get_scaled_bound() + scale = ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] @utils.flax_slots_kw_only_dataclass class AbsMeanCalibration(Calibration): - """Simple scale * mean(abs(x)) calibration. + """Simple clipping_factor * mean(abs(x) ** p) ** (1 / p) calibration. Attributes: - scale: Set it to something. IntNumerics.clip_gradient=True is likely to be - important. + clipping_factor: If clipping_factor < 1.0, setting + IntNumerics.clip_gradient=True is likely to be important. + p: Set it to 1 for mean of absolute scaling. """ - scale: float + clipping_factor: float p: float - def get_bound( + def get_scale_and_bias( self, x: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, context: utils.Context | None = None, - ) -> jnp.ndarray: + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: """Calibration.""" del context assert shared_axes is not None + dtype = self.dtype if self.dtype is not None else x.dtype abs_sum = jnp.sum(jnp.abs(x) ** self.p, axis=shared_axes, keepdims=True) count = jnp.sum(x != 0.0, axis=shared_axes, keepdims=True) count = jnp.where(count == 0.0, jnp.ones_like(count), count) abs_mean = (abs_sum / count) ** (1.0 / self.p) - abs_mean = abs_mean * self.scale + abs_mean = abs_mean * self.clipping_factor abs_mean = jnp.where(abs_mean == 0.0, jnp.ones_like(abs_mean), abs_mean) - return abs_mean.astype(x.dtype) + + scale = abs_mean / numerics_.get_scaled_bound() + scale = ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] @utils.flax_slots_kw_only_dataclass class SnrBasedAutoCalibration(Calibration): - """Automatically finds the best scale based on SNR values. + """Automatically finds the best clipping factors based on SNR values. - The best scale is determined by the SNR (signal-to-noise ratio) values of the - quantized tensor. The SNR is calculated by the following formula: + The best clipping factors are determined by the SNR (signal-to-noise ratio) + values of the quantized tensor. The SNR is calculated by the following + formula: SNR = log(1 + signal / noise) where signal = sum(x ** 2) and noise = sum(err ** 2). - An SNR value is calculated for each scale per subchannel group. Scales that - produce the highest SNR value for each subchannel group are selected as the - best scale. + An SNR value is calculated for each clipping factor per subchannel group. + Clipping factors that produce the highest SNR value for each subchannel group + are selected and used to calculate the best quantization scale. Attributes: - numerics: An `AqtNumerics` object containing information regarding - quantization such as target dtype. Also used to actually quantize (round - and clip) the tensor when calculating the SNR values. - scale_search_space: A sequence of scale values, a.k.a. clipping factors, to - search for the best scale. + auto_clip_search_config: A sequence of clipping factors to use to search for + the best per-channel quantization scale. """ - numerics: numerics.AqtNumerics - auto_scale_search_config: utils.AutoScaleSearchConfig + auto_clip_search_config: utils.AutoClipSearchConfig - def get_bound( + def get_scale_and_bias( self, x: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, context: utils.Context | None = None, - ) -> jnp.ndarray: - """Produces the max bound for quantization based on SNR values. + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + """Produces the scale for quantization based on SNR values. Args: x: The input tensor. shared_axes: Axes that each subchannel group is shared across. + numerics_: An `AqtNumerics` object containing information regarding + quantization such as target dtype. Also used to actually quantize (round + and clip) the tensor when calculating the SNR values. context: The quantization context. Returns: - The bound tensor containing the bound values for each subchannel group. + The scale tensor containing the scale values for each subchannel group. Its shape will be the same as `x.shape` but with `shared_axes` collapsed - to 1. + to 1. Biases are not supported. """ - # Determine the shape of the best_scale_values. There will be one scale - # value per subchannel group, so it shape will be the same as `x.shape` but - # with `shared_axes` collapsed to 1. - scales_shape = list(x.shape) + dtype = self.dtype if self.dtype is not None else x.dtype + + # Determine the shape of the best_subchannel_clip_factors. There will be one + # clip factor per subchannel group, so it shape will be the same as + # `x.shape` but with `shared_axes` collapsed to 1. + clip_shape = list(x.shape) for i in shared_axes: - scales_shape[i] = 1 + clip_shape[i] = 1 - # Default value of 1.0 (max value). One scale value per subchannel group. - best_scale_values = jnp.ones((*scales_shape,), dtype=jnp.float32) + # Default factor of 1.0 (max value). One clip factor per subchannel group. + best_subchannel_clip_factors = jnp.ones(clip_shape, dtype=jnp.float32) # Start with the worst possible SNR value of zeros, essentially representing - # infinite noise. This will be updated as we search through the scale - # values. - max_snr_values = jnp.zeros((*scales_shape,), dtype=jnp.float32) + # infinite noise. This will be updated as we search through the clip + # factors. + max_snr_values = jnp.zeros(clip_shape, dtype=jnp.float32) abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) - # Iteratively search through the scale value search space, identifying and - # updating the best SNR and corresponding scale value for each subchannel + # Iteratively search through the clip factors search space, identifying and + # updating the best SNR and corresponding clip factor for each subchannel # group. Essentially it is performing the "find the max value" for each - # subgroup in O(num_scales) time. - for scale in self.auto_scale_search_config: - # Replace the new highest SNR values and corresponding scale values - # after evaluating for `scale`. - best_scale_values, max_snr_values = self._update_best_scale_and_max_snr( - best_scale_values, - max_snr_values, - scale, - x, - abs_max, - shared_axes, - context, + # subgroup in O(auto_clip_search_config) time. + for clip_factor in self.auto_clip_search_config: + # Replace the new highest SNR values and corresponding clip factors + # after evaluating for `clip`. + best_subchannel_clip_factors, max_snr_values = ( + self._update_best_clip_factors_and_max_snr( + best_subchannel_clip_factors, + max_snr_values, + clip_factor, + x, + abs_max, + shared_axes, + numerics_, + context, + ) ) - # TODO(b/339746869): Generate a simple report for the scale distribution. - best_abs_max = abs_max * best_scale_values - return best_abs_max.astype(x.dtype) + # TODO(b/339746869): Generate a simple report for the clip distribution. + bound = abs_max * best_subchannel_clip_factors + scale = bound / numerics_.get_scaled_bound() + scale = ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] - def _update_best_scale_and_max_snr( + def _update_best_clip_factors_and_max_snr( self, - current_scale_values: jnp.ndarray, + current_clip_factors: jnp.ndarray, current_snr_values: jnp.ndarray, - scale: float, + clip_factor: float, x: jnp.ndarray, abs_max: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx], + numerics_: numerics.AqtNumerics, context: utils.Context, ) -> tuple[jnp.ndarray, jnp.ndarray]: - """Updates the best scale and max SNR values given a `scale` value. + """Updates the best clip factors and max SNR values given a `clip_factor`. - Given a `scale` value, this function calculates the SNR value for each + Given a `clip_factor`, this function calculates the SNR value for each subchannel group. It then identifies the subchannel groups that have higher - SNR values than `current_snr_values` and updates the best scale and max SNR - values for those groups. + SNR values than `current_snr_values` and updates the best clip factors and + max SNR values for those groups. - `current_scale_values`, `current_snr_values`, and `abs_max` are expected to + `current_clip_factors`, `current_snr_values`, and `abs_max` are expected to have the same shape, which is the same as `x.shape` but with `shared_axes` collapsed to 1. @@ -243,41 +296,47 @@ def _update_best_scale_and_max_snr( the SNR values for each subchannel group. Args: - current_scale_values: The current best scale values for each subchannel + current_clip_factors: The current best clip factors for each subchannel group. current_snr_values: The current best SNR values for each subchannel group. - scale: The scale value to be evaluated. + clip_factor: The clip factor to be evaluated. x: The input tensor. abs_max: The absolute max value for each subchannel group. shared_axes: Axes that each subchannel group is shared across. + numerics_: An `AqtNumerics` object containing information regarding + quantization such as target dtype. Also used to actually quantize (round + and clip) the tensor when calculating the SNR values. context: The quantization context. Returns: - The (updated best scale values, updated best SNR values) tuple. + The (updated best clip factors, updated best SNR values) tuple. """ - # Note that all subchannel groups are scaled by the same candidate scale - # value. - scaled_abs_max = abs_max * scale - scaled_abs_max = scaled_abs_max.astype(x.dtype) + # Note that all subchannel groups are clipped by the same candidate clip + # factor. + clipped_abs_max = abs_max * clip_factor + clipped_abs_max = clipped_abs_max.astype(x.dtype) - snr_values = self._calculate_snr(x, scaled_abs_max, shared_axes, context) + snr_values = self._calculate_snr( + x, clipped_abs_max, shared_axes, numerics_, context + ) - # Update the best scale values and SNR values for subchannel groups that + # Update the best clipping factors and SNR values for subchannel groups that # have higher SNR values. - updated_scale_values = jnp.where( + updated_clip_factors = jnp.where( snr_values > current_snr_values, - scale, - current_scale_values, + clip_factor, + current_clip_factors, ) updated_snr_values = jnp.maximum(snr_values, current_snr_values) - return updated_scale_values, updated_snr_values + return updated_clip_factors, updated_snr_values def _calculate_snr( self, x: jnp.ndarray, bound: jnp.ndarray, shared_axes: Sequence[utils.AxisIdx], + numerics_: numerics.AqtNumerics, context: utils.Context, ) -> jnp.ndarray: """Calculates the quantization signal-to-noise ratio (SNR) for the given bound. @@ -293,21 +352,27 @@ def _calculate_snr( shared_axes: Axes that each subchannel group is shared across. SNR values will be calculated for each dimension in `x.shape` except the shared axes. + numerics_: An `AqtNumerics` object containing information regarding + quantization such as target dtype. Also used to actually quantize (round + and clip) the tensor when calculating the SNR values. context: The quantization context. Returns: The SNR tensor containing the SNR values for each subchannel group. Its shape will be the same as `x.shape` but with `shared_axes` collapsed to 1. """ - abs_max_mapped_to = self.numerics.abs_val_mapped_to() - scale = bound / abs_max_mapped_to + scale = bound / numerics_.get_scaled_bound() q_tensor = aqt_tensor.QTensor( - qvalue=None, scale=[scale], scale_t=None, bias=[], dequant_dtype=x.dtype + qvalue=None, + scale=[scale], + scale_t=None, + bias=[], + dequant_dtype=x.dtype, ).quant(x) # This actually quantizes the tensor (clips, rounds, etc). - quantized_tensor, _ = self.numerics.vjp_fwd(q_tensor.qvalue, context) + quantized_tensor, _ = numerics_.vjp_fwd(q_tensor.qvalue, context) q_tensor = q_tensor.replace(qvalue=quantized_tensor) dequantized_tensor = q_tensor.dequant() @@ -318,3 +383,46 @@ def _calculate_snr( snr = jnp.log(1 + signal / noise) return snr + + +@utils.flax_slots_kw_only_dataclass +class MinMaxCalibration(Calibration): + """Calibration between the min and max values. + + Attributes: + eps: Optional epsilon to add to the bound to avoid division by zero. + """ + + eps: float | None = None + + def get_scale_and_bias( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, + context: utils.Context | None = None, + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + del context + msg = ( + 'Perhaps you are using DequantMode.THIS_INPUT (fake_quant) and forgot' + ' to set them.' + ) + assert shared_axes is not None, msg + if not isinstance(numerics_, int_numerics.AsymIntNumerics): + raise NotImplementedError( + f'MinMaxCalibration only supports AsymIntNumerics, but got {numerics}' + ) + dtype = self.dtype if self.dtype is not None else x.dtype + + x_min = jnp.min(x, axis=shared_axes, keepdims=True) + x_max = jnp.max(x, axis=shared_axes, keepdims=True) + bound = x_max - x_min + if self.eps is not None: + bound += self.eps + scale = bound / numerics_.get_scaled_bound() + + # Calculate bias s.t. quant(min(x)) = (min(x) + bias) / scale = quant_min. + quant_min, _ = numerics_.get_quant_range() + bias = quant_min * scale - x_min + + return [scale.astype(dtype)], [bias.astype(dtype)] diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index a8564bd1..8f13dc0c 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -87,16 +87,35 @@ def set_dg_raw_context(cfg_raw: DotGeneralRaw, key: Optional[jax.Array]): return ret_cfg -def set_fwd_dequant_mode( - cfg: DotGeneral, +def set_dequant_mode( + cfg: DotGeneralRaw, *, lhs_dequant_mode: Optional[DequantMode] = None, rhs_dequant_mode: Optional[DequantMode] = None, ): + """Sets the dequant mode for the lhs and rhs of a single dot general.""" if lhs_dequant_mode is not None: - cfg.fwd.lhs.dequant_mode = lhs_dequant_mode + cfg.lhs.dequant_mode = lhs_dequant_mode if rhs_dequant_mode is not None: - cfg.fwd.rhs.dequant_mode = rhs_dequant_mode + cfg.rhs.dequant_mode = rhs_dequant_mode + + fake_quant = DequantMode.THIS_INPUT in [lhs_dequant_mode, rhs_dequant_mode] + if fake_quant and jnp.issubdtype(cfg.dg_accumulator_dtype, jnp.integer): + # Fake-quantization is not compatible with integer accumulation. + cfg.dg_accumulator_dtype = None + + +def set_fwd_dequant_mode( + cfg: DotGeneral, + *, + lhs_dequant_mode: Optional[DequantMode] = None, + rhs_dequant_mode: Optional[DequantMode] = None, +): + set_dequant_mode( + cfg.fwd, + lhs_dequant_mode=lhs_dequant_mode, + rhs_dequant_mode=rhs_dequant_mode, + ) def set_fwd_calibration_mode( @@ -142,7 +161,9 @@ def set_fwd_rhs_dtype_int2(cfg: DotGeneral): # of 128, we use this setter to enable int2 dtype. # Remove this setter and enable int2 in utils.infer_dtype_from_bits() # when XLA supports general int2 casting. - assert isinstance(cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics) + assert isinstance( + cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics + ) assert cfg.fwd.dg_quantizer.rhs.numerics.bits == 2 # Disable pytype check since jnp.int2 is only dynamically to jax # when ml_dtypes package has it. @@ -276,7 +297,7 @@ def set_int_numerics_preserve_zero(cfg: DotGeneral, preserve_zero: bool): for dot_general_raw in [cfg.fwd, cfg.dlhs, cfg.drhs]: dg_quantizer = dot_general_raw.dg_quantizer for q_numerics in [dg_quantizer.lhs.numerics, dg_quantizer.rhs.numerics]: - if isinstance(q_numerics, int_numerics.IntNumerics): + if isinstance(q_numerics, int_numerics.SymIntNumerics): q_numerics.preserve_zero = preserve_zero updated_dtype = ( utils.infer_dtype_from_bits(q_numerics.bits) # pytype: disable=attribute-error @@ -286,10 +307,31 @@ def set_int_numerics_preserve_zero(cfg: DotGeneral, preserve_zero: bool): q_numerics.dtype = updated_dtype -def set_auto_calib_scale( - cfg: DotGeneral, auto_scale_search_config: utils.AutoScaleSearchConfig +def set_quantizer_calibration_config( + quantizer: aqt_quantizer.Quantizer, + *, + new_calibration_cls: type[calibration.Calibration] | None = None, + **config_kwargs, +): + """Updates the calibration config of a quantizer.""" + assert isinstance(quantizer, aqt_quantizer.Quantizer), type(quantizer) + + old_calibration_cls = quantizer.calibration + keywords = {} + if isinstance(old_calibration_cls, functools.partial): + keywords = old_calibration_cls.keywords + old_calibration_cls = old_calibration_cls.func + keywords.update(config_kwargs) + + if new_calibration_cls is None: + new_calibration_cls = old_calibration_cls + quantizer.calibration = functools.partial(new_calibration_cls, **keywords) + + +def set_auto_calib_clipping_config( + cfg: DotGeneral, auto_clip_search_config: utils.AutoClipSearchConfig ) -> None: - """Update `cfg`'s quantizers' calibration to use auto scale search. + """Update `cfg`'s quantizers' calibration to use auto clipping search. Currently only supports the weights (rhs) of `DotGeneral`, since the iterative process of finding the scale tensors might be too slow for the activations @@ -297,7 +339,7 @@ def set_auto_calib_scale( Args: cfg: The config to be updated. - auto_scale_search_config: The config for auto scale search. + auto_clip_search_config: The config for auto clipping search. """ assert isinstance( cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer @@ -310,17 +352,15 @@ def set_auto_calib_scale( ) for dot_general_raw in [cfg.fwd, cfg.dlhs, cfg.drhs]: - dg_quantizer = dot_general_raw.dg_quantizer - dg_rhs_quantizer = dg_quantizer.rhs - dg_rhs_quantizer.calibration = functools.partial( - calibration.SnrBasedAutoCalibration, - numerics=dg_rhs_quantizer.numerics, - auto_scale_search_config=auto_scale_search_config, + set_quantizer_calibration_config( + dot_general_raw.dg_quantizer.rhs, + new_calibration_cls=calibration.SnrBasedAutoCalibration, + auto_clip_search_config=auto_clip_search_config, ) -def set_absmax_calib_scale(cfg: DotGeneral, scale: float): - """Set AbsMaxCalibration scale and update clip_gradient accordingly.""" +def set_absmax_calib_clipping_factor(cfg: DotGeneral, clipping_factor: float): + """Set AbsMaxCalibration clipping_factor and clip_gradient accordingly.""" assert isinstance( cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer ) @@ -338,14 +378,15 @@ def set_absmax_calib_scale(cfg: DotGeneral, scale: float): if isinstance(calibration_cls, functools.partial): calibration_cls = calibration_cls.func assert calibration_cls == calibration.AbsMaxCalibration, ( - 'scale is only available in AbsMaxCalibration, while' + 'clipping_factor is only available in AbsMaxCalibration, while' f' {quantizer.calibration} is used in current config.' ) - quantizer.calibration = functools.partial( - calibration.AbsMaxCalibration, scale=scale + + set_quantizer_calibration_config( + quantizer, clipping_factor=clipping_factor ) - if scale < 1.0 and isinstance( - quantizer.numerics, int_numerics.IntNumerics + if clipping_factor < 1.0 and isinstance( + quantizer.numerics, int_numerics.SymIntNumerics ): quantizer.numerics.clip_gradient = True @@ -374,7 +415,7 @@ def get_numerics(bits): else: pz = False if bits == 1 else True dtype = utils.infer_dtype_from_bits(bits) if pz else None - effective_numerics = int_numerics.IntNumerics( + effective_numerics = int_numerics.SymIntNumerics( bits=bits, preserve_zero=pz, preserve_max_val=False, @@ -399,8 +440,62 @@ def get_numerics(bits): return cfg -def set_scale_dtype(cfg: DotGeneral, scale_dtype: jnp.dtype): - """Set the dtype for all scales in the given DotGeneral config.""" +def _get_asym_numerics(numerics_: numerics.AqtNumerics): + """Gets the asymmetric equivalent of the given numerics.""" + if isinstance( + numerics_, (int_numerics.SymIntNumerics, int_numerics.AsymIntNumerics) + ): + # pytype: disable=attribute-error + return int_numerics.AsymIntNumerics( + bits=numerics_.bits, + clip=numerics_.clip, + clip_gradient=numerics_.clip_gradient, + round=numerics_.round, + noise_fn=numerics_.noise_fn, + dtype=numerics_.dtype, + ) + # pytype: enable=attribute-error + elif isinstance(numerics_, no_numerics.NoNumerics): + return numerics_ + else: + raise NotImplementedError( + 'Asymmetric quantization currently only supports integer numerics,' + f' but got {numerics_}' + ) + + +def _set_asymmetric_quantization(cfg: DotGeneralRaw): + """Replaces symmetric quantization with asymmetric quantization.""" + set_numerics( + cfg, + _get_asym_numerics(cfg.dg_quantizer.lhs.numerics), + _get_asym_numerics(cfg.dg_quantizer.rhs.numerics), + ) + + set_quantizer_calibration_config( + cfg.dg_quantizer.lhs, new_calibration_cls=calibration.MinMaxCalibration + ) + set_quantizer_calibration_config( + cfg.dg_quantizer.rhs, new_calibration_cls=calibration.MinMaxCalibration + ) + + # Only fake quantization currently supports quantization biases. + set_dequant_mode( + cfg, + lhs_dequant_mode=DequantMode.THIS_INPUT, + rhs_dequant_mode=DequantMode.THIS_INPUT, + ) + + +def set_asymmetric_quantization(cfg: DotGeneral): + """Replaces symmetric quantization with asymmetric quantization.""" + _set_asymmetric_quantization(cfg.fwd) + _set_asymmetric_quantization(cfg.dlhs) + _set_asymmetric_quantization(cfg.drhs) + + +def set_scale_and_bias_dtype(cfg: DotGeneral, dtype: jnp.dtype): + """Set the dtype for all scales and biases in the given DotGeneral config.""" assert isinstance( cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer ) @@ -410,12 +505,12 @@ def set_scale_dtype(cfg: DotGeneral, scale_dtype: jnp.dtype): assert isinstance( cfg.drhs.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer ) - cfg.fwd.dg_quantizer.lhs.scale_dtype = scale_dtype - cfg.fwd.dg_quantizer.rhs.scale_dtype = scale_dtype - cfg.dlhs.dg_quantizer.lhs.scale_dtype = scale_dtype - cfg.dlhs.dg_quantizer.rhs.scale_dtype = scale_dtype - cfg.drhs.dg_quantizer.lhs.scale_dtype = scale_dtype - cfg.drhs.dg_quantizer.rhs.scale_dtype = scale_dtype + set_quantizer_calibration_config(cfg.fwd.dg_quantizer.lhs, dtype=dtype) + set_quantizer_calibration_config(cfg.fwd.dg_quantizer.rhs, dtype=dtype) + set_quantizer_calibration_config(cfg.dlhs.dg_quantizer.lhs, dtype=dtype) + set_quantizer_calibration_config(cfg.dlhs.dg_quantizer.rhs, dtype=dtype) + set_quantizer_calibration_config(cfg.drhs.dg_quantizer.lhs, dtype=dtype) + set_quantizer_calibration_config(cfg.drhs.dg_quantizer.rhs, dtype=dtype) ################################################################################ @@ -439,7 +534,6 @@ def quantizer() -> aqt_quantizer.Quantizer: calib_shared_axes=None, scale_stop_grad=True, calibration=calibration.AbsMaxCalibration, - po2_scale=False, context=utils.Context(key=None, train_step=None), ) diff --git a/aqt/jax/v2/config_test.py b/aqt/jax/v2/config_test.py index 2615c2aa..6620b5df 100644 --- a/aqt/jax/v2/config_test.py +++ b/aqt/jax/v2/config_test.py @@ -20,6 +20,15 @@ import jax.numpy as jnp +def _dot_general_full_init_calibration(cfg): + cfg.fwd.dg_quantizer.lhs.init_calibration() + cfg.fwd.dg_quantizer.rhs.init_calibration() + cfg.dlhs.dg_quantizer.lhs.init_calibration() + cfg.dlhs.dg_quantizer.rhs.init_calibration() + cfg.drhs.dg_quantizer.lhs.init_calibration() + cfg.drhs.dg_quantizer.rhs.init_calibration() + + class AqtConfigTest(parameterized.TestCase): def _retrieve_quantizers(self, dot_general_raws): @@ -48,37 +57,37 @@ def test_config_v4(self): rhs=Tensor(use_fwd_quant=False, dequant_mode=, calibration_mode=), - dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), - rhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + rhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -93,37 +102,37 @@ def test_config_v4(self): rhs=Tensor(use_fwd_quant=False, dequant_mode=, calibration_mode=), - dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=IntNumerics(bits=7, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=RandomCenteredUniform(), - dtype=), + dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=SymIntNumerics(bits=7, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=RandomCenteredUniform(), + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), - rhs=Quantizer(numerics=IntNumerics(bits=7, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + rhs=Quantizer(numerics=SymIntNumerics(bits=7, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -140,37 +149,37 @@ def test_config_v4(self): rhs=Tensor(use_fwd_quant=False, dequant_mode=, calibration_mode=), - dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=IntNumerics(bits=6, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=RandomCenteredUniform(), - dtype=), + dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=SymIntNumerics(bits=6, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=RandomCenteredUniform(), + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), - rhs=Quantizer(numerics=IntNumerics(bits=6, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + rhs=Quantizer(numerics=SymIntNumerics(bits=6, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -182,6 +191,7 @@ def test_config_v4(self): allow_dummy_gradient_into_qtensor=False, dot_general=), apply_custom_vjp_on_jax=True)""" + _dot_general_full_init_calibration(cfg) utils.test_pprint_eq(cfg, expected_cfg_str, remove_memory_addresses=True) def test_configv4_original(self): @@ -191,37 +201,37 @@ def test_configv4_original(self): rhs=Tensor(use_fwd_quant=False, dequant_mode=, calibration_mode=), - dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), - rhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + rhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -236,37 +246,37 @@ def test_configv4_original(self): rhs=Tensor(use_fwd_quant=False, dequant_mode=, calibration_mode=), - dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=JaxUniform(), - dtype=), + dg_quantizer=DefaultDotGeneralQuantizer(lhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=JaxUniform(), + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), - rhs=Quantizer(numerics=IntNumerics(bits=8, - preserve_zero=True, - preserve_max_val=False, - clip=True, - clip_gradient=False, - round=True, - noise_fn=None, - dtype=), + rhs=Quantizer(numerics=SymIntNumerics(bits=8, + preserve_zero=True, + preserve_max_val=False, + clip=True, + clip_gradient=False, + round=True, + noise_fn=None, + dtype=), calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -286,9 +296,9 @@ def test_configv4_original(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), @@ -297,9 +307,9 @@ def test_configv4_original(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -309,9 +319,9 @@ def test_configv4_original(self): allow_dummy_gradient_into_qtensor=False, dot_general=), apply_custom_vjp_on_jax=True)""" - utils.test_pprint_eq( - config.config_v4(), expected_cfg_str, remove_memory_addresses=True - ) + cfg = config.config_v4() + _dot_general_full_init_calibration(cfg) + utils.test_pprint_eq(cfg, expected_cfg_str, remove_memory_addresses=True) def test_config_fwd_fp8(self): expected_cfg = """DotGeneral(fwd=DotGeneralRaw(lhs=Tensor(use_fwd_quant=False, @@ -327,9 +337,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), @@ -340,9 +350,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -362,9 +372,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), @@ -373,9 +383,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -395,9 +405,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=)), @@ -406,9 +416,9 @@ def test_config_fwd_fp8(self): calib_shared_axes=None, scale_stop_grad=True, calibration=, - _calibrator=None, - po2_scale=False, - scale_dtype=None, + _calibrator=AbsMaxCalibration(dtype=None, + po2_scale=False, + clipping_factor=None), context=Context(key=None, train_step=None, quant_mode=))), @@ -418,9 +428,9 @@ def test_config_fwd_fp8(self): allow_dummy_gradient_into_qtensor=False, dot_general=), apply_custom_vjp_on_jax=True)""" - utils.test_pprint_eq( - config.config_fwd_fp8(), expected_cfg, remove_memory_addresses=True - ) + cfg = config.config_fwd_fp8() + _dot_general_full_init_calibration(cfg) + utils.test_pprint_eq(cfg, expected_cfg, remove_memory_addresses=True) def test_set_int_numerics_preserve_zero(self): cfg = config.config_v4() @@ -436,19 +446,19 @@ def test_set_int_numerics_preserve_zero(self): def test_set_absmax_calib_scale(self): cfg = config.config_v4() for quantizer in self._retrieve_quantizers([cfg.fwd, cfg.dlhs, cfg.drhs]): - self.assertIsNone(quantizer.calibration().scale) + self.assertIsNone(quantizer.calibration().clipping_factor) for quantizer in self._retrieve_quantizers([cfg.fwd, cfg.dlhs]): self.assertFalse(quantizer.numerics.clip_gradient) - config.set_absmax_calib_scale(cfg, scale=3) + config.set_absmax_calib_clipping_factor(cfg, clipping_factor=3) for quantizer in self._retrieve_quantizers([cfg.fwd, cfg.dlhs, cfg.drhs]): - self.assertAlmostEqual(quantizer.calibration().scale, 3) + self.assertAlmostEqual(quantizer.calibration().clipping_factor, 3) for quantizer in self._retrieve_quantizers([cfg.fwd, cfg.dlhs]): self.assertFalse(quantizer.numerics.clip_gradient) - config.set_absmax_calib_scale(cfg, scale=0.1) + config.set_absmax_calib_clipping_factor(cfg, clipping_factor=0.1) for quantizer in self._retrieve_quantizers([cfg.fwd, cfg.dlhs]): self.assertTrue(quantizer.numerics.clip_gradient) diff --git a/aqt/jax/v2/examples/flax_e2e_model_test.py b/aqt/jax/v2/examples/flax_e2e_model_test.py index 324e0167..5cadcff7 100644 --- a/aqt/jax/v2/examples/flax_e2e_model_test.py +++ b/aqt/jax/v2/examples/flax_e2e_model_test.py @@ -50,6 +50,7 @@ class MnistTest(parameterized.TestCase): "drhs_accumulator_dtype": jnp.int32, # overwrite the default None }, 8, + False, ), ( { @@ -58,34 +59,78 @@ class MnistTest(parameterized.TestCase): "dlhs_accumulator_dtype": None, }, 4, + False, + ), + ( + { + "fwd_bits": 2, + "dlhs_bits": 2, + }, + 2, + False, + ), + ( + { + "fwd_bits": 2, + "dlhs_bits": 2, + }, + 2, + True, ), ]) - def test_mnist_training(self, configs, bits): + def test_mnist_training(self, configs, bits, use_asymmetric=False): aqt_cfg = config.config_v4(**configs) + if use_asymmetric: + config.set_asymmetric_quantization(aqt_cfg) target_loss = { - 8: { - "cpu": [ - 3.122317314147949218750000000000, - 3.122316360473632812500000000000, - 3.122316837310791015625000000000, # colab - ], - "TPU v2": [3.198328018188476562500000000000], - "TPU v3": [3.198328018188476562500000000000], - "TPU v4": [3.198297500610351562500000000000], - "TPU v5 lite": [3.200393676757812500000000000000], + False: { # use_asymmetric + 8: { # bits + "cpu": [ + 3.122317314147949218750000000000, + 3.122316360473632812500000000000, + 3.122316837310791015625000000000, # colab + ], + "TPU v2": [3.198328018188476562500000000000], + "TPU v3": [3.198328018188476562500000000000], + "TPU v4": [3.198297500610351562500000000000], + "TPU v5 lite": [3.200393676757812500000000000000], + }, + 4: { + "cpu": [2.258865118026733398437500000000], + "TPU v2": [2.302409172058105468750000000000], + "TPU v3": [2.302409172058105468750000000000], + "TPU v4": [2.302409172058105468750000000000], + "TPU v5 lite": [2.302415609359741210937500000000], + }, + 2: { + "cpu": [ + 2.067147254943847656250000000000, + 2.067147493362426757812500000000, + ], + "TPU v2": [2.052407503128051757812500000000], + "TPU v3": [2.052407503128051757812500000000], + "TPU v4": [2.052407741546630859375000000000], + "TPU v5 lite": [2.054144620895385742187500000000], + }, }, - 4: { - "cpu": [2.258865118026733398437500000000], - "TPU v2": [2.302409172058105468750000000000], - "TPU v3": [2.302409172058105468750000000000], - "TPU v4": [2.302409172058105468750000000000], - "TPU v5 lite": [2.302415609359741210937500000000], + True: { + 2: { + "cpu": [ + 3.539640426635742187500000000000, + 3.539642810821533203125000000000, + ], + "TPU v2": [2.984576702117919921875000000000], + "TPU v3": [2.984576702117919921875000000000], + "TPU v4": [2.984576702117919921875000000000], + "TPU v5 lite": [2.982401847839355468750000000000], + }, }, } # below 3 lines are differences between config_v4/v3 and fully_quantized config.set_stochastic_rounding(aqt_cfg, True, True, "jax.uniform") - aqt_cfg.dlhs.rhs.use_fwd_quant = True - aqt_cfg.drhs.rhs.use_fwd_quant = True + if not use_asymmetric: + aqt_cfg.dlhs.rhs.use_fwd_quant = True + aqt_cfg.drhs.rhs.use_fwd_quant = True def forward(model, apply_fn): return apply_fn( @@ -115,16 +160,21 @@ def forward(model, apply_fn): ) device_kind = jax.devices()[0].device_kind - expected_train_loss = target_loss[bits][device_kind] + expected_train_loss = target_loss[use_asymmetric][bits][device_kind] if train_loss not in expected_train_loss: - msg = "train_loss changed. Consider updating with the following:\n" - msg += f' "{device_kind}": [{train_loss:.30f}]' + msg = ( + "train_loss changed. Consider updating with the following:\n" + f' "{device_kind}": [{train_loss:.30f}]\n' + f" expected one of: {expected_train_loss}" + ) self.fail(msg) # Run forward once more in the same mode to get logits for testing below. logits_s1, _ = forward(state.model, state.cnn_eval.apply) # Stage 2: Model conversion (quantized weights freezing) + # if use_asymmetric: + # return # Exit early out of the serving tests. apply_serving, model_serving = flax_e2e_model.serving_conversion(state) @@ -224,6 +274,15 @@ def forward(model, apply_fn): }, } + if use_asymmetric: + for qtensor in [ + expected_aqt_pytree["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qlhs"], + expected_aqt_pytree["aqt"]["Dense_0"]["AqtDotGeneral_0"]["qrhs"], + expected_aqt_pytree["aqt"]["Dense_1"]["AqtDotGeneral_0"]["qrhs"], + ]: + # Bias has the same shape and dtype as the scale. + qtensor["frozen"].bias = qtensor["frozen"].scale + serving_pytree = jax.tree_util.tree_map( lambda x: (x.dtype, x.shape), model_serving ) diff --git a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py index 58cf6ad7..21c00ba5 100644 --- a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py +++ b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model.py @@ -42,10 +42,10 @@ def update_cfg_with_gptq(aqt_cfg: aqt_dot_general.DotGeneral) -> None: aqt_cfg.fwd.dg_quantizer, aqt_dot_general.DefaultDotGeneralQuantizer ) assert isinstance( - aqt_cfg.fwd.dg_quantizer.lhs.numerics, int_numerics.IntNumerics + aqt_cfg.fwd.dg_quantizer.lhs.numerics, int_numerics.SymIntNumerics ) assert isinstance( - aqt_cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.IntNumerics + aqt_cfg.fwd.dg_quantizer.rhs.numerics, int_numerics.SymIntNumerics ) lhs_bits = aqt_cfg.fwd.dg_quantizer.lhs.numerics.bits rhs_bits = aqt_cfg.fwd.dg_quantizer.rhs.numerics.bits diff --git a/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py b/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py index 510730c1..a8f23d29 100644 --- a/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py +++ b/aqt/jax/v2/extensions/gptq/gptq_dot_general_quantizer.py @@ -270,11 +270,11 @@ def calibrate( # Follow the quantization mode and num_bits of the kernel. if self.is_rhs_kernel: quant_mode = _get_quant_mode(self.rhs.context) - assert isinstance(self.rhs.numerics, int_numerics.IntNumerics) + assert isinstance(self.rhs.numerics, int_numerics.SymIntNumerics) num_bits = self.rhs.numerics.bits else: quant_mode = _get_quant_mode(self.lhs.context) - assert isinstance(self.lhs.numerics, int_numerics.IntNumerics) + assert isinstance(self.lhs.numerics, int_numerics.SymIntNumerics) num_bits = self.lhs.numerics.bits if quant_mode == utils.QuantMode.TRAIN: diff --git a/aqt/jax/v2/flax/aqt_flax_calibration.py b/aqt/jax/v2/flax/aqt_flax_calibration.py index dd6b2e79..0d5ff269 100644 --- a/aqt/jax/v2/flax/aqt_flax_calibration.py +++ b/aqt/jax/v2/flax/aqt_flax_calibration.py @@ -17,6 +17,7 @@ from aqt.jax.v2 import calibration from aqt.jax.v2 import utils +from aqt.jax.v2.numerics import numerics import flax.linen as nn from jax import numpy as jnp @@ -73,6 +74,19 @@ def get_bound( # Maybe wait for the JAX language upgrade to have a better support for this? return sum_of_max.value / count.value + def get_scale_and_bias( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, + context: utils.Context | None = None, + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + dtype = self.dtype if self.dtype is not None else x.dtype + bound = self.get_bound(x, shared_axes, context) + scale = bound / numerics_.get_scaled_bound() + scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] + # TODO: b/335764538 - Check the math correctness of the module. class WeightedStatsCalibration(calibration.Calibration, nn.Module): @@ -222,3 +236,16 @@ def init_var_fn(init_val: float) -> jnp.ndarray: + self.max_dev_coeff * self._max_dev() + self.const_bound_coeff ) + + def get_scale_and_bias( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, + context: utils.Context | None = None, + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + dtype = self.dtype if self.dtype is not None else x.dtype + bound = self.get_bound(x, shared_axes, context) + scale = bound / numerics_.get_scaled_bound() + scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] diff --git a/aqt/jax/v2/flax/delayed_scaling_calibration.py b/aqt/jax/v2/flax/delayed_scaling_calibration.py index 0f69ac5d..7f9b3b98 100644 --- a/aqt/jax/v2/flax/delayed_scaling_calibration.py +++ b/aqt/jax/v2/flax/delayed_scaling_calibration.py @@ -16,6 +16,7 @@ from aqt.jax.v2 import calibration from aqt.jax.v2 import utils +from aqt.jax.v2.numerics import numerics import flax.linen as nn import jax from jax import numpy as jnp @@ -59,6 +60,7 @@ def get_bound( shared_axes: Sequence[utils.AxisIdx] | None, context: utils.Context | None = None, ) -> jnp.ndarray: + del shared_axes # Right now we just support per_tensor calibration (i.e. one value). # To support per_axis calibration, we would need to be able to change the # shape of the mutable arrays. For example, right now amax_history has @@ -89,6 +91,19 @@ def get_bound( amax_history_mutable_arr[:] = new_history[:] return new_bound.reshape((1,) * len(x.shape)) + def get_scale_and_bias( + self, + x: jnp.ndarray, + shared_axes: Sequence[utils.AxisIdx] | None, + numerics_: numerics.AqtNumerics, + context: utils.Context | None = None, + ) -> tuple[list[jnp.ndarray], list[jnp.ndarray]]: + dtype = self.dtype if self.dtype is not None else x.dtype + bound = self.get_bound(x, shared_axes, context) + scale = bound / numerics_.get_scaled_bound() + scale = calibration.ceil_to_po2(scale) if self.po2_scale else scale + return [scale.astype(dtype)], [] + def compute_bound(self, amax, prev_bound): new_bound = jnp.copy(amax) new_bound = jnp.where(amax > 0.0, new_bound, prev_bound) diff --git a/aqt/jax/v2/numerics/fp8_numerics.py b/aqt/jax/v2/numerics/fp8_numerics.py index 25476bd6..ba528f13 100644 --- a/aqt/jax/v2/numerics/fp8_numerics.py +++ b/aqt/jax/v2/numerics/fp8_numerics.py @@ -73,9 +73,12 @@ def _get_edge_of_last_fp8_bucket(self): def get_dtype(self): return self.dtype - def abs_val_mapped_to(self): + def get_scaled_bound(self): return self._get_edge_of_last_fp8_bucket() + def get_quant_range(self): + return -self.get_scaled_bound(), self.get_scaled_bound() + def vjp_fwd(self, x, context): match self.dtype: case jnp.float8_e4m3fn: diff --git a/aqt/jax/v2/numerics/fp_numerics.py b/aqt/jax/v2/numerics/fp_numerics.py index 9eaba24c..d5e05ed7 100644 --- a/aqt/jax/v2/numerics/fp_numerics.py +++ b/aqt/jax/v2/numerics/fp_numerics.py @@ -389,9 +389,12 @@ class FpNumerics(numerics.AqtNumerics): stochastic_rounding: bool = utils.static_field(default=False) clip_gradient: bool = utils.static_field(default=False) - def abs_val_mapped_to(self): + def get_scaled_bound(self): return fp_largest_representable(cfg=self.cfg) + def get_quant_range(self): + return -self.get_scaled_bound(), self.get_scaled_bound() + def get_dtype(self): return jnp.bfloat16 @@ -409,6 +412,6 @@ def vjp_bwd(self, res, grad): ret = grad if self.clip_gradient: (x,) = res - clip_bound = self.abs_val_mapped_to() + clip_bound = self.get_scaled_bound() ret *= (-clip_bound <= x) * (x <= clip_bound) return (ret, None) diff --git a/aqt/jax/v2/numerics/int_numerics.py b/aqt/jax/v2/numerics/int_numerics.py index 4a3a2169..b273eecd 100644 --- a/aqt/jax/v2/numerics/int_numerics.py +++ b/aqt/jax/v2/numerics/int_numerics.py @@ -21,8 +21,30 @@ import jax.numpy as jnp +def _maybe_apply_noise_fn(x, noise_fn, context): + if noise_fn is None: + return x + + assert context.key is not None, ( + 'noise_fn is set, requesting stochastic rounding, but RNG was not ' + 'passed in Context.key' + ) + return (x + noise_fn(x.shape, context.key)).astype(x.dtype) + + +def _apply_gradient_of_clip(res, grad, lower_clip_bound, upper_clip_bound): + # Gradient of the clip function. + # For boundary values we will have full gradient. + # When using max(abs(x)) scaling, x is always in the interior, and the + # gradient clip is always 1. So, we can always set clip_gradient to false. + # However, other types of scaling may result in x being outside (i.e., there + # is clipping). In that case it may be desirable to make the gradient zero. + (x,) = res + return grad * (lower_clip_bound <= x) * (x <= upper_clip_bound) + + @utils.flax_slots_kw_only_dataclass -class IntNumerics(numerics.AqtNumerics): +class SymIntNumerics(numerics.AqtNumerics): """Numerics for int8, int4, binary, etc.""" bits: int @@ -58,12 +80,24 @@ def get_edge_of_last_int_bucket(self): def get_center_of_last_int_bucket(self): return self.get_edge_of_last_int_bucket() - 0.5 - def abs_val_mapped_to(self): + def get_scaled_bound(self): if self.preserve_max_val: return self.get_center_of_last_int_bucket() else: return self.get_edge_of_last_int_bucket() + def get_quant_range(self): + if self.clip and self.round: + # Quantization values are guaranteed to be within the restricted signed + # quantization range. + sint_max = 2.0 ** (self.bits - 1) - 1 + return -sint_max, sint_max + else: + # Values are guaranteed to be within [sint_min, sint_max + 1]. The range + # may be more restricted depending on the full configuration. + sint_min = -(2.0 ** (self.bits - 1)) + return sint_min, -sint_min + def _get_fwd_clip_bound(self): # If we are not rounding, we just clip to bucket edges. fwd_clip_bound = self.get_edge_of_last_int_bucket() @@ -84,13 +118,7 @@ def vjp_fwd(self, x, context): input_dtype = x.dtype assert self.bits <= 22, 'Too many bits, float32 has less precision.' - # Maybe noise - if self.noise_fn: - assert context.key is not None, ( - 'noise_fn is set, requestic stochastic rounding, but RNG was not ' - 'passed in Context.key' - ) - x = (x + self.noise_fn(x.shape, context.key)).astype(input_dtype) + x = _maybe_apply_noise_fn(x, self.noise_fn, context) if self.clip: fwd_clip_bound = self._get_fwd_clip_bound() @@ -111,15 +139,62 @@ def vjp_fwd(self, x, context): return x, res def vjp_bwd(self, res, grad): - # Gradient of the clip function. - # For boundary values we will have full gradient. - # When using abs(max(x)) scaling, x is always in the interior, and the - # gradient clip is always 1. So, we can always set clip_gradient to false. - # However, other types of scaling may result in x being outside (i.e., there - # is clipping). In that case it may be desirable to make the gradient zero. ret = grad if self.clip_gradient: - (x,) = res clip_bound = self._get_fwd_clip_bound() - ret *= (-clip_bound <= x) * (x <= clip_bound) + ret = _apply_gradient_of_clip(res, grad, -clip_bound, clip_bound) + return (ret, None) + + +@utils.flax_slots_kw_only_dataclass +class AsymIntNumerics(numerics.AqtNumerics): + """Base numerics for sint8, sint4, binary, etc.""" + + bits: int + clip: bool + clip_gradient: bool + round: bool + noise_fn: Optional[stochastic_rounding.NoiseFn] + dtype: Optional[Any] = None + + def get_dtype(self): + return self.dtype + + def get_scaled_bound(self): + return 2.0**self.bits - 1 + + def get_quant_range(self): + if self.bits > 1: + # Full signed int range. + sint_max = 2.0 ** (self.bits - 1) - 1 + sint_min = -(2.0 ** (self.bits - 1)) + return sint_min, sint_max + else: + # Boolean range. + return 0.0, 1.0 + + def vjp_fwd(self, x, context): + """Forward pass.""" + res = (x,) + input_dtype = x.dtype + return_dtype = self.dtype if self.dtype is not None else input_dtype + assert self.bits <= 22, 'Too many bits, float32 has less precision.' + + x = _maybe_apply_noise_fn(x, self.noise_fn, context) + + if self.clip: + lower_clip_bound, upper_clip_bound = self.get_quant_range() + x = jnp.clip(x, lower_clip_bound, upper_clip_bound) + + if self.round: + x = lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) + + # Maybe cast: return dtype is either int or the input dtype + x = x.astype(return_dtype) + return x, res + + def vjp_bwd(self, res, grad): + ret = grad + if self.clip_gradient: + ret = _apply_gradient_of_clip(res, grad, *self.get_quant_range()) return (ret, None) diff --git a/aqt/jax/v2/numerics/int_numerics_test.py b/aqt/jax/v2/numerics/int_numerics_test.py index 4a72c701..7e44d689 100644 --- a/aqt/jax/v2/numerics/int_numerics_test.py +++ b/aqt/jax/v2/numerics/int_numerics_test.py @@ -72,7 +72,7 @@ def test_quant_range( sint_min_restricted = sint_min + 1 def quantize(x): - numerics_ = int_numerics.IntNumerics( + numerics_ = int_numerics.SymIntNumerics( bits=bits, preserve_zero=preserve_zero, preserve_max_val=preserve_max_val, @@ -89,7 +89,6 @@ def quantize(x): calib_shared_axes=None, scale_stop_grad=True, calibration=calibration.AbsMaxCalibration, - po2_scale=False, context=utils.Context(key=None, train_step=None), ) q.init_calibration() diff --git a/aqt/jax/v2/numerics/no_numerics.py b/aqt/jax/v2/numerics/no_numerics.py index 60fa2fbb..bf8eafbc 100644 --- a/aqt/jax/v2/numerics/no_numerics.py +++ b/aqt/jax/v2/numerics/no_numerics.py @@ -33,7 +33,10 @@ class NoNumerics(numerics.AqtNumerics): def get_dtype(self): return None - def abs_val_mapped_to(self): + def get_scaled_bound(self): + pass + + def get_quant_range(self): pass def vjp_fwd(self, x, context): diff --git a/aqt/jax/v2/numerics/numerics.py b/aqt/jax/v2/numerics/numerics.py index a7cc83c9..4e1aac81 100644 --- a/aqt/jax/v2/numerics/numerics.py +++ b/aqt/jax/v2/numerics/numerics.py @@ -26,12 +26,15 @@ def get_dtype(self): pass @abc.abstractmethod - def abs_val_mapped_to(self): - """The value returned is the end of quantization range. + def get_scaled_bound(self): + """Returns the width that the scale corresponds to in the quantizion range. - It could be biggest value that can be represented by numerical format - exactly. E.g. in case of int8, 127 . Or it could be edge of the last bucket. - Edge in case of int8, 127.5 + For symmetric scaling (relative to a fixed zero point) it could be biggest + value that can be represented by numerical format exactly. E.g. in case of + int8, 127 . Or it could be edge of the last bucket (in case of int8, 127.5). + + For asymmetric scaling, it corresponds to the width of the entire + quantization range. E.g. in case of int8, 255. """ pass @@ -42,3 +45,8 @@ def vjp_fwd(self, x, context): @abc.abstractmethod def vjp_bwd(self, res, grad): pass + + @abc.abstractmethod + def get_quant_range(self): + """Returns the minimum and maximum values of the quantization range.""" + pass diff --git a/aqt/jax/v2/utils.py b/aqt/jax/v2/utils.py index ada7fe64..e3a45499 100644 --- a/aqt/jax/v2/utils.py +++ b/aqt/jax/v2/utils.py @@ -42,7 +42,7 @@ # Specifies the scale values to search for. Used with `SnrBasedAutoCalibration` # for auto scale search. -AutoScaleSearchConfig: TypeAlias = Sequence[float] +AutoClipSearchConfig: TypeAlias = Sequence[float] def assert_shape(shape: Sequence[int], shape_template: ShapeTemplate, msg: str):