From 7f4ee3a7db6540ff94e4fafb0855d259d217f83b Mon Sep 17 00:00:00 2001 From: Yichi Zhang Date: Wed, 13 Nov 2024 12:11:45 -0800 Subject: [PATCH] Enable fake quant in the bwd with local_aqt PiperOrigin-RevId: 696225640 --- aqt/jax/v2/aqt_dot_general.py | 40 +++++++++-- aqt/jax/v2/aqt_dot_general_test.py | 107 ++++++++++++++++++++--------- aqt/jax/v2/config.py | 18 +++++ 3 files changed, 128 insertions(+), 37 deletions(-) diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 404991db..0cf78bac 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -773,6 +773,9 @@ def __call__( lhs, rhs, dimension_numbers, self.rhs.use_fwd_quant ) assert isinstance(rhs, jnp.ndarray) + orig_lhs_shape = lhs.shape + orig_rhs_shape = rhs.shape + orig_dimension_numbers = dimension_numbers # TODO(lew): Define cutsom_vjp on tiled_dot_general and replace local_aqt. if self.local_aqt is not None: @@ -818,17 +821,29 @@ def __call__( # TODO(lew): mt.x above should be clipped for clipping calibrations out = _qtensor_dot_general( - lhs_qt, rhs_qt, dimension_numbers, self, jnp.promote_types(lhs, rhs) + lhs_qt, + rhs_qt, + dimension_numbers, + self, + jnp.promote_types(lhs, rhs), + orig_lhs_shape, + orig_rhs_shape, + orig_dimension_numbers, ) out = out.dequant() res = DotGeneralRes(lhs=lhs_res, rhs=rhs_res) if self.local_aqt is not None: - (lhs_ca, rhs_ca), _ = dimension_numbers - assert len(lhs_ca) == len(rhs_ca) - if len(lhs_ca) > 0: - out = jnp.sum(out, axis=0) + fq = ( + self.lhs.dequant_mode == DequantMode.THIS_INPUT + and self.rhs.dequant_mode == DequantMode.THIS_INPUT + ) + if not fq: # Fake quant don't require sum over tile count axis + (lhs_ca, rhs_ca), _ = dimension_numbers + assert len(lhs_ca) == len(rhs_ca) + if len(lhs_ca) > 0: + out = jnp.sum(out, axis=0) # We are not supporting local AQT in fwd pass, so no res needed. res = None return out, res @@ -841,6 +856,9 @@ def _qtensor_dot_general( cfg: Any, # dequant_dtype: DType, dequant_dtype: jnp.dtype, + orig_lhs_shape: tuple[int, ...], + orig_rhs_shape: tuple[int, ...], + orig_dimension_numbers: jax.lax.DotDimensionNumbers, ) -> aqt_tensor.QTensor: """QTensor lax.dot_general replacement.""" @@ -858,6 +876,18 @@ def _maybe_dequant( # Dequantize before the lax dg call if in fake quant mode lhs_qin = _maybe_dequant(lhs_qt, cfg.lhs) rhs_qin = _maybe_dequant(rhs_qt, cfg.rhs) + if ( + cfg.lhs.dequant_mode == DequantMode.THIS_INPUT + and cfg.rhs.dequant_mode == DequantMode.THIS_INPUT + and cfg.local_aqt is not None + ): + # Revert local_aqt reshaping and reuse the original dimension numbers. + # The correct order to compute fake quant is: + # reshape (local_aqt), quant, dequant, reshape back. + # Dims and original shape are therefore required to pass in qtensor_dg. + dimension_numbers = orig_dimension_numbers + lhs_qin = lhs_qin.reshape(orig_lhs_shape) + rhs_qin = rhs_qin.reshape(orig_rhs_shape) dtype_ms = ( f'Found {cfg.dg_accumulator_dtype=}, {lhs_qin.dtype=} and' diff --git a/aqt/jax/v2/aqt_dot_general_test.py b/aqt/jax/v2/aqt_dot_general_test.py index e79b260c..5e976ccb 100644 --- a/aqt/jax/v2/aqt_dot_general_test.py +++ b/aqt/jax/v2/aqt_dot_general_test.py @@ -34,23 +34,6 @@ import scipy.stats -def _apply_po2_scale(quantizer): - if quantizer.calibration is None: - return - - calibration_cls = quantizer.calibration - # TODO(lew): Remove partial inspection wherever possible. - # Partial inspection is needed because the current implementation of delayed - # calibration initialization requires the configuration to be set via - # functools.partial. - keywords = {} - if isinstance(calibration_cls, functools.partial): - keywords = calibration_cls.keywords - calibration_cls = calibration_cls.func - keywords.update(po2_scale=True) - quantizer.calibration = functools.partial(calibration_cls, **keywords) - - def test_jaxpr_dtype(f, dg_raws: list[aqt.DotGeneralRaw], float_dtype): """Tests whether dot_generals in f conform to dtypes inside of dg_raws.""" @@ -98,6 +81,13 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32): ) +def rand_int(shape, maxval, seed, dtype=int): + key = jax.random.PRNGKey(seed) + return jax.random.randint( + key=key, shape=shape, minval=-maxval, maxval=maxval, dtype=dtype + ) + + # The main test strategy is : # - Test that FQ is sensible (improve fq_noise_test) # - Quantization noise (rounding error) is of proper value @@ -253,8 +243,8 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True): # have the same numerics when scales are power of two (po2). # We are passing dims to config so that we can reuse it in fake_quant. # Power-of-2 scales allow FQ and AQT to be exactly the same. - _apply_po2_scale(c.dg_quantizer.lhs) - _apply_po2_scale(c.dg_quantizer.rhs) + config.set_quantizer_po2_scale(c.dg_quantizer.lhs) + config.set_quantizer_po2_scale(c.dg_quantizer.rhs) _apply_dequant_mode(c, lhs_dequant_mode, rhs_dequant_mode) _apply_calibration_mode(c, lhs_calibration_mode, rhs_calibration_mode) @@ -456,6 +446,41 @@ def assert_clt(noise: jnp.ndarray): custom_1_noise = noise_fn(shape, jax.random.PRNGKey(11)) assert_clt(custom_1_noise) + def test_sum(self): + # This is a test showing that the following will produce *different* results + # (1) Contract 2 axes using jax.lax.dot_general + # (2) First contract 1 axix using jax.lax.dot_general, then contract the + # other axis using jnp.sum. + # This means that fake quant subchannel will produce different results from + # the real quant subchannel. + key = jax.random.PRNGKey(0) + key1, key2 = jax.random.split(key) + dtype = jnp.float32 + lhs = jax.random.normal(key1, shape=(2, 3, 4), dtype=dtype) + rhs = jax.random.normal(key2, shape=(3, 4, 5), dtype=dtype) + + def product1(x, y): + return jax.lax.dot_general( + x, y, dimension_numbers=(((1, 2), (0, 1)), ((), ())) + ) + + def product2(x, y): + intermediate = jax.lax.dot_general( + x, y, dimension_numbers=(((1,), (0,)), ((2,), (1,))) + ) + return jnp.sum(intermediate, axis=0) + + def print_err(x, y): + print(f"{x=}") + print(f"{y=}") + mse = jnp.mean(jnp.square(x - y)) + print(f"{mse=}") + + out1 = product1(lhs, rhs) + out2 = product2(lhs, rhs) + print_err(out1, out2) + assert not (out1 == out2).all() + @parameterized.parameters([ dict(bits=1), ]) @@ -466,7 +491,7 @@ def test_fake_quant( shape=(20, 1), ): quantizer = config.quantizer_make(bits, initialize_calibration=False) - _apply_po2_scale(quantizer) + config.set_quantizer_po2_scale(quantizer) quantizer.init_calibration() quantizer.calib_shared_axes = (0,) x = jnp.linspace(-maxval, maxval, num=shape[0]).reshape(shape) @@ -516,8 +541,8 @@ def test_fake_quant( # can't keep in the product of int8*int8 accurately. # It just so happens that this test does not fail but others do. # We do this test anyway, to catch jax-compilation-time errors. - dict(dg=config.dot_general_make(2, 2), dtype=jnp.bfloat16), - dict(dg=config.dot_general_make(8, 8), dtype=jnp.bfloat16), + dict(dg=config.dot_general_make(2, 2), dtype=jnp.float32), + dict(dg=config.dot_general_make(8, 8), dtype=jnp.float32), dict(dg=config.dot_general_make(None, 8)), dict(dg=config.dot_general_make(8, None)), dict( @@ -529,6 +554,7 @@ def test_fake_quant( gra_shape=(4, 3, 6), ), dict( + # This tests fake subchannel implementation as well. dg=fqt_param_dict( s=10, use_fwd_quant=True, @@ -574,9 +600,10 @@ def test_dot_general_calibration_with_contracting_axis( readonly_dg = dg del dg - lhs = rand_unif(lhs_shape, lhs_maxval, seed, dtype) - rhs = rand_unif(rhs_shape, rhs_maxval, seed + 1, dtype) - gra = rand_unif(gra_shape, gra_maxval, seed + 2, dtype) + # Use integer inputs to avoid the mismatch between fake & real quant results + lhs = rand_int(lhs_shape, lhs_maxval, seed).astype(dtype) + rhs = rand_int(rhs_shape, rhs_maxval, seed + 1).astype(dtype) + gra = rand_int(gra_shape, gra_maxval, seed + 2).astype(dtype) # Prepare utility functions for test. aqt_dg_full = functools.partial( @@ -741,8 +768,8 @@ def lax_dg(lhs, rhs): # can't keep in the product of int8*int8 accurately. # It just so happens that this test does not fail but others do. # We do this test anyway, to catch jax-compilation-time errors. - dict(dg=config.dot_general_make(2, 2), dtype=jnp.bfloat16), - dict(dg=config.dot_general_make(8, 8), dtype=jnp.bfloat16), + dict(dg=config.dot_general_make(2, 2), dtype=jnp.float32), + dict(dg=config.dot_general_make(8, 8), dtype=jnp.float32), dict(dg=config.dot_general_make(None, 8)), dict(dg=config.dot_general_make(8, None)), dict( @@ -780,9 +807,10 @@ def test_dot_general_calibration_with_remaining_axis( readonly_dg = dg del dg - lhs = rand_unif(lhs_shape, lhs_maxval, seed, dtype) - rhs = rand_unif(rhs_shape, rhs_maxval, seed + 1, dtype) - gra = rand_unif(gra_shape, gra_maxval, seed + 2, dtype) + # Use integer inputs to avoid the mismatch between fake & real quant results + lhs = rand_int(lhs_shape, lhs_maxval, seed).astype(dtype) + rhs = rand_int(rhs_shape, rhs_maxval, seed + 1).astype(dtype) + gra = rand_int(gra_shape, gra_maxval, seed + 2).astype(dtype) # Prepare utility functions for test. aqt_dg_full = functools.partial( @@ -1084,8 +1112,19 @@ def dg(lhs, rhs): lhs=[1270.0, 10.0, 1270000.0, 10000.0], expected_product=1280000.0, ), + dict( + # subchannel: [1270, 10] [1270000, 10000] + # scale=bound/127: 10 10000 + # pow2 scale: 16 16384 + # qvalue=rnd(x/scale): [79, 1] [78, 1] + # product: 80*16 + 79*16384 = 1295616 + shard_count=2, + lhs=[1270.0, 10.0, 1270000.0, 10000.0], + expected_product=1295616.0, + pw2_scale=True, + ), ]) - def test_local_aqt(self, shard_count, lhs, expected_product): + def test_local_aqt(self, shard_count, lhs, expected_product, pw2_scale=False): # create a config that quantizes both forward and backward passes to int8 # set the number of shards (local aqt) to 2 dg = config.fully_quantized( @@ -1093,11 +1132,15 @@ def test_local_aqt(self, shard_count, lhs, expected_product): bwd_bits=8, use_stochastic_rounding=False, drhs_local_aqt=aqt.LocalAqt(contraction_axis_shard_count=shard_count), + use_fwd_quant=False, # To not multiply rhs scale to the left ) dg.fwd.dg_quantizer.lhs.numerics.preserve_max_val = True dg.fwd.dg_quantizer.rhs.numerics.preserve_max_val = True dg.drhs.dg_quantizer.lhs.numerics.preserve_max_val = True dg.drhs.dg_quantizer.rhs.numerics.preserve_max_val = True + if pw2_scale: + config.set_quantizer_po2_scale(dg.drhs.dg_quantizer.lhs) + config.set_quantizer_po2_scale(dg.drhs.dg_quantizer.rhs) dg_f = lambda lhs, rhs: dg( lhs, rhs, @@ -1107,7 +1150,7 @@ def test_local_aqt(self, shard_count, lhs, expected_product): rhs = jnp.array([1.0]) output, bprop = jax.vjp(dg_f, lhs, rhs) _, drhs = bprop(jnp.ones_like(output)) - assert drhs == expected_product + assert drhs == expected_product, f"{drhs=}, {expected_product=}" def test_per_tensor(self): # TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes diff --git a/aqt/jax/v2/config.py b/aqt/jax/v2/config.py index 2581a62c..7b5b7e3c 100644 --- a/aqt/jax/v2/config.py +++ b/aqt/jax/v2/config.py @@ -540,6 +540,24 @@ def _update_dtype(quantizer: aqt_quantizer.Quantizer): _update_dtype(cfg.drhs.dg_quantizer.rhs) +def set_quantizer_po2_scale(quantizer: aqt_quantizer.Quantizer): + """Set the po2_scale flag for the given quantizer.""" + if quantizer.calibration is None: + return + + calibration_cls = quantizer.calibration + # TODO(lew): Remove partial inspection wherever possible. + # Partial inspection is needed because the current implementation of delayed + # calibration initialization requires the configuration to be set via + # functools.partial. + keywords = {} + if isinstance(calibration_cls, functools.partial): + keywords = calibration_cls.keywords + calibration_cls = calibration_cls.func + keywords.update(po2_scale=True) + quantizer.calibration = functools.partial(calibration_cls, **keywords) + + ################################################################################ # Functions below are auxiliary config creators.