Skip to content

Commit

Permalink
Initial weights compresson for pytorch models to 8 bit. (#1941)
Browse files Browse the repository at this point in the history
### Changes

Added data free int8 weight compression implementation for pytorch
models.

### Reason for changes

Acceleration of compression and next conversion to IR

### Related tickets

### Tests

---------

Co-authored-by: Nikita Malinin <[email protected]>
  • Loading branch information
andreyanufr and KodiaqQ authored Jul 24, 2023
1 parent b45c6a5 commit a17ec33
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 1 deletion.
1 change: 1 addition & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization import QuantizationPreset
from nncf.quantization import compress_weights
from nncf.quantization import quantize
from nncf.quantization import quantize_with_accuracy_control
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
Expand Down
1 change: 1 addition & 0 deletions nncf/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
# limitations under the License.
"""Post-training quantization APIs."""
from nncf.common.quantization.structs import QuantizationPreset
from nncf.quantization.quantize_model import compress_weights
from nncf.quantization.quantize_model import quantize
from nncf.quantization.quantize_model import quantize_with_accuracy_control
21 changes: 21 additions & 0 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,24 @@ def quantize_with_accuracy_control(
)

raise RuntimeError(f"Unsupported type of backend: {backend}")


@api(canonical_alias="nncf.compress_weights")
def compress_weights(model: TModel, use_fake_quantize: bool = False) -> TModel:
"""
Compress model weights.
:param model: A model to be compressed.
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers.
If True inserts fake quantization operations,
else compress weights to int8 and inserts custom dequantization.
:return: The model with compressed weight and dequantization or model with original weights and fake quantization.
Not trainable.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
import nncf.torch

return nncf.torch.compress_weights(model, use_fake_quantize)

raise RuntimeError(f"Unsupported type of backend: {backend}")
1 change: 1 addition & 0 deletions nncf/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from nncf.torch.dynamic_graph.context import no_nncf_trace
from nncf.torch.dynamic_graph.context import forward_nncf_trace
from nncf.torch.strip import strip
from nncf.torch.quantization.quantize_model import compress_weights

# NNCF relies on tracing PyTorch operations. Each code that uses NNCF
# should be executed with PyTorch operators wrapped via a call to "patch_torch_operators",
Expand Down
9 changes: 8 additions & 1 deletion nncf/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def from_module(module):

class NNCFEmbedding(_NNCFModuleMixin, nn.Embedding):
op_func_name = "embedding"
target_weight_dim_for_compression = 1

# Note that this does not require activation quantization because it's basically a lookup.
@staticmethod
Expand Down Expand Up @@ -449,7 +450,9 @@ def from_module(module):


@api(canonical_alias="nncf.torch.register_module")
def register_module(*quantizable_field_names: str, ignored_algorithms: list = None):
def register_module(
*quantizable_field_names: str, ignored_algorithms: list = None, target_weight_dim_for_compression: int = 0
):
# quantizable_field_names will work for `weight` attributes only. Should later extend to registering
# customly named attributes if it becomes necessary
def wrap(cls):
Expand All @@ -462,6 +465,10 @@ def wrap(cls):
setattr(NNCF_WRAPPED_USER_MODULES_DICT[cls], "get_weight_shape", get_base_attributes_fn)
if ignored_algorithms:
setattr(NNCF_WRAPPED_USER_MODULES_DICT[cls], "ignored_algorithms", ignored_algorithms)

setattr(
NNCF_WRAPPED_USER_MODULES_DICT[cls], "target_weight_dim_for_compression", target_weight_dim_for_compression
)
return cls

return wrap
Expand Down
1 change: 1 addition & 0 deletions nncf/torch/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

# Required for correct QUANTIZATION_MODULES registry functioning
from . import layers
from . import weights_compression
12 changes: 12 additions & 0 deletions nncf/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from nncf.torch.initialization import PTInitializingDataLoader
from nncf.torch.model_creation import create_compressed_model
from nncf.torch.nested_objects_traversal import objwalk
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.weights_compression import insert_pre_compression_operations
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_tensor

Expand Down Expand Up @@ -256,3 +258,13 @@ def send_to_device(tensor):
compressed_model.nncf.disable_dynamic_graph_building()

return compressed_model


def compress_weights(model: torch.nn.Module, use_fake_quantize: bool = False) -> torch.nn.Module:
"""
Implementation of the `compress_weights()` method for the PyTorch backend.
"""
compressed_model, _ = replace_modules_by_nncf_modules(model)
insert_pre_compression_operations(model, use_fake_quantize)

return compressed_model
125 changes: 125 additions & 0 deletions nncf/torch/quantization/weights_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional

import torch
from torch import nn

from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT
from nncf.torch.layers import NNCFEmbedding
from nncf.torch.layers import NNCFLinear
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high


class WeightsDecompressor(nn.Module):
"""Applies decompression of compressed weights in forward pass
Attributes:
zero_point: zero point in quantization scheme
scale: scale in quantizatin scheme
"""

def __init__(self, zero_point, scale):
super().__init__()
self.zero_point = zero_point
self.scale = scale

def forward(self, layer, op_arg):
w = layer.weight.type(dtype=self.scale.dtype)
layer.weight = (w - self.zero_point) * self.scale


class WeightsFQ(nn.Module):
"""Replaces weights with Torch's FakeQuantize operation on forward pass
Attributes:
zero_point: zero point in quantization scheme
scale: scale in quantizatin scheme
axis: channel for quantization
level_high: maximal quant value in assymetric quantization
"""

def __init__(self, zero_point, scale, axis=0, level_high=255):
super().__init__()
self.zero_point = zero_point
self.scale = scale
self.axis = axis
self.level_high = level_high

def forward(self, layer, op_arg):
layer.weight = torch.fake_quantize_per_channel_affine(
layer.weight, self.scale, self.zero_point, self.axis, 0, self.level_high
)


def _insert_pre_compression_operations(
module: nn.Module, allowed_types: Dict, use_fake_quantize=False, level_high=255
) -> Optional[nn.Module]:
"""
Inserts weights compression with dequantization or quantization pre operation for Linear and Embedding layers.
:param module: The module to insert the weights compression.
:param allowed_types: list of allowed types for weights compression.
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers.
If True inserts pytorch torch.fake_quantize_per_channel_affine(),
else compress weights to int8 and inserts custom dequantization.
:param level_high: highest possible value of compressed weights (lower is 0 in assymetric quantization).
:return: The module with inserted operations. The module is not trainable if use_fake_quantize is False.
"""
for _, layer in module.named_children():
if not type(layer) in allowed_types:
_insert_pre_compression_operations(layer, allowed_types, use_fake_quantize, level_high)
continue
target_dim = layer.target_weight_dim_for_compression
stat_dim = (target_dim + 1) % 2
input_low = torch.min(layer.weight, dim=stat_dim)[0].detach()
input_high = torch.max(layer.weight, dim=stat_dim)[0].detach()
scale, zero_point = get_scale_zp_from_input_low_input_high(0, level_high, input_low, input_high)

if not use_fake_quantize:
scale = scale.unsqueeze(stat_dim)
zero_point = zero_point.unsqueeze(stat_dim)
layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale))

compressed_weight = layer.weight.data / scale + zero_point
compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high)

layer.weight.requires_grad = False
layer.weight.data = compressed_weight.type(dtype=torch.uint8)
else:
zero_point = zero_point.type(dtype=torch.int32)
layer.register_pre_forward_operation(WeightsFQ(zero_point, scale, target_dim))


def insert_pre_compression_operations(module: nn.Module, use_fake_quantize=False, bits=8) -> Optional[nn.Module]:
"""
Inserts weights compression with dequantization or quantization pre operation for Linear and Embedding layers.
:param module: The module to insert the weights compression.
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers.
If True inserts torch.fake_quantize_per_channel_affine(),
else compress weights to int8 and inserts custom dequantization.
:param bits: number of bits for compression/quantization. Note: compressed weights type is
uint8 with one element per 8 bit.
:return: The module with inserted operations. The module is not trainable if use_fake_quantize is False.
"""
user_types = list(NNCF_WRAPPED_USER_MODULES_DICT.values())
allowed_types = [NNCFEmbedding, NNCFLinear]
level_high = 2**bits - 1

assert level_high < 256

for user_type in user_types:
if torch.nn.Embedding in user_type.__mro__:
allowed_types.append(user_type)

_insert_pre_compression_operations(module, allowed_types, use_fake_quantize, level_high)
65 changes: 65 additions & 0 deletions tests/torch/ptq/test_weights_compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from nncf.quantization import compress_weights


class ShortTransformer(torch.nn.Module):
def __init__(self, in_features, num_embeddings):
super().__init__()
self.wte = torch.nn.Embedding(num_embeddings, in_features)
self.linear = torch.nn.Linear(in_features, in_features)
self.lm_head = torch.nn.Linear(in_features, num_embeddings)

def forward(self, input_ids):
x = self.wte(input_ids)
x = self.linear(x)
res = self.lm_head(x)
return res


def test_compress_weights():
model = ShortTransformer(5, 10)

compressed_model = compress_weights(model, use_fake_quantize=False)

n_compressed_weights = 0
n_target_modules = 0

for _, module in compressed_model.named_children():
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
n_target_modules += 1
if module.weight.dtype in [torch.uint8, torch.int8]:
n_compressed_weights += 1

assert n_compressed_weights == n_target_modules


def test_compress_weights_with_fake_quantize(mocker):
model = ShortTransformer(5, 10)

compressed_model = compress_weights(model, use_fake_quantize=True)

n_target_modules = 0

spy = mocker.spy(torch, "fake_quantize_per_channel_affine")

for _, module in compressed_model.named_children():
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
n_target_modules += 1

model(torch.randint(0, 2, (1, 4)))

n_fake_quantize = spy.call_count

assert n_fake_quantize == n_target_modules

0 comments on commit a17ec33

Please sign in to comment.