From 021c7e5a7917e488d72179fb5fcb6d570af9d6e4 Mon Sep 17 00:00:00 2001 From: Liubov Talamanova Date: Fri, 12 Jul 2024 16:31:21 +0100 Subject: [PATCH] Support per-channel mode in AWQ, GPTQ and Scale Estimation algos (#2785) ### Changes * Support `group_size=-1` in AWQ, GPTQ and Scale Estimation algorithms ### Reason for changes * To enable more accurate per-channel quantization for devices that do not support group quantization ### Related tickets * 145725 ### Tests tests/openvino/native/quantization/test_weights_compression.py --- .../weights_compression/Usage.md | 2 +- .../algorithms/weight_compression/gptq.py | 17 ++++++++++++--- .../weight_compression/scale_estimation.py | 21 ++++++++++--------- nncf/quantization/quantize_model.py | 8 +++---- .../quantization/test_weights_compression.py | 20 ++++++++++++++---- .../post_training/data/wc_reference_data.yaml | 4 ++++ tests/post_training/model_scope.py | 12 +++++++++++ 7 files changed, 62 insertions(+), 22 deletions(-) diff --git a/docs/usage/post_training_compression/weights_compression/Usage.md b/docs/usage/post_training_compression/weights_compression/Usage.md index fbbc6a23305..faa8ad43473 100644 --- a/docs/usage/post_training_compression/weights_compression/Usage.md +++ b/docs/usage/post_training_compression/weights_compression/Usage.md @@ -40,7 +40,7 @@ compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM) # - Generally, `INT4_SYM` mode is the fastest mixed-precision mode, but it may lead to a significant accuracy degradation or perplexity increase. Compressing weights asymmetrically (`INT4_ASYM` mode) is the way to increase accuracy, however in turns it slows down inference a bit. If the accuracy or perplexity is still not satisfying, there are 2 more hyper-parameters to tune: `group_size` and `ratio`. Please refer to the [example](https://github.com/openvinotoolkit/nncf/blob/develop/examples/llm_compression/openvino/tiny_llama_find_hyperparams) how to automatically tune these parameters. - Lower group size and less ratio of 4-bit layers usually improve accuracy at the sacrifice of inference speed. + Lower group size and less ratio of 4-bit layers usually improve accuracy at the sacrifice of inference speed. To disable grouped quantization and quantize weights per-channel, set `group_size = -1`. Below is the example how to compress weights of 90% of layers to 4-bit integer asymmetrically with the group size 64, and the rest of layers to 8-bit asymmetric integer data type. The same parametrization is applicable for `INT4_SYM` mode. diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index 721db4fbf29..fe179aa1957 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -119,7 +119,10 @@ def apply( ) for node, inputs in track(target_node_iterator, total=len(target_nodes), description="Applying GPTQ"): wc_params = target_nodes_wc_params_map[node] - if wc_params.compression_config.group_size == -1: + if wc_params.compression_config.mode in [ + CompressWeightsMode.INT8_ASYM, + CompressWeightsMode.INT8_SYM, + ]: continue assert len(inputs) == 1 _, input_tensors = next(iter(inputs.items())) @@ -222,7 +225,11 @@ def _quantize_weights( quantized_tensor = fns.zeros_like(weight_tensor) columns = hessian.shape[0] - group_size = wc_params.compression_config.group_size + group_size = ( + wc_params.compression_config.group_size + if wc_params.compression_config.group_size != -1 + else weight_tensor.shape[1] + ) reduction_axes = wc_params.reduction_axes block_compression_config = WeightCompressionConfig(mode=wc_params.compression_config.mode) @@ -248,7 +255,7 @@ def _quantize_weights( weight_col = weight_block[:, i] hessian_diag_val = hessian_inv_block[i, i] - if group_size != -1 and (i1 + i) % group_size == 0: + if (i1 + i) % group_size == 0: if block_compression_config.mode == CompressWeightsMode.NF4: scale = calculate_nf4_scale(weight_tensor[:, (i1 + i) : (i1 + i + group_size)], reduction_axes) scales.append(scale) @@ -287,11 +294,15 @@ def _quantize_weights( ) scales = fns.stack(scales, axis=1) + if wc_params.compression_config.group_size == -1: + scales = fns.squeeze(scales, axis=-1) if wc_params.compression_config.mode in [ CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM, ]: zero_points = fns.stack(zero_points, axis=1) + if wc_params.compression_config.group_size == -1: + zero_points = fns.squeeze(zero_points, axis=-1) else: zero_points = None return scales, zero_points diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 7ac5eb2aef1..557db0886ca 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -132,9 +132,6 @@ def apply( stats = self._activations[node_name] reduction_axis = wp.reduction_axes[0] - cur_config = deepcopy(config) - cur_config.group_size = -1 - weight_data = self._backend_entity.get_weight_names_and_port_ids(wp.node_with_weight, graph) if len(weight_data) != 1: # not supported by the algorithm continue @@ -162,19 +159,21 @@ def apply( weight = fns.transpose(weight) reduction_axis = 1 + group_size = config.group_size if config.group_size != -1 else weight.shape[reduction_axis] + cur_config = deepcopy(config) + cur_config.group_size = group_size + original_weight = fns.zeros_like(weight) + weight - compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, config) + compressed_weights, scale, zp = do_integer_quantization(original_weight, reduction_axis, cur_config) if zp is not None: zp = zp.astype(scale.dtype) q_weights = do_dequantization(compressed_weights, scale, zp, reduction_axis) s = fns.unsqueeze(s, 0) - s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, config.group_size) + s, _ = reshape_weight_for_grouped_quantization(s, reduction_axis, group_size) - original_weight, _ = reshape_weight_for_grouped_quantization( - original_weight, reduction_axis, config.group_size - ) + original_weight, _ = reshape_weight_for_grouped_quantization(original_weight, reduction_axis, group_size) # all weight in group has importance based on corresponding input activations importance = fns.ones_like(original_weight) @@ -187,8 +186,8 @@ def apply( denum = fns.sum(importance, axis=2, keepdims=True) importance = importance / (denum + eps) - X, _ = reshape_weight_for_grouped_quantization(X, 0, config.group_size) - q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, config.group_size) + X, _ = reshape_weight_for_grouped_quantization(X, 0, group_size) + q_weights, _ = reshape_weight_for_grouped_quantization(q_weights, reduction_axis, group_size) best_diffs = None result_scale = None @@ -298,6 +297,8 @@ def apply( near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale result_scale = near_to_ideal_scale + if config.group_size == -1: + result_scale = fns.squeeze(result_scale, axis=1) res[weight_name] = result_scale return res diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index a395233a0fb..65b804f10a8 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -457,13 +457,13 @@ def compress_weights( from nncf.openvino.quantization.quantize_model import compress_weights_impl as ov_compress_weights_impl if any((awq, scale_estimation)) and ( - dataset is None or mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] or group_size == -1 + dataset is None or mode in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1] ): raise AttributeError( - "Scale estimation or AWQ algorithm defined, but dataset is None or mode is NF4 or group_size < 0." + "Scale estimation or AWQ algorithm defined, but dataset is None or mode is (NF4 or E2M1)." ) - if gptq and (dataset is None or group_size == -1 or mode == CompressWeightsMode.E2M1): - raise AttributeError("GPTQ algorithm defined, but dataset is None or group_size < 0 or mode is E2M1.") + if gptq and (dataset is None or mode == CompressWeightsMode.E2M1): + raise AttributeError("GPTQ algorithm defined, but dataset is None or mode is E2M1.") if gptq and scale_estimation: raise AttributeError( diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index c9547d38493..f313a4f2a46 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -747,8 +747,7 @@ def test_call_max_var_criterion_with_dataset_awq_for_compressed_model(mode): def test_call_max_var_criterion_with_dataset_awq_neg_group_size(mode): model = AWQMatmulModel().ov_model dataset = Dataset([np.ones([8, 8])]) - with pytest.raises(AttributeError): - compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True) + compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, awq=True) def test_data_type_for_num_weights(mocker): @@ -857,8 +856,7 @@ def test_call_max_var_criterion_with_dataset_scale_estimation_neg_group_size(mod model = AWQMatmulModel().ov_model dataset = Dataset([np.ones([8, 8])]) - with pytest.raises(AttributeError): - compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, scale_estimation=True) + compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, scale_estimation=True) @pytest.mark.parametrize("mode", INT4_NF4_MODES) @@ -943,3 +941,17 @@ def test_np_ov_compression_decompression(mode): assert np.allclose(compressed_weighs, compressed_weighs_ov) assert np.allclose(decompressed_weighs, decompressed_weighs_ov) + + +@pytest.mark.parametrize("mode", INT4_NF4_MODES) +def test_call_max_var_criterion_with_dataset_gptq_neg_group_size(mode): + model = AWQMatmulModel().ov_model + sz = 8 + dataset = Dataset([np.ones([sz, sz])]) + + compressed_model = compress_weights(model, mode=mode, ratio=1.0, group_size=-1, dataset=dataset, gptq=True) + + for op in compressed_model.get_ordered_ops(): + op_name = op.get_friendly_name() + if op.get_type_name() == "Constant" and ("/zero_point" in op_name or "/scale" in op_name): + assert op.get_shape() == [sz, 1] diff --git a/tests/post_training/data/wc_reference_data.yaml b/tests/post_training/data/wc_reference_data.yaml index a3813e84ff6..75f704663d4 100644 --- a/tests/post_training/data/wc_reference_data.yaml +++ b/tests/post_training/data/wc_reference_data.yaml @@ -26,3 +26,7 @@ tinyllama_data_aware_gptq_backend_OV: metric_value: 0.83706 num_int4: 94 num_int8: 124 +tinyllama_scale_estimation_per_channel_backend_OV: + metric_value: 0.7435 + num_int4: 188 + num_int8: 124 diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 6527e62f43c..e4d2694e7d3 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -403,6 +403,18 @@ }, "backends": [BackendType.OV], }, + { + "reported_name": "tinyllama_scale_estimation_per_channel", + "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b", + "pipeline_cls": LMWeightCompression, + "compression_params": { + "group_size": -1, + "ratio": 0.8, + "mode": CompressWeightsMode.INT4_ASYM, + "scale_estimation": True, + }, + "backends": [BackendType.OV], + }, ]