Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fake quant in the bwd with local_aqt #748

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,6 +773,9 @@ def __call__(
lhs, rhs, dimension_numbers, self.rhs.use_fwd_quant
)
assert isinstance(rhs, jnp.ndarray)
orig_lhs_shape = lhs.shape
orig_rhs_shape = rhs.shape
orig_dimension_numbers = dimension_numbers

# TODO(lew): Define cutsom_vjp on tiled_dot_general and replace local_aqt.
if self.local_aqt is not None:
Expand Down Expand Up @@ -818,17 +821,29 @@ def __call__(

# TODO(lew): mt.x above should be clipped for clipping calibrations
out = _qtensor_dot_general(
lhs_qt, rhs_qt, dimension_numbers, self, jnp.promote_types(lhs, rhs)
lhs_qt,
rhs_qt,
dimension_numbers,
self,
jnp.promote_types(lhs, rhs),
orig_lhs_shape,
orig_rhs_shape,
orig_dimension_numbers,
)

out = out.dequant()

res = DotGeneralRes(lhs=lhs_res, rhs=rhs_res)
if self.local_aqt is not None:
(lhs_ca, rhs_ca), _ = dimension_numbers
assert len(lhs_ca) == len(rhs_ca)
if len(lhs_ca) > 0:
out = jnp.sum(out, axis=0)
fq = (
self.lhs.dequant_mode == DequantMode.THIS_INPUT
and self.rhs.dequant_mode == DequantMode.THIS_INPUT
)
if not fq: # Fake quant don't require sum over tile count axis
(lhs_ca, rhs_ca), _ = dimension_numbers
assert len(lhs_ca) == len(rhs_ca)
if len(lhs_ca) > 0:
out = jnp.sum(out, axis=0)
# We are not supporting local AQT in fwd pass, so no res needed.
res = None
return out, res
Expand All @@ -841,6 +856,9 @@ def _qtensor_dot_general(
cfg: Any,
# dequant_dtype: DType,
dequant_dtype: jnp.dtype,
orig_lhs_shape: tuple[int, ...],
orig_rhs_shape: tuple[int, ...],
orig_dimension_numbers: jax.lax.DotDimensionNumbers,
) -> aqt_tensor.QTensor:
"""QTensor lax.dot_general replacement."""

Expand All @@ -858,6 +876,18 @@ def _maybe_dequant(
# Dequantize before the lax dg call if in fake quant mode
lhs_qin = _maybe_dequant(lhs_qt, cfg.lhs)
rhs_qin = _maybe_dequant(rhs_qt, cfg.rhs)
if (
cfg.lhs.dequant_mode == DequantMode.THIS_INPUT
and cfg.rhs.dequant_mode == DequantMode.THIS_INPUT
and cfg.local_aqt is not None
):
# Revert local_aqt reshaping and reuse the original dimension numbers.
# The correct order to compute fake quant is:
# reshape (local_aqt), quant, dequant, reshape back.
# Dims and original shape are therefore required to pass in qtensor_dg.
dimension_numbers = orig_dimension_numbers
lhs_qin = lhs_qin.reshape(orig_lhs_shape)
rhs_qin = rhs_qin.reshape(orig_rhs_shape)

dtype_ms = (
f'Found {cfg.dg_accumulator_dtype=}, {lhs_qin.dtype=} and'
Expand Down
107 changes: 75 additions & 32 deletions aqt/jax/v2/aqt_dot_general_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,6 @@
import scipy.stats


def _apply_po2_scale(quantizer):
if quantizer.calibration is None:
return

calibration_cls = quantizer.calibration
# TODO(lew): Remove partial inspection wherever possible.
# Partial inspection is needed because the current implementation of delayed
# calibration initialization requires the configuration to be set via
# functools.partial.
keywords = {}
if isinstance(calibration_cls, functools.partial):
keywords = calibration_cls.keywords
calibration_cls = calibration_cls.func
keywords.update(po2_scale=True)
quantizer.calibration = functools.partial(calibration_cls, **keywords)


def test_jaxpr_dtype(f, dg_raws: list[aqt.DotGeneralRaw], float_dtype):
"""Tests whether dot_generals in f conform to dtypes inside of dg_raws."""

Expand Down Expand Up @@ -98,6 +81,13 @@ def rand_unif(shape, maxval, seed, dtype=jnp.float32):
)


def rand_int(shape, maxval, seed, dtype=int):
key = jax.random.PRNGKey(seed)
return jax.random.randint(
key=key, shape=shape, minval=-maxval, maxval=maxval, dtype=dtype
)


# The main test strategy is :
# - Test that FQ is sensible (improve fq_noise_test)
# - Quantization noise (rounding error) is of proper value
Expand Down Expand Up @@ -253,8 +243,8 @@ def _disable_quant_types(c, on_lhs=True, on_rhs=True):
# have the same numerics when scales are power of two (po2).
# We are passing dims to config so that we can reuse it in fake_quant.
# Power-of-2 scales allow FQ and AQT to be exactly the same.
_apply_po2_scale(c.dg_quantizer.lhs)
_apply_po2_scale(c.dg_quantizer.rhs)
config.set_quantizer_po2_scale(c.dg_quantizer.lhs)
config.set_quantizer_po2_scale(c.dg_quantizer.rhs)

_apply_dequant_mode(c, lhs_dequant_mode, rhs_dequant_mode)
_apply_calibration_mode(c, lhs_calibration_mode, rhs_calibration_mode)
Expand Down Expand Up @@ -456,6 +446,41 @@ def assert_clt(noise: jnp.ndarray):
custom_1_noise = noise_fn(shape, jax.random.PRNGKey(11))
assert_clt(custom_1_noise)

def test_sum(self):
# This is a test showing that the following will produce *different* results
# (1) Contract 2 axes using jax.lax.dot_general
# (2) First contract 1 axix using jax.lax.dot_general, then contract the
# other axis using jnp.sum.
# This means that fake quant subchannel will produce different results from
# the real quant subchannel.
key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(key)
dtype = jnp.float32
lhs = jax.random.normal(key1, shape=(2, 3, 4), dtype=dtype)
rhs = jax.random.normal(key2, shape=(3, 4, 5), dtype=dtype)

def product1(x, y):
return jax.lax.dot_general(
x, y, dimension_numbers=(((1, 2), (0, 1)), ((), ()))
)

def product2(x, y):
intermediate = jax.lax.dot_general(
x, y, dimension_numbers=(((1,), (0,)), ((2,), (1,)))
)
return jnp.sum(intermediate, axis=0)

def print_err(x, y):
print(f"{x=}")
print(f"{y=}")
mse = jnp.mean(jnp.square(x - y))
print(f"{mse=}")

out1 = product1(lhs, rhs)
out2 = product2(lhs, rhs)
print_err(out1, out2)
assert not (out1 == out2).all()

@parameterized.parameters([
dict(bits=1),
])
Expand All @@ -466,7 +491,7 @@ def test_fake_quant(
shape=(20, 1),
):
quantizer = config.quantizer_make(bits, initialize_calibration=False)
_apply_po2_scale(quantizer)
config.set_quantizer_po2_scale(quantizer)
quantizer.init_calibration()
quantizer.calib_shared_axes = (0,)
x = jnp.linspace(-maxval, maxval, num=shape[0]).reshape(shape)
Expand Down Expand Up @@ -516,8 +541,8 @@ def test_fake_quant(
# can't keep in the product of int8*int8 accurately.
# It just so happens that this test does not fail but others do.
# We do this test anyway, to catch jax-compilation-time errors.
dict(dg=config.dot_general_make(2, 2), dtype=jnp.bfloat16),
dict(dg=config.dot_general_make(8, 8), dtype=jnp.bfloat16),
dict(dg=config.dot_general_make(2, 2), dtype=jnp.float32),
dict(dg=config.dot_general_make(8, 8), dtype=jnp.float32),
dict(dg=config.dot_general_make(None, 8)),
dict(dg=config.dot_general_make(8, None)),
dict(
Expand All @@ -529,6 +554,7 @@ def test_fake_quant(
gra_shape=(4, 3, 6),
),
dict(
# This tests fake subchannel implementation as well.
dg=fqt_param_dict(
s=10,
use_fwd_quant=True,
Expand Down Expand Up @@ -574,9 +600,10 @@ def test_dot_general_calibration_with_contracting_axis(
readonly_dg = dg
del dg

lhs = rand_unif(lhs_shape, lhs_maxval, seed, dtype)
rhs = rand_unif(rhs_shape, rhs_maxval, seed + 1, dtype)
gra = rand_unif(gra_shape, gra_maxval, seed + 2, dtype)
# Use integer inputs to avoid the mismatch between fake & real quant results
lhs = rand_int(lhs_shape, lhs_maxval, seed).astype(dtype)
rhs = rand_int(rhs_shape, rhs_maxval, seed + 1).astype(dtype)
gra = rand_int(gra_shape, gra_maxval, seed + 2).astype(dtype)

# Prepare utility functions for test.
aqt_dg_full = functools.partial(
Expand Down Expand Up @@ -741,8 +768,8 @@ def lax_dg(lhs, rhs):
# can't keep in the product of int8*int8 accurately.
# It just so happens that this test does not fail but others do.
# We do this test anyway, to catch jax-compilation-time errors.
dict(dg=config.dot_general_make(2, 2), dtype=jnp.bfloat16),
dict(dg=config.dot_general_make(8, 8), dtype=jnp.bfloat16),
dict(dg=config.dot_general_make(2, 2), dtype=jnp.float32),
dict(dg=config.dot_general_make(8, 8), dtype=jnp.float32),
dict(dg=config.dot_general_make(None, 8)),
dict(dg=config.dot_general_make(8, None)),
dict(
Expand Down Expand Up @@ -780,9 +807,10 @@ def test_dot_general_calibration_with_remaining_axis(
readonly_dg = dg
del dg

lhs = rand_unif(lhs_shape, lhs_maxval, seed, dtype)
rhs = rand_unif(rhs_shape, rhs_maxval, seed + 1, dtype)
gra = rand_unif(gra_shape, gra_maxval, seed + 2, dtype)
# Use integer inputs to avoid the mismatch between fake & real quant results
lhs = rand_int(lhs_shape, lhs_maxval, seed).astype(dtype)
rhs = rand_int(rhs_shape, rhs_maxval, seed + 1).astype(dtype)
gra = rand_int(gra_shape, gra_maxval, seed + 2).astype(dtype)

# Prepare utility functions for test.
aqt_dg_full = functools.partial(
Expand Down Expand Up @@ -1084,20 +1112,35 @@ def dg(lhs, rhs):
lhs=[1270.0, 10.0, 1270000.0, 10000.0],
expected_product=1280000.0,
),
dict(
# subchannel: [1270, 10] [1270000, 10000]
# scale=bound/127: 10 10000
# pow2 scale: 16 16384
# qvalue=rnd(x/scale): [79, 1] [78, 1]
# product: 80*16 + 79*16384 = 1295616
shard_count=2,
lhs=[1270.0, 10.0, 1270000.0, 10000.0],
expected_product=1295616.0,
pw2_scale=True,
),
])
def test_local_aqt(self, shard_count, lhs, expected_product):
def test_local_aqt(self, shard_count, lhs, expected_product, pw2_scale=False):
# create a config that quantizes both forward and backward passes to int8
# set the number of shards (local aqt) to 2
dg = config.fully_quantized(
fwd_bits=8,
bwd_bits=8,
use_stochastic_rounding=False,
drhs_local_aqt=aqt.LocalAqt(contraction_axis_shard_count=shard_count),
use_fwd_quant=False, # To not multiply rhs scale to the left
)
dg.fwd.dg_quantizer.lhs.numerics.preserve_max_val = True
dg.fwd.dg_quantizer.rhs.numerics.preserve_max_val = True
dg.drhs.dg_quantizer.lhs.numerics.preserve_max_val = True
dg.drhs.dg_quantizer.rhs.numerics.preserve_max_val = True
if pw2_scale:
config.set_quantizer_po2_scale(dg.drhs.dg_quantizer.lhs)
config.set_quantizer_po2_scale(dg.drhs.dg_quantizer.rhs)
dg_f = lambda lhs, rhs: dg(
lhs,
rhs,
Expand All @@ -1107,7 +1150,7 @@ def test_local_aqt(self, shard_count, lhs, expected_product):
rhs = jnp.array([1.0])
output, bprop = jax.vjp(dg_f, lhs, rhs)
_, drhs = bprop(jnp.ones_like(output))
assert drhs == expected_product
assert drhs == expected_product, f"{drhs=}, {expected_product=}"

def test_per_tensor(self):
# TODO(lew): bits=8 started failing in VLP colab due x/x != 1.0 sometimes
Expand Down
18 changes: 18 additions & 0 deletions aqt/jax/v2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,24 @@ def _update_dtype(quantizer: aqt_quantizer.Quantizer):
_update_dtype(cfg.drhs.dg_quantizer.rhs)


def set_quantizer_po2_scale(quantizer: aqt_quantizer.Quantizer):
"""Set the po2_scale flag for the given quantizer."""
if quantizer.calibration is None:
return

calibration_cls = quantizer.calibration
# TODO(lew): Remove partial inspection wherever possible.
# Partial inspection is needed because the current implementation of delayed
# calibration initialization requires the configuration to be set via
# functools.partial.
keywords = {}
if isinstance(calibration_cls, functools.partial):
keywords = calibration_cls.keywords
calibration_cls = calibration_cls.func
keywords.update(po2_scale=True)
quantizer.calibration = functools.partial(calibration_cls, **keywords)


################################################################################
# Functions below are auxiliary config creators.

Expand Down