diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 342725c0237..38e5979e355 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -18,6 +18,7 @@ from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.fake_quantize import calculate_scale_zero_point +from nncf.quantization.fake_quantize import calculate_zero_point from nncf.tensor import Tensor from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType @@ -249,7 +250,10 @@ def calculate_normalized_weight_and_fp4_scale( def calculate_integer_quantization_params( - weight: Tensor, reduction_axes: ReductionAxes, config: WeightCompressionConfig + weight: Tensor, + reduction_axes: ReductionAxes, + config: WeightCompressionConfig, + precompute_scale: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Calculates the scale and zero point for uniform quantization (INT4, INT8), when the range of values is divided into @@ -258,6 +262,7 @@ def calculate_integer_quantization_params( :param weight: Weight array to compress. :param reduction_axes: Axes, along which to reduce (collect) different statistics (e.g. min, max). :param config: Weight compression configuration. + :param precompute_scale: Optional precomputed scale. :return: Scale and zero point tensors. """ mode = config.mode @@ -271,10 +276,14 @@ def calculate_integer_quantization_params( level_low = 0 level_high = 2**num_bits - 1 min_values = fns.min(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] - max_values = fns.max(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] - scale, zero_point = calculate_scale_zero_point( - min_values, max_values, level_low, level_high, narrow_range=False - ) + if precompute_scale is None: + max_values = fns.max(weight, axis=reduction_axes, keepdims=True) # [a1, r, a2] -> [a1, 1, a2] + scale, zero_point = calculate_scale_zero_point( + min_values, max_values, level_low, level_high, narrow_range=False + ) + else: + scale = precompute_scale + zero_point = calculate_zero_point(scale, min_values, level_low, level_high, narrow_range=False) return scale, zero_point scale = calculate_signed_scale(weight, reduction_axes, num_bits) @@ -366,8 +375,10 @@ def do_int_quantization( # weights are reshaped from [a1, r, a2] to [a1, r//gs, gs, a2] weight, reduction_axes = reshape_weight_for_grouped_quantization(weight, reduction_axes, group_size) - if precomputed_zero_point is None or precomputed_zero_point is None: - scale, zero_point = calculate_integer_quantization_params(weight, reduction_axes, config) + is_asym = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM] + zero_point = None + if precomputed_scale is None or (is_asym and precomputed_zero_point is None): + scale, zero_point = calculate_integer_quantization_params(weight, reduction_axes, config, precomputed_scale) if precomputed_scale is not None: scale = precomputed_scale if precomputed_zero_point is not None: diff --git a/nncf/quantization/fake_quantize.py b/nncf/quantization/fake_quantize.py index d5a3e96ae64..017c34124a3 100644 --- a/nncf/quantization/fake_quantize.py +++ b/nncf/quantization/fake_quantize.py @@ -359,7 +359,27 @@ def calculate_scale_zero_point( eps = fns.finfo(scale).eps # NOTE: adding machine epsilon to avoid division by zero scale = fns.where(fns.abs(scale) < eps, eps, scale) + zero_point = calculate_zero_point(scale, input_low, level_low, level_high, narrow_range) + return scale, zero_point + + +def calculate_zero_point( + scale: Tensor, input_low: Tensor, level_low: int, level_high: int, narrow_range: bool +) -> Tensor: + """ + Calculates zero_point values for the quantizer. + + :param scale: Pre-calculated scale value. + :param input_low: The minimum limit for an input value based on collected statistics. + :param level_low: The minimum level in the integer range to quantize. + The default is "0" for an unsigned range, and "-2^(bit-1)" for a signed one . + :param level_high: The maximum level in the integer range to quantize. + The default is "2^bits-1" for an unsigned range, and "2^(bit-1)-1" for a signed one. + :param narrow_range: True if the range of quantized values is narrowed as compared to the + naive case, False otherwise. + :return: Zero point value. + """ expected_level_low = level_low + 1 if narrow_range else level_low zero_point = expected_level_low - fns.round(input_low / scale) zero_point = fns.clip(zero_point.astype(TensorDataType.int32), level_low, level_high) - return scale, zero_point + return zero_point diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index edc50652710..d34d899241a 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -11,6 +11,7 @@ import inspect import os +import unittest.mock from typing import Callable, List import numpy as np @@ -36,10 +37,12 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA from nncf.quantization.algorithms.weight_compression.openvino_backend import OVWeightCompressionAlgoBackend +from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization from nncf.quantization.algorithms.weight_compression.weight_lowering import get_integer_quantization_error from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization +from nncf.quantization.fake_quantize import calculate_zero_point from nncf.scopes import IgnoredScope from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -1072,6 +1075,42 @@ def test_compressed_weighs_range(mode, data): assert np.allclose(np.abs(compressed_weighs.data), np.abs(w.data)) +@pytest.mark.parametrize("mode", INT4_MODES + INT8_MODES) +def test_compress_weights_with_precomputed_scale(mode): + weight = ((np.arange(11) - 5) / 10).astype(np.float32) + precomputed_scale = -((np.arange(11) - 5) / 100).astype(np.float32) + weight, precomputed_scale = Tensor(weight[:, None]), Tensor(precomputed_scale[:, None]) + + config = WeightCompressionConfig(mode=mode) + is_asym = config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM] + with unittest.mock.patch( + "nncf.quantization.algorithms.weight_compression.weight_lowering.calculate_integer_quantization_params", + side_effect=calculate_integer_quantization_params, + ) as mock_calc_params: + with unittest.mock.patch( + "nncf.quantization.algorithms.weight_compression.weight_lowering.calculate_zero_point", + side_effect=calculate_zero_point, + ) as mock_calc_zp: + _, scale, zp_from_weight_and_scale = do_int_quantization(weight, -1, config, precomputed_scale) + if is_asym: + # For asymmetric quantization we should calculate only the new zero point + mock_calc_params.assert_called_once() + mock_calc_zp.assert_called_once() + else: + # For symmetric nothing needs to be computed + mock_calc_params.assert_not_called() + mock_calc_zp.assert_not_called() + + _, _, zp_from_weight = do_int_quantization(weight, -1, config) + + if is_asym: + # Zero points obtained with pre-computed scale and without it must differ + assert not np.allclose(zp_from_weight_and_scale.data, zp_from_weight.data) + else: + assert zp_from_weight_and_scale is None and zp_from_weight is None + assert np.allclose(scale.data, precomputed_scale.data) + + @pytest.mark.parametrize("mode", INT4_NF4_MODES) def test_call_max_var_criterion_with_dataset_gptq_neg_group_size(mode): model = AWQMatmulModel().ov_model