Skip to content

Commit

Permalink
Fixed problem with shared weights in compression. (#2110)
Browse files Browse the repository at this point in the history
### Changes

Fixed problem with shared weights in compression.

### Reason for changes

Problem with some LLMs with shared weights.

### Related tickets


### Tests
  • Loading branch information
andreyanufr authored Sep 11, 2023
1 parent 9d8ed96 commit 25968cd
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 5 deletions.
20 changes: 16 additions & 4 deletions nncf/torch/quantization/weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import Dict, List, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -39,20 +39,30 @@ def forward(self, layer, op_arg):


def _insert_pre_compression_operations(
module: nn.Module, allowed_types: List, level_high: int = 255
module: nn.Module, allowed_types: List, level_high: int = 255, compression_hist: Dict = None
) -> Optional[nn.Module]:
"""
Inserts weights compression with dequantization for layers in `allowed_types`.
:param module: The module to insert the weights compression.
:param allowed_types: list of allowed types for weights compression.
:param level_high: highest possible value of compressed weights (lower is 0 in assymetric quantization).
:param compression_hist: mapping between layer weight and corresponding WeightsDecompressor for finding
shared weights.
:return: The non-trainable module with inserted operations.
"""
if compression_hist is None:
compression_hist = {}
for _, layer in module.named_children():
if not type(layer) in allowed_types:
_insert_pre_compression_operations(layer, allowed_types, level_high)
_insert_pre_compression_operations(layer, allowed_types, level_high, compression_hist)
continue

if layer.weight.dtype in [torch.uint8, torch.int8]:
if layer.weight in compression_hist:
layer.register_pre_forward_operation(compression_hist[layer.weight])
continue

target_dim = layer.target_weight_dim_for_compression
stat_dim = (target_dim + 1) % 2
input_low = torch.min(layer.weight, dim=stat_dim).values.detach()
Expand All @@ -61,14 +71,16 @@ def _insert_pre_compression_operations(

scale = scale.unsqueeze(stat_dim)
zero_point = zero_point.unsqueeze(stat_dim)
layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale))
key = layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale))

compressed_weight = layer.weight.data / scale + zero_point
compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high)

layer.weight.requires_grad = False
layer.weight.data = compressed_weight.type(dtype=torch.uint8)

compression_hist[layer.weight] = layer.get_pre_op(key)


def insert_pre_compression_operations(module: nn.Module, bits: int = 8) -> Optional[nn.Module]:
"""
Expand Down
29 changes: 28 additions & 1 deletion tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@


class ShortTransformer(torch.nn.Module):
def __init__(self, in_features, num_embeddings):
def __init__(self, in_features, num_embeddings, share_weights=False):
super().__init__()
self.wte = torch.nn.Embedding(num_embeddings, in_features)
self.linear = torch.nn.Linear(in_features, in_features)
self.lm_head = torch.nn.Linear(in_features, num_embeddings)

if share_weights:
self.lm_head.weight = self.wte.weight

def forward(self, input_ids):
x = self.wte(input_ids)
x = self.linear(x)
Expand All @@ -43,3 +46,27 @@ def test_compress_weights():
n_compressed_weights += 1

assert n_compressed_weights == n_target_modules


def test_compress_shared_weights():
model = ShortTransformer(5, 10, share_weights=True)

compressed_model = compress_weights(model)

n_compressed_weights = 0
n_target_modules = 0

for _, module in compressed_model.named_children():
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
n_target_modules += 1
if module.weight.dtype in [torch.uint8, torch.int8]:
n_compressed_weights += 1

assert n_compressed_weights == n_target_modules

assert len(compressed_model.wte.pre_ops) > 0

assert len(compressed_model.wte.pre_ops) == len(compressed_model.lm_head.pre_ops)

for key, val in compressed_model.wte.pre_ops.items():
assert compressed_model.lm_head.get_pre_op(key) is val

0 comments on commit 25968cd

Please sign in to comment.