Skip to content

Commit

Permalink
Initial commit. Rebased.
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Jul 3, 2024
1 parent f5ad4ea commit 55cafaa
Show file tree
Hide file tree
Showing 6 changed files with 219 additions and 99 deletions.
6 changes: 6 additions & 0 deletions nncf/common/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import sys
from contextlib import contextmanager
from functools import lru_cache

NNCF_LOGGER_NAME = "nncf"

Expand Down Expand Up @@ -86,3 +87,8 @@ def warn_bkc_version_mismatch(backend: str, bkc_version: str, current_version: s
f"while current {backend} version is {current_version}. "
f"If you encounter issues, consider switching to {backend}{bkc_version}"
)


@lru_cache(None)
def log_once(level, message):
nncf_logger.log(level, message)
116 changes: 116 additions & 0 deletions nncf/openvino/quantization/compression_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple

import openvino as ov
from openvino.runtime import opset13 as opset

from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig


class OVCompressionPrimitiveCache:
def __init__(self):
self._compress_weight_model_cache = {}
self._compress_decompress_weight_model_cache = {}

def get_compress_weight_primitive(
self,
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
):
key = (config.mode, config.num_bits, weight_shape, scale_shape)
if zero_point_shape is not None:
key += (zero_point_shape,)
if key not in self._compress_weight_model_cache:
self._compress_weight_model_cache[key] = self._build_compress_model(
config, weight_shape, scale_shape, zero_point_shape
)
return self._compress_weight_model_cache[key]

def get_compress_decompress_weight_primitive(
self,
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
):
key = (config.mode, config.num_bits, weight_shape, scale_shape)
if zero_point_shape is not None:
key += (zero_point_shape,)
if key not in self._compress_decompress_weight_model_cache:
self._compress_decompress_weight_model_cache[key] = self._build_compress_decompress_model(
config, weight_shape, scale_shape, zero_point_shape
)
return self._compress_decompress_weight_model_cache[key]

@staticmethod
def _build_compress_model(
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
return_nodes: bool = False,
):
w = opset.parameter(weight_shape, name="w")
s = opset.parameter(scale_shape, name="s")
parameters = [w, s]
compressed_w = w / s
num_bits = config.num_bits
if zero_point_shape is not None:
level_low = 0
level_high = 2**num_bits - 1

zp = opset.parameter(zero_point_shape, name="zp")
parameters.append(zp)
compressed_w += zp
else:
level_low = -(2 ** (num_bits - 1))
level_high = 2 ** (num_bits - 1) - 1

result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights")

if return_nodes:
return parameters, result

model = ov.Model([result], parameters)

compiled_model = ov.compile_model(model)

return lambda parameters: compiled_model(parameters)[0]

@staticmethod
def _build_compress_decompress_model(
config: WeightCompressionConfig,
weight_shape: Tuple,
scale_shape: Tuple,
zero_point_shape: Optional[Tuple] = None,
):
parameters, clamp = OVCompressionPrimitiveCache._build_compress_model(
config, weight_shape, scale_shape, zero_point_shape, return_nodes=True
)

if len(parameters) == 3:
_, s, zp = parameters
result = (clamp - zp) * s
else:
s = parameters[1]
result = clamp * s

model = ov.Model([result], parameters)
compiled_model = ov.compile_model(model)

return lambda parameters: compiled_model(parameters)[0]


OV_COMPRESSION_PRIMITIVE_CACHE = OVCompressionPrimitiveCache()
Original file line number Diff line number Diff line change
Expand Up @@ -221,59 +221,6 @@ def dump_parameters(
) -> None:
dump_parameters(model, parameters, algo_name, path)

@staticmethod
def get_compress_decompress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None
):
parameters, clamp = OVWeightCompressionAlgoBackend.get_compress_pipeline(
weight_compression_parameter, w_shape, s_shape, z_p_shape, True
)

if len(parameters) == 3:
_, s, zp = parameters
result = (clamp - zp) * s
else:
s = parameters[1]
result = clamp * s

model = ov.Model([result], parameters)

compiled_model = ov.compile_model(model)

return lambda parameters: compiled_model(parameters)[0]

@staticmethod
def get_compress_pipeline(
weight_compression_parameter: WeightCompressionParameters, w_shape, s_shape, z_p_shape=None, return_nodes=False
):
config = weight_compression_parameter.compression_config
mode = config.mode
assert mode in [CompressWeightsMode.INT4_SYM, CompressWeightsMode.INT4_ASYM]
num_bits = config.num_bits

level_low = 0
level_high = 2**num_bits - 1

w = opset.parameter(w_shape, name="w")
s = opset.parameter(s_shape, name="s")
parameters = [w, s]
compressed_w = w / s
if z_p_shape is not None:
zp = opset.parameter(z_p_shape, name="zp")
parameters.append(zp)
compressed_w += zp

result = opset.clamp(opset.round(compressed_w), level_low, level_high, name="compressed_weights")

if return_nodes:
return parameters, result

model = ov.Model([result], parameters)

compiled_model = ov.compile_model(model)

return lambda parameters: compiled_model(parameters)[0]


class OVAWQAlgoAlgoBackend(AWQAlgoBackend, OVWeightCompressionAlgoBackend):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_dequantized_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_integer_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import reshape_weight_for_grouped_quantization
Expand Down Expand Up @@ -117,7 +119,6 @@ def apply(
:return: Dict with pairs (weight name, estimated scale).
"""

compress_decompress_cache = {}
res = dict()

for wp in track(self._all_weight_params, description="Applying Scale Estimation"):
Expand Down Expand Up @@ -201,36 +202,14 @@ def apply(
if self._weight_penalty > 0.0:
min_max_scale_diffs += self._weight_penalty * fns.mean((q_weights - original_weight) ** 2, axis=-1)

zp_shape = zp.shape if zp is not None else None
key = [(wp.compression_config.mode, wp.compression_config.num_bits) + q_weights.shape + scale.shape]
if zp is not None:
key += zp_shape
key = tuple(key)
if key in compress_decompress_cache:
compress_decompress_model = compress_decompress_cache[key]["compress_decompress_model"]
compress_model = compress_decompress_cache[key]["compress_model"]
else:
compress_decompress_model = self._backend_entity.get_compress_decompress_pipeline(
wp, q_weights.shape, scale.shape, zp_shape
)
compress_model = self._backend_entity.get_compress_pipeline(wp, q_weights.shape, scale.shape, zp_shape)
compress_decompress_cache[key] = {
"compress_decompress_model": compress_decompress_model,
"compress_model": compress_model,
}

zero_scale = 0.001
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)

input_tensors = [original_weight.data, None]
if zp is not None:
input_tensors.append(zp.data)
# iterative rectification of initial scale
for i in range(self._initial_steps):
near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance)
input_tensors[1] = near_to_ideal_scale.data

out = compress_decompress_model(input_tensors)
out = calculate_quantized_dequantized_weight(original_weight, config, near_to_ideal_scale, zp)
q_weights_ = fns.zeros_like(original_weight) + out
q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)

Expand All @@ -253,10 +232,9 @@ def apply(
else:
near_to_ideal_scale = mask * result_scale + (1.0 - mask) * near_to_ideal_scale
result_scale = near_to_ideal_scale
input_tensors[1] = near_to_ideal_scale.data

if i < self._initial_steps - 1:
out = compress_model(input_tensors)
out = calculate_quantized_weight(original_weight, config, near_to_ideal_scale, zp)
compressed_weights = fns.zeros_like(original_weight) + out
target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
Expand All @@ -266,16 +244,14 @@ def apply(
factor = 1.0 - 0.05 * scale_steps
scaled_scale = factor * scale

input_tensors[1] = scaled_scale.data
out = compress_model(input_tensors)
out = calculate_quantized_weight(original_weight, config, scaled_scale, zp)
compressed_weights = fns.zeros_like(original_weight) + out

target, zero_mask = get_target_zero_mask(compressed_weights, zp)
zero_mask = zero_scale * zero_mask.astype(original_weight.dtype)
near_to_ideal_scale = estimate_scales(original_weight, target, zero_mask, importance)

input_tensors[1] = near_to_ideal_scale.data
out = compress_decompress_model(input_tensors)
out = calculate_quantized_dequantized_weight(original_weight, config, near_to_ideal_scale, zp)
q_weights_ = fns.zeros_like(original_weight) + out

q_outs = fns.matmul(fns.transpose(q_weights_, (1, 0, 2)), X)
Expand Down
75 changes: 59 additions & 16 deletions nncf/quantization/algorithms/weight_compression/weight_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np

import nncf
from nncf.common.logging.logger import log_once
from nncf.parameters import CompressWeightsMode
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.fake_quantize import calculate_scale_zero_point
from nncf.tensor import Tensor
from nncf.tensor import functions as fns
from nncf.tensor.definitions import TensorBackend
from nncf.tensor.definitions import TensorDataType
from nncf.utils import is_openvino_available

ReductionAxes = Tuple[int, ...]

Expand Down Expand Up @@ -279,25 +282,65 @@ def calculate_quantized_weight(
:param zero_point: Zero point tensor used for quantization.
:return: Quantized weight tensor of uint8 or int8 type.
"""
if weight.dtype != TensorDataType.float32:
weight = weight.astype(TensorDataType.float32)
if scale.dtype != TensorDataType.float32:
scale = scale.astype(TensorDataType.float32)

num_bits = config.num_bits
asym_quant = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]
dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

compressed_weights = weight / scale
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype)
if weight.backend == TensorBackend.numpy and not is_openvino_available():
log_once(logging.INFO, "Compression time may improve after installing OpenVINO")

if weight.backend == TensorBackend.numpy and is_openvino_available():
from nncf.openvino.quantization.compression_primitives import OV_COMPRESSION_PRIMITIVE_CACHE

zero_point_shape = None if zero_point is None else zero_point.shape
compress_weight_primitive = OV_COMPRESSION_PRIMITIVE_CACHE.get_compress_weight_primitive(
config, weight.shape, scale.shape, zero_point_shape
)
input_tensors = weight.data, scale.data
if zero_point is not None:
input_tensors += (zero_point.data,)
compressed_weights = Tensor(compress_weight_primitive(input_tensors))
else:
if weight.dtype != TensorDataType.float32:
weight = weight.astype(TensorDataType.float32)
if scale.dtype != TensorDataType.float32:
scale = scale.astype(TensorDataType.float32)

num_bits = config.num_bits
asym_quant = config.mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT4_ASYM]
dtype = TensorDataType.uint8 if asym_quant else TensorDataType.int8
level_low = 0 if asym_quant else -(2 ** (num_bits - 1))
level_high = 2**num_bits - 1 if asym_quant else 2 ** (num_bits - 1) - 1

compressed_weights = weight / scale
if zero_point is not None:
compressed_weights += zero_point.astype(weight.dtype)
compressed_weights = fns.round(compressed_weights)
compressed_weights = fns.clip(compressed_weights, level_low, level_high).astype(dtype)
return compressed_weights


def calculate_quantized_dequantized_weight(
weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None
) -> Tensor:

if weight.backend == TensorBackend.numpy and not is_openvino_available():
log_once(logging.INFO, "Compression time may improve after installing OpenVINO")

if weight.backend == TensorBackend.numpy and is_openvino_available():
from nncf.openvino.quantization.compression_primitives import OV_COMPRESSION_PRIMITIVE_CACHE

zero_point_shape = None if zero_point is None else zero_point.shape
compress_decompress_weight_primitive = OV_COMPRESSION_PRIMITIVE_CACHE.get_compress_decompress_weight_primitive(
config, weight.shape, scale.shape, zero_point_shape
)
input_tensors = weight.data, scale.data
if zero_point is not None:
input_tensors += (zero_point.data,)
decompressed_weight = Tensor(compress_decompress_weight_primitive(input_tensors))
else:
compressed_weight = calculate_quantized_weight(weight, config, scale, zero_point)
decompressed_weight = do_dequantization(compressed_weight, scale, zero_point)
return decompressed_weight


def do_integer_quantization(
weight: Tensor,
reduction_axes: ReductionAxes,
Expand Down
Loading

0 comments on commit 55cafaa

Please sign in to comment.