diff --git a/nncf/torch/quantization/weights_compression.py b/nncf/torch/quantization/weights_compression.py index 2d191333beb..9fc725fb235 100644 --- a/nncf/torch/quantization/weights_compression.py +++ b/nncf/torch/quantization/weights_compression.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Dict, List, Optional import torch from torch import nn @@ -39,7 +39,7 @@ def forward(self, layer, op_arg): def _insert_pre_compression_operations( - module: nn.Module, allowed_types: List, level_high: int = 255 + module: nn.Module, allowed_types: List, level_high: int = 255, compression_hist: Dict = None ) -> Optional[nn.Module]: """ Inserts weights compression with dequantization for layers in `allowed_types`. @@ -47,12 +47,22 @@ def _insert_pre_compression_operations( :param module: The module to insert the weights compression. :param allowed_types: list of allowed types for weights compression. :param level_high: highest possible value of compressed weights (lower is 0 in assymetric quantization). + :param compression_hist: mapping between layer weight and corresponding WeightsDecompressor for finding + shared weights. :return: The non-trainable module with inserted operations. """ + if compression_hist is None: + compression_hist = {} for _, layer in module.named_children(): if not type(layer) in allowed_types: - _insert_pre_compression_operations(layer, allowed_types, level_high) + _insert_pre_compression_operations(layer, allowed_types, level_high, compression_hist) continue + + if layer.weight.dtype in [torch.uint8, torch.int8]: + if layer.weight in compression_hist: + layer.register_pre_forward_operation(compression_hist[layer.weight]) + continue + target_dim = layer.target_weight_dim_for_compression stat_dim = (target_dim + 1) % 2 input_low = torch.min(layer.weight, dim=stat_dim).values.detach() @@ -61,7 +71,7 @@ def _insert_pre_compression_operations( scale = scale.unsqueeze(stat_dim) zero_point = zero_point.unsqueeze(stat_dim) - layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale)) + key = layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale)) compressed_weight = layer.weight.data / scale + zero_point compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high) @@ -69,6 +79,8 @@ def _insert_pre_compression_operations( layer.weight.requires_grad = False layer.weight.data = compressed_weight.type(dtype=torch.uint8) + compression_hist[layer.weight] = layer.get_pre_op(key) + def insert_pre_compression_operations(module: nn.Module, bits: int = 8) -> Optional[nn.Module]: """ diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 72427191abe..e71394e9284 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -15,12 +15,15 @@ class ShortTransformer(torch.nn.Module): - def __init__(self, in_features, num_embeddings): + def __init__(self, in_features, num_embeddings, share_weights=False): super().__init__() self.wte = torch.nn.Embedding(num_embeddings, in_features) self.linear = torch.nn.Linear(in_features, in_features) self.lm_head = torch.nn.Linear(in_features, num_embeddings) + if share_weights: + self.lm_head.weight = self.wte.weight + def forward(self, input_ids): x = self.wte(input_ids) x = self.linear(x) @@ -43,3 +46,27 @@ def test_compress_weights(): n_compressed_weights += 1 assert n_compressed_weights == n_target_modules + + +def test_compress_shared_weights(): + model = ShortTransformer(5, 10, share_weights=True) + + compressed_model = compress_weights(model) + + n_compressed_weights = 0 + n_target_modules = 0 + + for _, module in compressed_model.named_children(): + if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): + n_target_modules += 1 + if module.weight.dtype in [torch.uint8, torch.int8]: + n_compressed_weights += 1 + + assert n_compressed_weights == n_target_modules + + assert len(compressed_model.wte.pre_ops) > 0 + + assert len(compressed_model.wte.pre_ops) == len(compressed_model.lm_head.pre_ops) + + for key, val in compressed_model.wte.pre_ops.items(): + assert compressed_model.lm_head.get_pre_op(key) is val