From 31b8fa4bf1293b591f187c14f53d4231c571fcbc Mon Sep 17 00:00:00 2001 From: Phoenix Meadowlark Date: Tue, 6 Aug 2024 15:48:55 -0700 Subject: [PATCH] Add fake-quantization support for quantization biases to AQTv2. PiperOrigin-RevId: 660127656 --- aqt/jax/v2/aqt_conv_general.py | 5 +- aqt/jax/v2/aqt_dot_general.py | 21 +++++- aqt/jax/v2/aqt_quantizer.py | 10 +-- aqt/jax/v2/aqt_tensor.py | 72 ++++++++++++++++--- aqt/jax/v2/aqt_tensor_test.py | 21 +++++- aqt/jax/v2/calibration.py | 4 +- aqt/jax/v2/examples/flax_e2e_model_test.py | 17 ++++- .../gptq/examples/gptq_flax_e2e_model_test.py | 3 + aqt/jax/v2/flax/aqt_flax.py | 17 ++++- aqt/jax/v2/flax/aqt_flax_test.py | 2 + aqt/jax/v2/pallas/dot_general.py | 14 ++-- aqt/jax/v2/pallas/pallas_tensor.py | 1 + aqt/jax/v2/pallas/pallas_tensor_test.py | 1 + 13 files changed, 158 insertions(+), 30 deletions(-) diff --git a/aqt/jax/v2/aqt_conv_general.py b/aqt/jax/v2/aqt_conv_general.py index be041628..b0a4b632 100644 --- a/aqt/jax/v2/aqt_conv_general.py +++ b/aqt/jax/v2/aqt_conv_general.py @@ -60,7 +60,7 @@ def my_conv_general_dilated( # In Flax, lhs is the inputs, rhs is the kernel. # lhs layout is B, spatials..., Ci # rhs layout is: spatials..., Ci, Co - # out layous it: B, spatials..., Co + # out layout is: B, spatials..., Co # # we need to share these axes: lhs[1:] , rhs[:-1] # we have a scale/invscale per: lhs[0] / out[0] and rhs[-1] / out[-1] @@ -103,6 +103,7 @@ def my_conv_general_dilated( qvalue=out, scale=[], scale_t=None, + bias=[], dequant_dtype=jnp.promote_types(lhs, rhs), ) assert out.scale is not None # pytype help @@ -112,7 +113,7 @@ def my_conv_general_dilated( # # Future scale granularity optimization. # In 1x1 conv, each pixel (spatial location) can have different scales - # in 1xN (rows x colums) conv each row can have different scale, but + # in 1xN (rows x columns) conv each row can have different scale, but # columns need to share the scales , because we are adding pixels across. # # For patch convs we could have separate scales per patch. diff --git a/aqt/jax/v2/aqt_dot_general.py b/aqt/jax/v2/aqt_dot_general.py index 56f7849d..d042e8df 100644 --- a/aqt/jax/v2/aqt_dot_general.py +++ b/aqt/jax/v2/aqt_dot_general.py @@ -114,7 +114,7 @@ def dot_general_raw_make( else: dg_accumulator_dtype = None - # DotGeneralRaw should create a this quantizer on defualt. + # DotGeneralRaw should create a this quantizer on default. # Then setter can change it. # initialize_calibration=False because that has to be delayed to be called @@ -574,6 +574,11 @@ def _maybe_use_fwd_quant( ) if use_fwd_quant: assert fwd_quantized, msg + if rhs.qx.bias: + raise NotImplementedError( + 'Quantization biases are not supported in forward quantization.' + ) + scale_t = transpose.rhs_scale_transpose_for_lhs_input( rhs.qx.scale[0], dimension_numbers, lhs.shape ) @@ -664,6 +669,17 @@ def __call__( self.allow_dummy_gradient_into_qtensor ) + msg = ( + 'biases are only supported in fake quant mode, but got a {arg} bias ' + 'and self.{arg}.dequant_mode == {mode} != DequantMode.THIS_INPUT' + ) + assert not ( + lhs_qt.bias and self.lhs.dequant_mode != DequantMode.THIS_INPUT + ), msg.format(arg='lhs', mode=self.lhs.dequant_mode) + assert not ( + rhs_qt.bias and self.rhs.dequant_mode != DequantMode.THIS_INPUT + ), msg.format(arg='rhs', mode=self.rhs.dequant_mode) + lhs_mt = MultiTensor(x=lhs, qx=lhs_qt) lhs_res = TensorRes(mt=lhs_mt, quant_grad=lhs_quant_grad) @@ -772,6 +788,7 @@ def _maybe_dequant( qvalue=out, scale=[], scale_t=None, + bias=[], dequant_dtype=dequant_dtype, ) assert out.scale is not None # pytype help @@ -869,7 +886,7 @@ def assert_config_validity(self: Self): expected_fwd_quant = False msg_fwd_quant = ( f'use_fwd_quant should be set to {expected_fwd_quant} when remaining' - ' axis are used for calibration axis.' + ' axis are used for calibration axis. ' ) if self.fwd.rhs.calibration_mode == CalibrationMode.REMAINING_AXIS: diff --git a/aqt/jax/v2/aqt_quantizer.py b/aqt/jax/v2/aqt_quantizer.py index 80e40d8b..d79f3441 100644 --- a/aqt/jax/v2/aqt_quantizer.py +++ b/aqt/jax/v2/aqt_quantizer.py @@ -51,7 +51,7 @@ class Quantizer: # The dtype of the quantization scale array. If not set, the scale array will # be in the same dtype as the input. scale_dtype: jnp.dtype | None = utils.static_field(default=None) - # TODO(yichizh): Factor out auxilliary dataclasses into a separate file. + # TODO(yichizh): Factor out auxiliary dataclasses into a separate file. context: utils.Context # we need to speed up this initialization for the backward pass to happen @@ -87,9 +87,9 @@ def calibrate( The tiling state is used to tile the input tensor and change the calibration axes accordingly. When axis is tiled, it is split into multiple tiles. Each tile shares the same quantization parameters like scale factor. On the other - hand, if the axis is not tiled, the whole axis shares the same qantization + hand, if the axis is not tiled, the whole axis shares the same quantization parameters. This tiling will increase the granularity of calibration - reducing the numeric error from quantizaiton. + reducing the numeric error from quantization. Args: x: The input tensor to be calibrated. @@ -112,6 +112,7 @@ def calibrate( qvalue=x, scale=[], scale_t=None, + bias=[], dequant_dtype=x.dtype, tiling_state=tiling_state, ) @@ -133,7 +134,7 @@ def calibrate( if self.po2_scale: # With floor the biggest value (we are using jnp.max) is in the range of - # clipping and therefore have a correct gradinet. + # clipping and therefore have a correct gradient. scale = 2 ** jnp.floor(jnp.log2(jax.lax.reciprocal(scale))) scale = jax.lax.reciprocal(scale) if self.scale_stop_grad: @@ -147,6 +148,7 @@ def calibrate( qvalue=None, scale=[scale], scale_t=None, + bias=[], dequant_dtype=dequant_dtype, tiling_state=tiling_state, ) diff --git a/aqt/jax/v2/aqt_tensor.py b/aqt/jax/v2/aqt_tensor.py index 94f3abf2..ba82194a 100644 --- a/aqt/jax/v2/aqt_tensor.py +++ b/aqt/jax/v2/aqt_tensor.py @@ -72,6 +72,12 @@ class QTensor: # TODO(lew): Move scale_t from QTensor to some dot-general specific type? scale_t: Optional[list[ArrayT]] + # len(bias) == 0 means bias should not be applied. + # Quantization and dequantization are defined such that: + # quant(x) = (x + b) / s = (x + b[0] + b[1] + ...) / s[0] / s[1] / ... + # dequant(q) = (q * s) - b = (q * s[0] * s[1] * ...) - b[0] - b[1] - ... + bias: list[ArrayT] + # DType of the tensor before quantized. # NOTE: AQT Users should use the public property, dtype, instead. dequant_dtype: Optional[jnp.dtype] = flax.struct.field( @@ -116,15 +122,26 @@ def quant(self, x): """Quantizes the QTensor.""" assert not self.is_full(), 'Already quantized QTensor.' assert self.scale is not None, 'Missing scales to be used for quantization.' + assert isinstance( + self.scale, list + ), f'QTensor.scale must be a list of arrays, but got {self.scale}' + assert isinstance( + self.bias, list + ), f'QTensor.bias must be a list of arrays, but got {self.bias}' if self.tiling_state is not None: x = self.tiling_state.apply(x) + qvalue = x + # quant(x) = (x + b) / s + for b in self.bias: + qvalue += b + for s in self.scale: # TODO(lew): We could store s_inv for faster activation quantization. s_inv = jax.lax.reciprocal(s) s_inv = jnp.where(jnp.isinf(s_inv), jnp.ones_like(s_inv), s_inv) - qvalue = qvalue * s_inv + qvalue *= s_inv # TODO(lew): We should apply numerics here, so that 'quant' function # Can be considered a part of API. @@ -133,8 +150,14 @@ def quant(self, x): def dequant(self) -> jnp.ndarray: """Dequantizes the QTensor.""" assert self.scale is not None, 'Missing scales when dequantizing a QTensor.' + assert isinstance( + self.scale, list + ), f'QTensor.scale must be a list of arrays, but got {self.scale}' + assert isinstance( + self.bias, list + ), f'QTensor.bias must be a list of arrays, but got {self.bias}' msg = ( - 'QTensor is manually created without setting a dequant_detype. It can' + 'QTensor is manually created without setting a dequant_dtype. It can' ' be used in dot_general, but to dequantize you need to set its dtype.' ) assert self.dequant_dtype is not None, msg @@ -143,13 +166,21 @@ def dequant(self) -> jnp.ndarray: # pytype: disable=attribute-error ret = self.qvalue.astype(self.dequant_dtype) - for scale in self.scale: - ret = ret * scale + + # dequant(q) = q * s - b + for s in self.scale: + ret *= s + + # Apply bias after all rescaling is done. There may be more biases than + # scales, e.g. in native asymmetric matmul output dequantization. + for b in self.bias: + ret -= b if self.tiling_state is not None: ret = self.tiling_state.unapply(ret) - # In case the scale dtype is not the same as dequant_dtype, and it is a - # higher precision. + + # In case the scale or bias dtypes are not the same as dequant_dtype, and it + # is a higher precision. ret = ret.astype(self.dequant_dtype) # pytype: enable=attribute-error return ret # pytype: disable=bad-return-type @@ -170,6 +201,7 @@ def __getitem__(self, idx: jax_typing.ArrayLike) -> Self: qvalue=qvalue, scale=scale, scale_t=self.scale_t, + bias=self.bias, dequant_dtype=self.dequant_dtype, ) @@ -204,6 +236,7 @@ def zeros( qvalue=jnp.zeros(shape, dtype=container_dtype), scale=[], scale_t=None, + bias=[], dequant_dtype=dequant_dtype, ) @@ -228,6 +261,7 @@ def zeros_with_scale( qvalue=jnp.zeros(shape, dtype=container_dtype), scale=[jnp.ones(scale_shape, dtype=scale_dtype)], scale_t=None, + bias=[], dequant_dtype=dequant_dtype, ) @@ -236,15 +270,30 @@ def partition_spec( partitions: Sequence[Any], calibration_axis: Sequence[utils.AxisIdx], dtype: jnp.dtype, + *, + use_bias: bool, ) -> QTensor: """Returns a QTensor filled with partition specs.""" + # This function assumes that there is a single scale, and if use_bias=True, a + # single bias. Both of which are expected to be configured using per-channel + # quantization. scale_partitions = list(partitions) for axis in calibration_axis: scale_partitions[axis] = None + if use_bias: + # Assumes that the bias to be partitioned is the bias from input + # quantization, (which has singleton dimensions for the calibration_axis), + # and not the biases used in native output dequantization, of which there + # may be more than one, and which may have the same shape as the qvalue. + bias_partition = [jax.sharding.PartitionSpec(*scale_partitions)] + else: + # JAX errors upon receiving partition specs for non-existent tensors. + bias_partition = [] return QTensor( qvalue=jax.sharding.PartitionSpec(*partitions), scale=[jax.sharding.PartitionSpec(*scale_partitions)], scale_t=None, + bias=bias_partition, dequant_dtype=dtype, ) @@ -255,8 +304,9 @@ def dynamic_slice( slice_sizes: Sequence[int], ) -> QTensor: """Dynamically slices the value at start_indices using the given shape.""" - msg = 'scale_t is not supported in the dynamic_slice of a QTensor.' - assert operand.scale_t is None, msg + msg = '{attribute} is not supported in the dynamic_slice of a QTensor.' + assert operand.scale_t is None, msg.format('scale_t') + assert not operand.bias, msg.format('bias') def get_sliced_scales(scale): msg = 'Slice sizes must have the same length as operand dims.' @@ -284,6 +334,7 @@ def get_sliced_scales(scale): qvalue=jax.lax.dynamic_slice(operand.qvalue, start_indices, slice_sizes), scale=[get_sliced_scales(s) for s in operand.scale], scale_t=None, + bias=[], dequant_dtype=operand.dequant_dtype, ) @@ -332,6 +383,7 @@ def dynamic_update_slice( qvalue=qvalues, scale=scales, scale_t=None, + bias=[], dequant_dtype=operand.dequant_dtype, ) @@ -348,5 +400,9 @@ def update_frame(operand: QTensor, frame: int, update: QTensor) -> QTensor: for target_scale, update_scale in zip(operand.scale, update.scale) ], scale_t=None, + bias=[ + target_bias.at[frame].set(update_bias) + for target_bias, update_bias in zip(operand.bias, update.bias) + ], dequant_dtype=operand.dequant_dtype, ) diff --git a/aqt/jax/v2/aqt_tensor_test.py b/aqt/jax/v2/aqt_tensor_test.py index 434db7d9..af6a48b5 100644 --- a/aqt/jax/v2/aqt_tensor_test.py +++ b/aqt/jax/v2/aqt_tensor_test.py @@ -40,7 +40,11 @@ def test_dynamic_slice(self): print(x.shape, scale.shape) q = aqt_tensor.QTensor( - qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype + qvalue=x, + scale=[scale], + scale_t=None, + bias=[], + dequant_dtype=scale.dtype, ) y = aqt_tensor.dynamic_slice(q, start_indices=(1, 0), slice_sizes=[2, 1]) print("======") @@ -59,7 +63,11 @@ def test_getitem(self): print(x.shape, scale.shape) q = aqt_tensor.QTensor( - qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype + qvalue=x, + scale=[scale], + scale_t=None, + bias=[], + dequant_dtype=scale.dtype, ) y = q.__getitem__(2) print("======") @@ -77,7 +85,11 @@ def test_dynamic_update(self): print(scale) print(x.shape, scale.shape) q = aqt_tensor.QTensor( - qvalue=x, scale=[scale], scale_t=None, dequant_dtype=scale.dtype + qvalue=x, + scale=[scale], + scale_t=None, + bias=[], + dequant_dtype=scale.dtype, ) update_qvalue = jnp.zeros((3, 1), dtype=x.dtype) @@ -86,6 +98,7 @@ def test_dynamic_update(self): qvalue=update_qvalue, scale=[update_scale], scale_t=None, + bias=[], dequant_dtype=update_scale.dtype, ) y = aqt_tensor.dynamic_update_slice(q, update, (0, 1)) @@ -149,6 +162,7 @@ def test_tiling_state(self): qvalue=None, scale=[xlhs_scale], scale_t=None, + bias=[], dequant_dtype=xlhs_scale.dtype, tiling_state=xlhs, ) @@ -158,6 +172,7 @@ def test_tiling_state(self): qvalue=None, scale=[xrhs_scale], scale_t=None, + bias=[], dequant_dtype=xrhs_scale.dtype, tiling_state=xrhs, ) diff --git a/aqt/jax/v2/calibration.py b/aqt/jax/v2/calibration.py index 4310fcae..b183ca4e 100644 --- a/aqt/jax/v2/calibration.py +++ b/aqt/jax/v2/calibration.py @@ -97,7 +97,7 @@ def get_bound( # int_numerics.IntNumerics. abs_max = jnp.max(jnp.abs(x), axis=shared_axes, keepdims=True) # TODO(yichizh): the zero filtering is not needed anymore because inf is - # filtered when calculating the reciprocal of scaline factor + # filtered when calculating the reciprocal of scaling factor abs_max = jnp.where(abs_max == 0.0, jnp.ones_like(abs_max), abs_max) if self.scale is not None: abs_max = abs_max * self.scale @@ -269,7 +269,7 @@ def _calculate_snr( scale = bound / abs_max_mapped_to q_tensor = aqt_tensor.QTensor( - qvalue=None, scale=[scale], scale_t=None, dequant_dtype=x.dtype + qvalue=None, scale=[scale], scale_t=None, bias=[], dequant_dtype=x.dtype ).quant(x) # This actually quantizes the tensor (clips, rounds, etc). diff --git a/aqt/jax/v2/examples/flax_e2e_model_test.py b/aqt/jax/v2/examples/flax_e2e_model_test.py index c4d6d0b0..f1b33ae8 100644 --- a/aqt/jax/v2/examples/flax_e2e_model_test.py +++ b/aqt/jax/v2/examples/flax_e2e_model_test.py @@ -41,7 +41,7 @@ def _dummy_dataset(ds_size, image_rng, label_rng): class MnistTest(parameterized.TestCase): # Unable to use config_v4() in parameters since it needs jax.device info. - # TODO(aqt): Move confiv_v4() into parameters once int4 works for cpu. + # TODO(aqt): Move config_v4() into parameters once int4 works for cpu. @parameterized.parameters([ ( { @@ -138,6 +138,7 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (2, 5, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) } @@ -157,6 +158,7 @@ def forward(model, apply_fn): # After tiling the scale shape is (2, 1, 256), # then transposed to (2, 1, 256). scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) } @@ -176,6 +178,7 @@ def forward(model, apply_fn): # After tiling the scale shape is (2, 1, 10), # then transposed to (2, 1, 10). scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) } @@ -406,6 +409,7 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (2, 5, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -414,6 +418,7 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -426,6 +431,7 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -434,6 +440,7 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (2, 1568, 256)), scale=[(dtype("float32"), (2, 1, 256))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -446,6 +453,7 @@ def assert_array_not_equal(x, y): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -454,6 +462,7 @@ def assert_array_not_equal(x, y): qvalue=(expected_dtype, (2, 128, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -624,6 +633,7 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (2, 5, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -632,6 +642,7 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -644,6 +655,7 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -652,6 +664,7 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (2, 1568, 256)), scale=[(dtype("float32"), (2, 1, 256))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -664,6 +677,7 @@ def forward(model, apply_fn): qvalue=None, scale=[(dtype("float32"), (1, 1, 1))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -672,6 +686,7 @@ def forward(model, apply_fn): qvalue=(expected_dtype, (2, 128, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, diff --git a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py index 0f39d335..12256d75 100644 --- a/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py +++ b/aqt/jax/v2/extensions/gptq/examples/gptq_flax_e2e_model_test.py @@ -129,6 +129,7 @@ def test_gptq(self): qvalue=(expected_dtype, (2, 5, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) }, @@ -141,6 +142,7 @@ def test_gptq(self): qvalue=(expected_dtype, (2, 1568, 256)), scale=[(dtype("float32"), (2, 1, 256))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) } @@ -153,6 +155,7 @@ def test_gptq(self): qvalue=(expected_dtype, (2, 128, 10)), scale=[(dtype("float32"), (2, 1, 10))], scale_t=None, + bias=[], dequant_dtype=dtype("float32"), ) } diff --git a/aqt/jax/v2/flax/aqt_flax.py b/aqt/jax/v2/flax/aqt_flax.py index 92019fde..247bfcb9 100644 --- a/aqt/jax/v2/flax/aqt_flax.py +++ b/aqt/jax/v2/flax/aqt_flax.py @@ -60,7 +60,7 @@ class Freezer(nn.Module): On default it is an identity function that saves the input in a variable. In 'quant_mode=QuantMode.Serve' mode, ignores the input and returns the frozen - value. It is usefult to implement 'constant folding' and put quantized weights + value. It is useful to implement 'constant folding' and put quantized weights and scales in the checkpoint for serving. Specifically: self.get() returns None when quant_mode=TRAIN or CONVERT, returns variable @@ -106,6 +106,7 @@ def get(self) -> Optional[aqt_tensor.QTensor]: qvalue=self.qvalue.value, scale=None, scale_t=[self.scale_t.value], + bias=[], # TODO(lew): Ideal solution: To find out this dequant_dtype one should # use the dtype of inputs of the quant function. We should store it as # a dtype of small-sized scale tensor. @@ -120,6 +121,10 @@ def set(self, inputs: aqt_tensor.QTensor) -> None: # f'Freezer got a QTensor of type {inputs.qvalue.dtype} but expected' # f' {self.q_dtype}.' # ) + if inputs.bias: + raise NotImplementedError( + 'Quantization biases are not supported in AQT Flax Legacy Freezer.' + ) if self.quant_mode == QuantMode.TRAIN: pass elif self.quant_mode == QuantMode.CALIBRATE: @@ -310,6 +315,9 @@ def init_wrapper( scale_non_shard_axis_all = list(range(qt.ndim)) scale_non_shard_axis_contracting = list(contracting_axis) + def _get_singleton_axes(x: jnp.ndarray) -> list[utils.AxisIdx]: + return [axis for axis, dim in enumerate(x.shape) if dim == 1] + qt = qt.replace( qvalue=axis_metadata_wrapper( qt.qvalue, @@ -331,6 +339,13 @@ def init_wrapper( ), qt.scale_t, ), + # Set the non-sharding axes for bias to the singleton dimensions. + bias=jax.tree.map( + lambda x: axis_metadata_wrapper( + x, tile_map, _get_singleton_axes(x) + ), + qt.bias, + ), ) return qt diff --git a/aqt/jax/v2/flax/aqt_flax_test.py b/aqt/jax/v2/flax/aqt_flax_test.py index 5288beab..6bf9c4ce 100644 --- a/aqt/jax/v2/flax/aqt_flax_test.py +++ b/aqt/jax/v2/flax/aqt_flax_test.py @@ -314,6 +314,7 @@ def __call__(self, lhs): qvalue=(dtype('int8'), (2, 3, 2, 5, 2)), scale=[(dtype('float32'), (2, 3, 1, 5, 1))], scale_t=None, + bias=[], dequant_dtype=dtype('float32'), ) } @@ -386,6 +387,7 @@ def __call__(self, lhs): qvalue=(dtype('int8'), (3, 6, 5, 2)), scale=[(dtype('float32'), (3, 1, 5, 1))], scale_t=None, + bias=[], dequant_dtype=dtype('float32'), ) } diff --git a/aqt/jax/v2/pallas/dot_general.py b/aqt/jax/v2/pallas/dot_general.py index ae780e7c..fd8aeaff 100644 --- a/aqt/jax/v2/pallas/dot_general.py +++ b/aqt/jax/v2/pallas/dot_general.py @@ -72,11 +72,11 @@ def dot_general( precision: this is ignored, but added to match the signature of jax.lax.dot_general. preferred_element_type: preferred output element type after dequantization. - lhs_dequant_mode: This decicdes where lhs is dequantized. Default is OUTPUT + lhs_dequant_mode: This decides where lhs is dequantized. Default is OUTPUT where dequantization is applied to the output of dot_general. OTHER_INPUT applies dequantization to the other input before dot_general and THIS_INPUT applies dequantization to the current input before dot_general. - rhs_dequant_mode: This decicdes where rhs is dequantized. Default is OUTPUT. + rhs_dequant_mode: This decides where rhs is dequantized. Default is OUTPUT. Returns: Dequantized output of dot_general. @@ -86,9 +86,9 @@ def dot_general( # The code below is supposed to be executed inside pallas kernel. - # When only one of operands is quantized, Jax impliclty cast int8 into float - # and performs dot_general. However, pallas requires explict casting when only - # one of operands is quantized. + # When only one of operands is quantized, Jax implicitly cast int8 into float + # and performs dot_general. However, pallas requires explicit casting when + # only one of operands is quantized. is_both_quantized = isinstance(lhs, QTensor) and isinstance(rhs, QTensor) if isinstance(lhs, QTensor) and not is_both_quantized: promoted_dtype = jnp.promote_types(lhs.dequant_dtype, rhs) @@ -99,11 +99,11 @@ def dot_general( if isinstance(lhs, jax.Array): lhs = QTensor( - qvalue=lhs, scale=[], scale_t=None, dequant_dtype=lhs.dtype + qvalue=lhs, scale=[], scale_t=None, bias=[], dequant_dtype=lhs.dtype ) if isinstance(rhs, jax.Array): rhs = QTensor( - qvalue=rhs, scale=[], scale_t=None, dequant_dtype=rhs.dtype + qvalue=rhs, scale=[], scale_t=None, bias=[], dequant_dtype=rhs.dtype ) if preferred_element_type is None: diff --git a/aqt/jax/v2/pallas/pallas_tensor.py b/aqt/jax/v2/pallas/pallas_tensor.py index 7aae4f27..f43601b7 100644 --- a/aqt/jax/v2/pallas/pallas_tensor.py +++ b/aqt/jax/v2/pallas/pallas_tensor.py @@ -160,5 +160,6 @@ def scale_index_map( qvalue=block_spec, scale=[_make_scale_block_spec(s, block_spec) for s in qtensor.scale], scale_t=None, + bias=[], dequant_dtype=qtensor.dequant_dtype, ) diff --git a/aqt/jax/v2/pallas/pallas_tensor_test.py b/aqt/jax/v2/pallas/pallas_tensor_test.py index f97fd4a1..1998b004 100644 --- a/aqt/jax/v2/pallas/pallas_tensor_test.py +++ b/aqt/jax/v2/pallas/pallas_tensor_test.py @@ -72,6 +72,7 @@ def test_qtensor_blockspec_correctness( qvalue=jnp.ones(qvalue_shape, dtype=jnp.int8), scale=[jnp.ones(scale_shape, dtype=jnp.float32)], scale_t=None, + bias=[], dequant_dtype=jnp.float32, ) block_spec = pl.BlockSpec(block_shape, lambda *args: args)