-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial weights compresson for pytorch models to 8 bit. (#1941)
### Changes Added data free int8 weight compression implementation for pytorch models. ### Reason for changes Acceleration of compression and next conversion to IR ### Related tickets ### Tests --------- Co-authored-by: Nikita Malinin <[email protected]>
- Loading branch information
1 parent
b45c6a5
commit a17ec33
Showing
9 changed files
with
235 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) 2023 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Dict, Optional | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from nncf.torch.layers import NNCF_WRAPPED_USER_MODULES_DICT | ||
from nncf.torch.layers import NNCFEmbedding | ||
from nncf.torch.layers import NNCFLinear | ||
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high | ||
|
||
|
||
class WeightsDecompressor(nn.Module): | ||
"""Applies decompression of compressed weights in forward pass | ||
Attributes: | ||
zero_point: zero point in quantization scheme | ||
scale: scale in quantizatin scheme | ||
""" | ||
|
||
def __init__(self, zero_point, scale): | ||
super().__init__() | ||
self.zero_point = zero_point | ||
self.scale = scale | ||
|
||
def forward(self, layer, op_arg): | ||
w = layer.weight.type(dtype=self.scale.dtype) | ||
layer.weight = (w - self.zero_point) * self.scale | ||
|
||
|
||
class WeightsFQ(nn.Module): | ||
"""Replaces weights with Torch's FakeQuantize operation on forward pass | ||
Attributes: | ||
zero_point: zero point in quantization scheme | ||
scale: scale in quantizatin scheme | ||
axis: channel for quantization | ||
level_high: maximal quant value in assymetric quantization | ||
""" | ||
|
||
def __init__(self, zero_point, scale, axis=0, level_high=255): | ||
super().__init__() | ||
self.zero_point = zero_point | ||
self.scale = scale | ||
self.axis = axis | ||
self.level_high = level_high | ||
|
||
def forward(self, layer, op_arg): | ||
layer.weight = torch.fake_quantize_per_channel_affine( | ||
layer.weight, self.scale, self.zero_point, self.axis, 0, self.level_high | ||
) | ||
|
||
|
||
def _insert_pre_compression_operations( | ||
module: nn.Module, allowed_types: Dict, use_fake_quantize=False, level_high=255 | ||
) -> Optional[nn.Module]: | ||
""" | ||
Inserts weights compression with dequantization or quantization pre operation for Linear and Embedding layers. | ||
:param module: The module to insert the weights compression. | ||
:param allowed_types: list of allowed types for weights compression. | ||
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers. | ||
If True inserts pytorch torch.fake_quantize_per_channel_affine(), | ||
else compress weights to int8 and inserts custom dequantization. | ||
:param level_high: highest possible value of compressed weights (lower is 0 in assymetric quantization). | ||
:return: The module with inserted operations. The module is not trainable if use_fake_quantize is False. | ||
""" | ||
for _, layer in module.named_children(): | ||
if not type(layer) in allowed_types: | ||
_insert_pre_compression_operations(layer, allowed_types, use_fake_quantize, level_high) | ||
continue | ||
target_dim = layer.target_weight_dim_for_compression | ||
stat_dim = (target_dim + 1) % 2 | ||
input_low = torch.min(layer.weight, dim=stat_dim)[0].detach() | ||
input_high = torch.max(layer.weight, dim=stat_dim)[0].detach() | ||
scale, zero_point = get_scale_zp_from_input_low_input_high(0, level_high, input_low, input_high) | ||
|
||
if not use_fake_quantize: | ||
scale = scale.unsqueeze(stat_dim) | ||
zero_point = zero_point.unsqueeze(stat_dim) | ||
layer.register_pre_forward_operation(WeightsDecompressor(zero_point, scale)) | ||
|
||
compressed_weight = layer.weight.data / scale + zero_point | ||
compressed_weight = torch.clamp(torch.round(compressed_weight), 0, level_high) | ||
|
||
layer.weight.requires_grad = False | ||
layer.weight.data = compressed_weight.type(dtype=torch.uint8) | ||
else: | ||
zero_point = zero_point.type(dtype=torch.int32) | ||
layer.register_pre_forward_operation(WeightsFQ(zero_point, scale, target_dim)) | ||
|
||
|
||
def insert_pre_compression_operations(module: nn.Module, use_fake_quantize=False, bits=8) -> Optional[nn.Module]: | ||
""" | ||
Inserts weights compression with dequantization or quantization pre operation for Linear and Embedding layers. | ||
:param module: The module to insert the weights compression. | ||
:param use_fake_quantize: Disables real compression of weights in Linear and Embedding layers. | ||
If True inserts torch.fake_quantize_per_channel_affine(), | ||
else compress weights to int8 and inserts custom dequantization. | ||
:param bits: number of bits for compression/quantization. Note: compressed weights type is | ||
uint8 with one element per 8 bit. | ||
:return: The module with inserted operations. The module is not trainable if use_fake_quantize is False. | ||
""" | ||
user_types = list(NNCF_WRAPPED_USER_MODULES_DICT.values()) | ||
allowed_types = [NNCFEmbedding, NNCFLinear] | ||
level_high = 2**bits - 1 | ||
|
||
assert level_high < 256 | ||
|
||
for user_type in user_types: | ||
if torch.nn.Embedding in user_type.__mro__: | ||
allowed_types.append(user_type) | ||
|
||
_insert_pre_compression_operations(module, allowed_types, use_fake_quantize, level_high) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) 2023 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
|
||
from nncf.quantization import compress_weights | ||
|
||
|
||
class ShortTransformer(torch.nn.Module): | ||
def __init__(self, in_features, num_embeddings): | ||
super().__init__() | ||
self.wte = torch.nn.Embedding(num_embeddings, in_features) | ||
self.linear = torch.nn.Linear(in_features, in_features) | ||
self.lm_head = torch.nn.Linear(in_features, num_embeddings) | ||
|
||
def forward(self, input_ids): | ||
x = self.wte(input_ids) | ||
x = self.linear(x) | ||
res = self.lm_head(x) | ||
return res | ||
|
||
|
||
def test_compress_weights(): | ||
model = ShortTransformer(5, 10) | ||
|
||
compressed_model = compress_weights(model, use_fake_quantize=False) | ||
|
||
n_compressed_weights = 0 | ||
n_target_modules = 0 | ||
|
||
for _, module in compressed_model.named_children(): | ||
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): | ||
n_target_modules += 1 | ||
if module.weight.dtype in [torch.uint8, torch.int8]: | ||
n_compressed_weights += 1 | ||
|
||
assert n_compressed_weights == n_target_modules | ||
|
||
|
||
def test_compress_weights_with_fake_quantize(mocker): | ||
model = ShortTransformer(5, 10) | ||
|
||
compressed_model = compress_weights(model, use_fake_quantize=True) | ||
|
||
n_target_modules = 0 | ||
|
||
spy = mocker.spy(torch, "fake_quantize_per_channel_affine") | ||
|
||
for _, module in compressed_model.named_children(): | ||
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)): | ||
n_target_modules += 1 | ||
|
||
model(torch.randint(0, 2, (1, 4))) | ||
|
||
n_fake_quantize = spy.call_count | ||
|
||
assert n_fake_quantize == n_target_modules |