Skip to content

Commit

Permalink
Add fake-quantization support for quantization biases to AQTv2.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660127656
  • Loading branch information
phoenix-meadowlark authored and copybara-github committed Aug 6, 2024
1 parent f769736 commit 31b8fa4
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 30 deletions.
5 changes: 3 additions & 2 deletions aqt/jax/v2/aqt_conv_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def my_conv_general_dilated(
# In Flax, lhs is the inputs, rhs is the kernel.
# lhs layout is B, spatials..., Ci
# rhs layout is: spatials..., Ci, Co
# out layous it: B, spatials..., Co
# out layout is: B, spatials..., Co
#
# we need to share these axes: lhs[1:] , rhs[:-1]
# we have a scale/invscale per: lhs[0] / out[0] and rhs[-1] / out[-1]
Expand Down Expand Up @@ -103,6 +103,7 @@ def my_conv_general_dilated(
qvalue=out,
scale=[],
scale_t=None,
bias=[],
dequant_dtype=jnp.promote_types(lhs, rhs),
)
assert out.scale is not None # pytype help
Expand All @@ -112,7 +113,7 @@ def my_conv_general_dilated(

# # Future scale granularity optimization.
# In 1x1 conv, each pixel (spatial location) can have different scales
# in 1xN (rows x colums) conv each row can have different scale, but
# in 1xN (rows x columns) conv each row can have different scale, but
# columns need to share the scales , because we are adding pixels across.
#
# For patch convs we could have separate scales per patch.
Expand Down
21 changes: 19 additions & 2 deletions aqt/jax/v2/aqt_dot_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def dot_general_raw_make(
else:
dg_accumulator_dtype = None

# DotGeneralRaw should create a this quantizer on defualt.
# DotGeneralRaw should create a this quantizer on default.
# Then setter can change it.

# initialize_calibration=False because that has to be delayed to be called
Expand Down Expand Up @@ -574,6 +574,11 @@ def _maybe_use_fwd_quant(
)
if use_fwd_quant:
assert fwd_quantized, msg
if rhs.qx.bias:
raise NotImplementedError(
'Quantization biases are not supported in forward quantization.'
)

scale_t = transpose.rhs_scale_transpose_for_lhs_input(
rhs.qx.scale[0], dimension_numbers, lhs.shape
)
Expand Down Expand Up @@ -664,6 +669,17 @@ def __call__(
self.allow_dummy_gradient_into_qtensor
)

msg = (
'biases are only supported in fake quant mode, but got a {arg} bias '
'and self.{arg}.dequant_mode == {mode} != DequantMode.THIS_INPUT'
)
assert not (
lhs_qt.bias and self.lhs.dequant_mode != DequantMode.THIS_INPUT
), msg.format(arg='lhs', mode=self.lhs.dequant_mode)
assert not (
rhs_qt.bias and self.rhs.dequant_mode != DequantMode.THIS_INPUT
), msg.format(arg='rhs', mode=self.rhs.dequant_mode)

lhs_mt = MultiTensor(x=lhs, qx=lhs_qt)
lhs_res = TensorRes(mt=lhs_mt, quant_grad=lhs_quant_grad)

Expand Down Expand Up @@ -772,6 +788,7 @@ def _maybe_dequant(
qvalue=out,
scale=[],
scale_t=None,
bias=[],
dequant_dtype=dequant_dtype,
)
assert out.scale is not None # pytype help
Expand Down Expand Up @@ -869,7 +886,7 @@ def assert_config_validity(self: Self):
expected_fwd_quant = False
msg_fwd_quant = (
f'use_fwd_quant should be set to {expected_fwd_quant} when remaining'
' axis are used for calibration axis.'
' axis are used for calibration axis. '
)

if self.fwd.rhs.calibration_mode == CalibrationMode.REMAINING_AXIS:
Expand Down
10 changes: 6 additions & 4 deletions aqt/jax/v2/aqt_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Quantizer:
# The dtype of the quantization scale array. If not set, the scale array will
# be in the same dtype as the input.
scale_dtype: jnp.dtype | None = utils.static_field(default=None)
# TODO(yichizh): Factor out auxilliary dataclasses into a separate file.
# TODO(yichizh): Factor out auxiliary dataclasses into a separate file.
context: utils.Context

# we need to speed up this initialization for the backward pass to happen
Expand Down Expand Up @@ -87,9 +87,9 @@ def calibrate(
The tiling state is used to tile the input tensor and change the calibration
axes accordingly. When axis is tiled, it is split into multiple tiles. Each
tile shares the same quantization parameters like scale factor. On the other
hand, if the axis is not tiled, the whole axis shares the same qantization
hand, if the axis is not tiled, the whole axis shares the same quantization
parameters. This tiling will increase the granularity of calibration
reducing the numeric error from quantizaiton.
reducing the numeric error from quantization.
Args:
x: The input tensor to be calibrated.
Expand All @@ -112,6 +112,7 @@ def calibrate(
qvalue=x,
scale=[],
scale_t=None,
bias=[],
dequant_dtype=x.dtype,
tiling_state=tiling_state,
)
Expand All @@ -133,7 +134,7 @@ def calibrate(

if self.po2_scale:
# With floor the biggest value (we are using jnp.max) is in the range of
# clipping and therefore have a correct gradinet.
# clipping and therefore have a correct gradient.
scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale)))
scale = jax.lax.reciprocal(scale)
if self.scale_stop_grad:
Expand All @@ -147,6 +148,7 @@ def calibrate(
qvalue=None,
scale=[scale],
scale_t=None,
bias=[],
dequant_dtype=dequant_dtype,
tiling_state=tiling_state,
)
Expand Down
72 changes: 64 additions & 8 deletions aqt/jax/v2/aqt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class QTensor:
# TODO(lew): Move scale_t from QTensor to some dot-general specific type?
scale_t: Optional[list[ArrayT]]

# len(bias) == 0 means bias should not be applied.
# Quantization and dequantization are defined such that:
# quant(x) = (x + b) / s = (x + b[0] + b[1] + ...) / s[0] / s[1] / ...
# dequant(q) = (q * s) - b = (q * s[0] * s[1] * ...) - b[0] - b[1] - ...
bias: list[ArrayT]

# DType of the tensor before quantized.
# NOTE: AQT Users should use the public property, dtype, instead.
dequant_dtype: Optional[jnp.dtype] = flax.struct.field(
Expand Down Expand Up @@ -116,15 +122,26 @@ def quant(self, x):
"""Quantizes the QTensor."""
assert not self.is_full(), 'Already quantized QTensor.'
assert self.scale is not None, 'Missing scales to be used for quantization.'
assert isinstance(
self.scale, list
), f'QTensor.scale must be a list of arrays, but got {self.scale}'
assert isinstance(
self.bias, list
), f'QTensor.bias must be a list of arrays, but got {self.bias}'

if self.tiling_state is not None:
x = self.tiling_state.apply(x)

qvalue = x
# quant(x) = (x + b) / s
for b in self.bias:
qvalue += b

for s in self.scale:
# TODO(lew): We could store s_inv for faster activation quantization.
s_inv = jax.lax.reciprocal(s)
s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv)
qvalue = qvalue * s_inv
qvalue *= s_inv

# TODO(lew): We should apply numerics here, so that 'quant' function
# Can be considered a part of API.
Expand All @@ -133,8 +150,14 @@ def quant(self, x):
def dequant(self) -> jnp.ndarray:
"""Dequantizes the QTensor."""
assert self.scale is not None, 'Missing scales when dequantizing a QTensor.'
assert isinstance(
self.scale, list
), f'QTensor.scale must be a list of arrays, but got {self.scale}'
assert isinstance(
self.bias, list
), f'QTensor.bias must be a list of arrays, but got {self.bias}'
msg = (
'QTensor is manually created without setting a dequant_detype. It can'
'QTensor is manually created without setting a dequant_dtype. It can'
' be used in dot_general, but to dequantize you need to set its dtype.'
)
assert self.dequant_dtype is not None, msg
Expand All @@ -143,13 +166,21 @@ def dequant(self) -> jnp.ndarray:

# pytype: disable=attribute-error
ret = self.qvalue.astype(self.dequant_dtype)
for scale in self.scale:
ret = ret * scale

# dequant(q) = q * s - b
for s in self.scale:
ret *= s

# Apply bias after all rescaling is done. There may be more biases than
# scales, e.g. in native asymmetric matmul output dequantization.
for b in self.bias:
ret -= b

if self.tiling_state is not None:
ret = self.tiling_state.unapply(ret)
# In case the scale dtype is not the same as dequant_dtype, and it is a
# higher precision.

# In case the scale or bias dtypes are not the same as dequant_dtype, and it
# is a higher precision.
ret = ret.astype(self.dequant_dtype)
# pytype: enable=attribute-error
return ret # pytype: disable=bad-return-type
Expand All @@ -170,6 +201,7 @@ def __getitem__(self, idx: jax_typing.ArrayLike) -> Self:
qvalue=qvalue,
scale=scale,
scale_t=self.scale_t,
bias=self.bias,
dequant_dtype=self.dequant_dtype,
)

Expand Down Expand Up @@ -204,6 +236,7 @@ def zeros(
qvalue=jnp.zeros(shape, dtype=container_dtype),
scale=[],
scale_t=None,
bias=[],
dequant_dtype=dequant_dtype,
)

Expand All @@ -228,6 +261,7 @@ def zeros_with_scale(
qvalue=jnp.zeros(shape, dtype=container_dtype),
scale=[jnp.ones(scale_shape, dtype=scale_dtype)],
scale_t=None,
bias=[],
dequant_dtype=dequant_dtype,
)

Expand All @@ -236,15 +270,30 @@ def partition_spec(
partitions: Sequence[Any],
calibration_axis: Sequence[utils.AxisIdx],
dtype: jnp.dtype,
*,
use_bias: bool,
) -> QTensor:
"""Returns a QTensor filled with partition specs."""
# This function assumes that there is a single scale, and if use_bias=True, a
# single bias. Both of which are expected to be configured using per-channel
# quantization.
scale_partitions = list(partitions)
for axis in calibration_axis:
scale_partitions[axis] = None
if use_bias:
# Assumes that the bias to be partitioned is the bias from input
# quantization, (which has singleton dimensions for the calibration_axis),
# and not the biases used in native output dequantization, of which there
# may be more than one, and which may have the same shape as the qvalue.
bias_partition = [jax.sharding.PartitionSpec(*scale_partitions)]
else:
# JAX errors upon receiving partition specs for non-existent tensors.
bias_partition = []
return QTensor(
qvalue=jax.sharding.PartitionSpec(*partitions),
scale=[jax.sharding.PartitionSpec(*scale_partitions)],
scale_t=None,
bias=bias_partition,
dequant_dtype=dtype,
)

Expand All @@ -255,8 +304,9 @@ def dynamic_slice(
slice_sizes: Sequence[int],
) -> QTensor:
"""Dynamically slices the value at start_indices using the given shape."""
msg = 'scale_t is not supported in the dynamic_slice of a QTensor.'
assert operand.scale_t is None, msg
msg = '{attribute} is not supported in the dynamic_slice of a QTensor.'
assert operand.scale_t is None, msg.format('scale_t')
assert not operand.bias, msg.format('bias')

def get_sliced_scales(scale):
msg = 'Slice sizes must have the same length as operand dims.'
Expand Down Expand Up @@ -284,6 +334,7 @@ def get_sliced_scales(scale):
qvalue=jax.lax.dynamic_slice(operand.qvalue, start_indices, slice_sizes),
scale=[get_sliced_scales(s) for s in operand.scale],
scale_t=None,
bias=[],
dequant_dtype=operand.dequant_dtype,
)

Expand Down Expand Up @@ -332,6 +383,7 @@ def dynamic_update_slice(
qvalue=qvalues,
scale=scales,
scale_t=None,
bias=[],
dequant_dtype=operand.dequant_dtype,
)

Expand All @@ -348,5 +400,9 @@ def update_frame(operand: QTensor, frame: int, update: QTensor) -> QTensor:
for target_scale, update_scale in zip(operand.scale, update.scale)
],
scale_t=None,
bias=[
target_bias.at[frame].set(update_bias)
for target_bias, update_bias in zip(operand.bias, update.bias)
],
dequant_dtype=operand.dequant_dtype,
)
21 changes: 18 additions & 3 deletions aqt/jax/v2/aqt_tensor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ def test_dynamic_slice(self):
print(x.shape, scale.shape)

q = aqt_tensor.QTensor(
qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype
qvalue=x,
scale=[scale],
scale_t=None,
bias=[],
dequant_dtype=scale.dtype,
)
y = aqt_tensor.dynamic_slice(q, start_indices=(1, 0), slice_sizes=[2, 1])
print("======")
Expand All @@ -59,7 +63,11 @@ def test_getitem(self):
print(x.shape, scale.shape)

q = aqt_tensor.QTensor(
qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype
qvalue=x,
scale=[scale],
scale_t=None,
bias=[],
dequant_dtype=scale.dtype,
)
y = q.__getitem__(2)
print("======")
Expand All @@ -77,7 +85,11 @@ def test_dynamic_update(self):
print(scale)
print(x.shape, scale.shape)
q = aqt_tensor.QTensor(
qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype
qvalue=x,
scale=[scale],
scale_t=None,
bias=[],
dequant_dtype=scale.dtype,
)

update_qvalue = jnp.zeros((3, 1), dtype=x.dtype)
Expand All @@ -86,6 +98,7 @@ def test_dynamic_update(self):
qvalue=update_qvalue,
scale=[update_scale],
scale_t=None,
bias=[],
dequant_dtype=update_scale.dtype,
)
y = aqt_tensor.dynamic_update_slice(q, update, (0, 1))
Expand Down Expand Up @@ -149,6 +162,7 @@ def test_tiling_state(self):
qvalue=None,
scale=[xlhs_scale],
scale_t=None,
bias=[],
dequant_dtype=xlhs_scale.dtype,
tiling_state=xlhs,
)
Expand All @@ -158,6 +172,7 @@ def test_tiling_state(self):
qvalue=None,
scale=[xrhs_scale],
scale_t=None,
bias=[],
dequant_dtype=xrhs_scale.dtype,
tiling_state=xrhs,
)
Expand Down
4 changes: 2 additions & 2 deletions aqt/jax/v2/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_bound(
# int_numerics.IntNumerics.
abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True)
# TODO(yichizh): the zero filtering is not needed anymore because inf is
# filtered when calculating the reciprocal of scaline factor
# filtered when calculating the reciprocal of scaling factor
abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max)
if self.scale is not None:
abs_max = abs_max * self.scale
Expand Down Expand Up @@ -269,7 +269,7 @@ def _calculate_snr(
scale = bound / abs_max_mapped_to

q_tensor = aqt_tensor.QTensor(
qvalue=None, scale=[scale], scale_t=None, dequant_dtype=x.dtype
qvalue=None, scale=[scale], scale_t=None, bias=[], dequant_dtype=x.dtype
).quant(x)

# This actually quantizes the tensor (clips, rounds, etc).
Expand Down
Loading

0 comments on commit 31b8fa4

Please sign in to comment.