Skip to content

Commit

Permalink
Weight compression via Lora Correction Algorithm (#2816)
Browse files Browse the repository at this point in the history
### Changes

Lora Correction algorithm for int4/nf4 weight compression. 

### Reason for changes

Method for improving accuracy by migrating quantization noise to
“learnable” lora adapters.

### Related tickets

135863

### Tests

- [x] docstrings, proper names
- [x] results for phi3 and stablelm2-1.6b on lambada, wikitext
- [x] job/NNCF/job/manual/job/post_training_weight_compression/144/

![image](https://github.com/user-attachments/assets/93721a8f-a0c5-4852-9d79-7b281cf2fe67)

![image](https://github.com/user-attachments/assets/2e3bd797-8535-4fbb-88da-2dbd92964d50)

![image](https://github.com/user-attachments/assets/8436009f-2827-4ea1-b481-e3f89bd35aef)


![image](https://github.com/user-attachments/assets/3aa290c9-5f41-4933-b0de-160fac3cce2a)
  • Loading branch information
ljaljushkin authored Aug 28, 2024
1 parent ec25a29 commit 417c2a1
Show file tree
Hide file tree
Showing 22 changed files with 949 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
```

- Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation or GPTQ algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision. The AWQ algorithm can be used in conjunction with either the Scale Estimation or GPTQ algorithm. However, Scale Estimation and GPTQ algorithms are mutually exclusive and cannot be used together. Below are examples demonstrating how to enable the AWQ, Scale Estimation or GPTQ algorithms:
- Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation, GPTQ or Lora Correction algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision.
Unlike all others, the Lora Correction algorithm inserts an additional Linear layers for reducing quantization noise and further accuracy improvement. Inevitably, this approach introduces a memory and a runtime overheads, but they are negligible, since the inserted weight much smaller and can be quantized to 8-bit. The AWQ, Scale Estimation (SE) and Lora Correction (LC) algo can be used in any combination together: AWQ + SE, AWQ + LC, SE + LC, AWQ + SE + LC. The GPTQ algorithm can be combined with AWQ only. Below are examples demonstrating how to enable the AWQ, Scale Estimation, GPTQ or Lora Correction algorithms:

Prepare the calibration dataset for data-based algorithms:

Expand Down Expand Up @@ -135,6 +136,16 @@ model.model = compress_weights(model.model,
gptq=True)
```

- How to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and Lora Correction algorithm. It requires setting `lora_correction` to `True` additionally to data-based mixed-precision algorithm.

```python
model.model = compress_weights(model.model,
mode=CompressWeightsMode.INT4_SYM,
ratio=0.8,
dataset=nncf_dataset,
lora_correction=True)
```

- `NF4` mode can be considered for improving accuracy, but currently models quantized to nf4 should not be faster models
quantized to 8-bit asymmetric integer. Here's the example how to compress weights to nf4 data type with group size = 128.
Different `group_size` and `ratio` are also supported.
Expand Down
5 changes: 5 additions & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,13 @@
from nncf.quantization.advanced_parameters import (
AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters,
)
from nncf.quantization.advanced_parameters import AdvancedAWQParameters as AdvancedAWQParameters
from nncf.quantization.advanced_parameters import AdvancedBiasCorrectionParameters as AdvancedBiasCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import AdvancedGPTQParameters as AdvancedGPTQParameters
from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters as AdvancedLoraCorrectionParameters
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters as AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import AdvancedScaleEstimationParameters as AdvancedScaleEstimationParameters
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters as AdvancedSmoothQuantParameters
from nncf.quantization.advanced_parameters import OverflowFix as OverflowFix
from nncf.scopes import IgnoredScope as IgnoredScope
Expand Down
2 changes: 2 additions & 0 deletions nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def compress_weights_impl(
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
) -> ov.Model:
"""
Expand All @@ -455,6 +456,7 @@ def compress_weights_impl(
subset_size,
scale_estimation,
gptq,
lora_correction,
advanced_parameters,
)
graph = NNCFGraphFactory.create(model)
Expand Down
30 changes: 30 additions & 0 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,33 @@ class AdvancedGPTQParameters:
subset_size: int = 128


@api()
@dataclass
class AdvancedLoraCorrectionParameters:
"""
Contains advanced parameters for lora correction algorithm.
:param adapter_rank: rank of lora adapters. Defaults to 16.
:type adapter_rank: int
:param num_iterations: number of correction iterations. Defaults to 3.
:type num_iterations: int
:param apply_regularization: Whether to add a regularization during the correction process. Defaults to True.
Helpful for big rank values to avoid overfitting.
:type apply_regularization: bool
:param subset_size: Number of data samples for lora correction algorithm. Defaults to 128.
:type subset_size: int
:param use_int8_adapters: Whether to 8-bit quantize lora adapters, otherwise they kept in the original weights
precision. Defaults to True.
:type use_int8_adapters: bool
"""

adapter_rank: int = 8
num_iterations: int = 3
apply_regularization: bool = True
subset_size: int = 128
use_int8_adapters: bool = True


@api()
@dataclass
class AdvancedCompressionParameters:
Expand All @@ -337,6 +364,9 @@ class AdvancedCompressionParameters:
# Advanced GPTQ algorithm parameters
gptq_params: AdvancedGPTQParameters = field(default_factory=AdvancedGPTQParameters)

# Advanced Lora Correction algorithm parameters
lora_correction_params: AdvancedLoraCorrectionParameters = field(default_factory=AdvancedLoraCorrectionParameters)


@api()
@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, TypeVar

from nncf.tensor import functions as fns

TTensor = TypeVar("TTensor")


def process_stats(stats: List[TTensor], subset_size: int) -> Tuple[TTensor, TTensor]:
"""
It's a processing of activations shared between AWQ, Scale Estimation and LoRA Correction algorithms.
:param stats: list of activation statistics for a layer that contains N tensors with shape [SeqLen, HiddenDim]
:type stats: List[TTensor]
:param subset_size: The number of samples for AWQ.
:type subset_size: int
:return: tuple of the following tensors:
s - maximum channel magnitude across samples [HiddenDim]
X - average channel magnitude across tokens in the sequence [HiddenDim, SampleSize]
:rtype: Tuple[TTensor, TTensor]
"""
X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) # [Batch, HiddenDim]
X_full = fns.transpose(X) # [HiddenDim, Batch]

# prevent high memory and time consumption
if X_full.shape[1] > subset_size:
lens = [stat.shape[0] for stat in stats]
step = X_full.shape[1] // subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X_full[:, idxs] # [HiddenDim, SampleSize]
else:
X = X_full
s = fns.max(fns.abs(X_full), axis=1) # [HiddenDim]
return s, X
17 changes: 16 additions & 1 deletion nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
from nncf.parameters import CompressWeightsMode
from nncf.parameters import SensitivityMetric
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters
from nncf.quantization.advanced_parameters import convert_to_dict_recursively
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.awq import AWQ
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.gptq import GPTQ
from nncf.quantization.algorithms.weight_compression.lora_correction import LoraCorrectionAlgorithm
from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA
from nncf.quantization.algorithms.weight_compression.scale_estimation import ScaleEstimation
from nncf.quantization.algorithms.weight_compression.weight_lowering import WeightCompressionConfig
Expand Down Expand Up @@ -65,6 +67,7 @@ def __init__(
subset_size: int,
scale_estimation: bool,
gptq: bool,
lora_correction: bool,
advanced_parameters: Optional[AdvancedCompressionParameters] = None,
):
"""
Expand Down Expand Up @@ -97,6 +100,7 @@ def __init__(
quantization precision.
:param scale_estimation: determines whether to use or not scale estimation for 4 bit layers.
:param gptq: determines whether to use or not GPTQ algorithm.
:param lora_correction: determines whether to use or not LoRA Correction algorithm.
:param advanced_parameters: advanced parameters for algorithms in compression pipeline.
"""
super().__init__()
Expand All @@ -113,6 +117,7 @@ def __init__(
self._subset_size = subset_size
self._scale_estimation = scale_estimation
self._gptq = gptq
self._lora_correction = lora_correction
self._advanced_parameters = (
advanced_parameters if advanced_parameters is not None else AdvancedCompressionParameters()
)
Expand Down Expand Up @@ -403,6 +408,13 @@ def apply(
backend_entity=self._backend_entity,
)

lora_correction_algo = None
description = "Applying Weight Compression"
if self._lora_correction:
lora_correction_params = self._advanced_parameters.lora_correction_params
lora_correction_algo = LoraCorrectionAlgorithm(activations, lora_correction_params)
description += " with correction of low-rank adapters"

# Sort weight params to start compression with the bigger constants. This lowers peak memory footprint.
all_weight_params = sorted(all_weight_params, key=lambda wp: wp.num_weights, reverse=True)
all_weight_sizes = [wp.num_weights for wp in all_weight_params]
Expand All @@ -411,9 +423,10 @@ def apply(
transformed_model = self._backend_entity.transform_model(
model,
graph,
track(all_weight_params, description="Applying Weight Compression", weights=all_weight_sizes),
track(all_weight_params, description=description, weights=all_weight_sizes),
scales,
zero_points,
lora_correction_algo,
)

self._backend_entity.dump_parameters(
Expand All @@ -428,6 +441,8 @@ def apply(
"awq": self._awq,
"scale_estimation": self._scale_estimation,
"gptq": self._gptq,
"lora_correction": self._lora_correction,
"advanced_parameters": convert_to_dict_recursively(self._advanced_parameters),
},
algo_name="weight_compression",
)
Expand Down
24 changes: 6 additions & 18 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,10 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.passes import transform_to_inference_graph
from nncf.tensor import functions as fns

Expand Down Expand Up @@ -101,9 +102,6 @@ def _set_backend_entity(self, model: TModel) -> None:
Creates a helper class with a backed-specific logic of the algorithm.
:param model: Backend-specific input model.
:param all_weight_params: List of all weight parameters.
:param nodes_to_compress: List of nodes for processing.
:param activations: The input activations of the layers considered for compression.
"""

model_backend = get_backend(model)
Expand Down Expand Up @@ -197,17 +195,7 @@ def apply(

config = wp.compression_config

stats = self._activations[k]
X = fns.stack([fns.mean(stat, axis=0) for stat in stats])
X = fns.transpose(X)

s = fns.max(fns.abs(X), axis=1)

if X.shape[1] > self._subset_size:
lens = [stat.shape[0] for stat in stats]
step = X.shape[1] // self._subset_size
idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step]
X = X[:, idxs]
s, X = process_stats(self._activations[k], self._subset_size)

top_k = max(int(s.shape[0] * self._percent_to_apply), 1)
topk_idxs = fns.argsort(-s)[:top_k]
Expand Down Expand Up @@ -257,10 +245,10 @@ def apply(
for _ in range(self._steps):
cur_scale = gscale**alpha

g_compressed_weighs, g_c_scale, g_c_zp = do_integer_quantization(
g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
gweight * cur_scale, reduction_axis, awq_config
)
g_decompressed_weighs = do_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand Down
28 changes: 28 additions & 0 deletions nncf/quantization/algorithms/weight_compression/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,34 @@ def transform_model(
:return: The transformed model.
"""

@abstractmethod
def insert_adapters(
self, wc_params: WeightCompressionParameters, lora_A: Tensor, lora_B: Tensor, int8_lora: bool
) -> None:
"""
Expands a model's execution graph following the Low-Rank Adaptation (LoRA) concept.
It inserts two additional Linear layers with weight matrices of low rank that are executed in parallel to the
target Linear layer.
Before insertion:
----INPUT
\
orig.MM--------------------------------OUTPUT
After insertion:
----INPUT ----lora_A.MM----lora_B.MM----\
\ add----OUTPUT
orig.MM--------------------------/
:param wc_params: Parameters for weight compression.
:param lora_A: weights for the first LoRA matrix.
:param lora_B: weights for the second LoRA matrix.
:param int8_lora: indicates whether the LoRA matrices should be compressed to 8-bit.
"""

@staticmethod
@abstractmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
Expand Down
14 changes: 8 additions & 6 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_integer_quantization_params
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import decompress_nf4_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.tensor import Tensor
from nncf.tensor import functions as fns
from nncf.tensor.definitions import TensorDataType
Expand Down Expand Up @@ -266,13 +266,15 @@ def _quantize_weights(
scales.append(scale)
zero_points.append(zero_point)
if block_compression_config.mode == CompressWeightsMode.NF4:
compressed_weights = calculate_nf4_weight(fns.unsqueeze(weight_col, 1), scales[-1])
quantized_col = decompress_nf4_weight(compressed_weights, scales[-1])
compressed_weights = do_nf4_quantization(
fns.unsqueeze(weight_col, 1), scales[-1], is_normalized_weight=False
)
quantized_col = do_nf4_dequantization(compressed_weights, scales[-1], reduction_axis=-1)
else:
compressed_weights = calculate_quantized_weight(
fns.unsqueeze(weight_col, 1), block_compression_config, scales[-1], zero_points[-1]
)
quantized_col = do_dequantization(compressed_weights, scales[-1], zero_points[-1])
quantized_col = do_int_dequantization(compressed_weights, scales[-1], zero_points[-1])
quantized_col = fns.flatten(quantized_col)
quantized_block[:, i] = quantized_col
loss_block[:, i] = (weight_col - quantized_col) ** 2 / hessian_diag_val**2
Expand Down
Loading

0 comments on commit 417c2a1

Please sign in to comment.