Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Computation of compression parameters via OpenVINO models #2727

Open
wants to merge 77 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
10d1ddb
Initial draft. Rebased.
nikita-savelyevv Jul 3, 2024
bd2629b
Unstage helper scripts
nikita-savelyevv Oct 22, 2024
3e69252
WIP
nikita-savelyevv Oct 23, 2024
166dd04
Reshape weights beforehand
nikita-savelyevv Oct 24, 2024
edbe913
BF16 support
nikita-savelyevv Oct 25, 2024
b636c66
Tweak lora type hint
nikita-savelyevv Oct 25, 2024
f0129ef
Tweaks
nikita-savelyevv Oct 25, 2024
e887e70
Added share_inputs
nikita-savelyevv Oct 25, 2024
9141a8a
Modeling tweaks
nikita-savelyevv Oct 25, 2024
a43c514
Move results_cache into separate file
nikita-savelyevv Oct 25, 2024
1216f65
Implement astype for ov backend for bf16, u4, i4
nikita-savelyevv Oct 25, 2024
8611b75
Experiments
nikita-savelyevv Oct 26, 2024
0718668
Support case of (weight, scale) -> (c_weight, zp)
nikita-savelyevv Oct 26, 2024
283a821
SE improvements
nikita-savelyevv Oct 28, 2024
6964844
Accelerate AWQ
nikita-savelyevv Oct 28, 2024
80e2c92
SE changes
nikita-savelyevv Oct 29, 2024
fc82866
Add access counts to caching decorator
nikita-savelyevv Oct 29, 2024
f3891cd
Comment out env vars
nikita-savelyevv Oct 29, 2024
353aac1
Fix existing tests
nikita-savelyevv Oct 29, 2024
d20e593
Unstage helper scripts
nikita-savelyevv Oct 30, 2024
dc30d8d
Tests WIP
nikita-savelyevv Oct 31, 2024
c5606ce
Invert Tensor division
nikita-savelyevv Nov 1, 2024
e6a9d56
Add fns.divide
nikita-savelyevv Nov 4, 2024
ab90a08
Adopt misalignment test to check the degree of misalignment
nikita-savelyevv Nov 6, 2024
2e308b7
Merge branch 'develop' into compress-via-openvino
nikita-savelyevv Nov 7, 2024
6289c5c
Merge-related fixes
nikita-savelyevv Nov 7, 2024
f60fd17
Tweaks
nikita-savelyevv Nov 7, 2024
57a0931
Strict input/output data types
nikita-savelyevv Nov 11, 2024
1010fcf
Add dynamic shapes test
nikita-savelyevv Nov 11, 2024
6e54fba
ov modeling tests
nikita-savelyevv Nov 13, 2024
8ac0fe2
Move cache_results decorator
nikita-savelyevv Nov 13, 2024
ded66f3
Tests reorgantization
nikita-savelyevv Nov 13, 2024
69ae5fa
cache_results decorator test
nikita-savelyevv Nov 13, 2024
d0f49ae
get_const_value test
nikita-savelyevv Nov 13, 2024
a282976
OVModelParameters minor refactor
nikita-savelyevv Nov 13, 2024
b13f186
Added OV tensor tests
nikita-savelyevv Nov 14, 2024
9e90d5a
Minor file reorg
nikita-savelyevv Nov 14, 2024
5f46593
Tweaks
nikita-savelyevv Nov 14, 2024
e7617f1
Tweaks
nikita-savelyevv Nov 14, 2024
925f830
Switch to OV 2024.5 rc2
nikita-savelyevv Nov 15, 2024
5831fcd
Additional tests for ov_modeling
nikita-savelyevv Nov 15, 2024
9160de3
Type hints
nikita-savelyevv Nov 15, 2024
c7c63eb
Ignore mypy
nikita-savelyevv Nov 15, 2024
764f722
Reuse DTYPE_MAP_REV
nikita-savelyevv Nov 15, 2024
4a448e1
Added docstrings
nikita-savelyevv Nov 18, 2024
73f61fc
Remove inverted NP division. Add non-convertable OV division.
nikita-savelyevv Dec 11, 2024
16ccf50
Merge branch 'develop' into compress-via-openvino
nikita-savelyevv Dec 11, 2024
cd884eb
Remove OV 2024.5 RC installation
nikita-savelyevv Dec 11, 2024
608cfe9
Add a test for non-convertable division
nikita-savelyevv Dec 11, 2024
9569e1e
Make the test more strict
nikita-savelyevv Dec 11, 2024
f962bd1
Remove unnecessary lines
nikita-savelyevv Dec 11, 2024
5dcd83d
Update get_integer_quantization_error implementation
nikita-savelyevv Dec 11, 2024
6e22ef5
Remove unnecessary convert
nikita-savelyevv Dec 11, 2024
b45e788
Move create_ov_const_from_tensor to node_utils
nikita-savelyevv Dec 11, 2024
b2cebd0
Separate checking logic into standalone methods
nikita-savelyevv Dec 11, 2024
3a71141
Add debug conditions
nikita-savelyevv Dec 11, 2024
eeadf1d
Move ov model cache clearing to ov backend destructor
nikita-savelyevv Dec 12, 2024
40aef54
Update default ov model parameters
nikita-savelyevv Dec 12, 2024
ab3d35f
Revert debug logic
nikita-savelyevv Dec 12, 2024
d48c748
Update reference
nikita-savelyevv Dec 12, 2024
9a56fae
Add debug conditions
nikita-savelyevv Dec 11, 2024
e10d806
Disable dynamic shapes by default
nikita-savelyevv Dec 12, 2024
b372dc7
Revert "Add debug conditions"
nikita-savelyevv Dec 12, 2024
63858d3
Linters
nikita-savelyevv Dec 12, 2024
87b5c10
Fix lora correction
nikita-savelyevv Dec 13, 2024
7134e6d
Remove not used argument
nikita-savelyevv Dec 13, 2024
5a1866f
Remove static shapes testing because it is not needed with non-conver…
nikita-savelyevv Dec 13, 2024
6a2c9fc
Set dynamic shapes by default
nikita-savelyevv Dec 13, 2024
204fb21
Merge branch 'develop' into compress-via-openvino
nikita-savelyevv Dec 13, 2024
dca5376
Merge branch 'develop' into compress-via-openvino
nikita-savelyevv Dec 16, 2024
92fbba5
Guarantee call order
nikita-savelyevv Dec 16, 2024
b27c720
Add convertable_division parameter
nikita-savelyevv Dec 16, 2024
6ab1c08
Cleanup
nikita-savelyevv Dec 16, 2024
a0fe91a
Add convertable division test
nikita-savelyevv Dec 16, 2024
97bd61d
Add explicit inference precision
nikita-savelyevv Dec 16, 2024
58963ab
Fix import
nikita-savelyevv Dec 16, 2024
ec21996
Update tests/post_training/data/wc_reference_data.yaml
nikita-savelyevv Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def collect_api_entities() -> APIInfo:
"nncf.tensor.functions.torch_linalg",
"nncf.tensor.functions.torch_io",
"nncf.tensor.functions.numpy_io",
"nncf.tensor.functions.ov",
]

with mock(mock_modules):
Expand Down
11 changes: 11 additions & 0 deletions nncf/common/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import logging
import sys
from contextlib import contextmanager
from functools import lru_cache

NNCF_LOGGER_NAME = "nncf"

Expand Down Expand Up @@ -86,3 +87,13 @@ def warn_bkc_version_mismatch(backend: str, bkc_version: str, current_version: s
f"while current {backend} version is {current_version}. "
f"If you encounter issues, consider switching to {backend}{bkc_version}"
)


@lru_cache(None)
def log_once(level: int, message: str) -> None:
"""
Logs a message only once.
:param level: Logging level, e.g. logging.WARNING.
:param message: The message to log.
"""
nncf_logger.log(level, message)
60 changes: 60 additions & 0 deletions nncf/common/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from importlib import import_module
from typing import Any, Callable, Dict, List

Expand Down Expand Up @@ -51,3 +52,62 @@ def wrapped_f(*args: Any, **kwargs: Any): # type: ignore
return wrapped_f

return wrap


class ResultsCacheContainer:
"""
A container for results decorated with @cache_results decorator.
"""

def __init__(self) -> None:
# Stores the results of the decorated function
self._cache: Dict[Any, Any] = {}
# Stores the number of times the cached result was accessed
self._access_count: Dict[Any, int] = {}

def clear(self) -> None:
self._cache.clear()
self._access_count.clear()

def is_empty(self) -> bool:
return len(self._cache) == 0

def __getitem__(self, item: Any) -> Any:
self._access_count[item] += 1
return self._cache[item]

def __setitem__(self, key: Any, value: Any) -> None:
self._access_count[key] = 0
self._cache[key] = value

def __contains__(self, item: Any) -> bool:
return item in self._cache


def cache_results(cache: ResultsCacheContainer) -> Callable: # type: ignore
"""
Decorator to cache the results of a function.

Decorated function additionally accepts a `disable_caching` argument do disable caching if needed. If it is True,
the result will not be stored saved to a cache. Also, if there is a corresponding result in the cache, it will be
recomputed.
:param cache: A cache container where results will be stored.
"""

def decorator(func: Callable) -> Callable: # type: ignore
def wrapper(*args, disable_caching: bool = False, **kwargs) -> Any: # type: ignore
if disable_caching:
return func(*args, **kwargs)
sig = inspect.signature(func)
new_kwargs = {name: arg for name, arg in zip(sig.parameters, args)}
new_kwargs.update(kwargs)
cache_key = (func.__name__, frozenset(new_kwargs.items()))
if cache_key in cache:
return cache[cache_key]
result = func(*args, **kwargs)
cache[cache_key] = result
return result

return wrapper

return decorator
36 changes: 36 additions & 0 deletions nncf/import_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib

_openvino_available = importlib.util.find_spec("openvino") is not None
_openvino_version = "N/A"
if _openvino_available:
try:
from openvino.runtime import get_version

version = get_version()
# avoid invalid format
if "-" in version:
ov_major_version, dev_info = version.split("-", 1)
commit_id = dev_info.split("-")[0]
version = f"{ov_major_version}-{commit_id}"
_openvino_version = version
except ImportError:
_openvino_available = False


def is_openvino_available():
"""
Check if OpenVINO is available.
:return: True if openvino package is installed, False otherwise.
"""
return _openvino_available
48 changes: 45 additions & 3 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import openvino.runtime as ov
import openvino.runtime.opset13 as opset
from openvino._pyopenvino.op import Constant

import nncf
from nncf.common.graph.graph import NNCFGraph
Expand Down Expand Up @@ -41,6 +42,8 @@
from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVOpMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype
from nncf.tensor import Tensor
from nncf.tensor import TensorBackend

InplaceInsertionFnType = Callable[[ov.Node, int, str], ov.Node]

Expand Down Expand Up @@ -107,16 +110,17 @@ def cnt_if_op(model: ov.Model, cnt: int) -> int:
return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node) -> np.ndarray:
def get_const_value(const_node: ov.Node, cast_bf16_to_fp32: Optional[bool] = True) -> np.ndarray:
"""
Returns the constant tensor for the node.
This method is applicable only for the floating-point constant data.

:param const_node: OpenVINO node.
:param cast_bf16_to_fp32: Whether to cast bf16 node data to fp32 or not. If False and the node contains bf16 data,
the resulting bf16 value will be returned encoded inside a numpy.float16 array.
:return: The constant value.
"""
if const_node.get_element_type() == ov.Type.bf16:
# Fixed FP32 data type as the result for BF16 constant
if const_node.get_element_type() == ov.Type.bf16 and cast_bf16_to_fp32:
return const_node.get_data(dtype=np.float32)
return const_node.data

Expand Down Expand Up @@ -631,3 +635,41 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple
channel_axis = activations_layout.index(OVLayoutElem.C_IN)

return channel_axis


def convert_if_needed(node: ov.Node, target_dtype: ov.Type) -> ov.Node:
"""
Converts the input node to the target data type if it is not already in the target data type.

:param node: The input node to convert.
:param target_dtype: The target data type to convert the input node to.
:return: The converted node.
"""
if node.get_element_type() == target_dtype:
return node
return opset.convert(node, target_dtype)


def non_convertable_divide(a: ov.Node, b: ov.Node) -> ov.Node:
"""
Creates a "non-convertable" divide operation. It won't be converted to a*(1/b).
"""
divide_node = a / b
divide_node.get_rt_info()["nonconvertable_divide_0"] = True
return divide_node


def create_ov_const_from_tensor(x: Tensor, dtype: ov.Type, name: Optional[str] = None) -> Constant:
"""
Create an OpenVINO Constant node from the given tensor.
:param x: Data tensor. Supports NumPy and OV tensor backends. If x backend is OV, the constant node is created
directly from underlying OV tensor.
:param dtype: Data type of the constant.
:param name: Optional name of the constant.
:return: OpenVINO Constant node.
"""
if x.backend == TensorBackend.ov:
assert x.data.get_element_type() == dtype
return opset.constant(x.data, name=name)
const = opset.constant(x.data, dtype=dtype, name=name)
return const
8 changes: 3 additions & 5 deletions nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_nf4_scale
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_quantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import calculate_quantized_dequantized_weight
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
from nncf.quantization.passes import transform_to_inference_graph
Expand Down Expand Up @@ -262,10 +261,9 @@ def apply(
g_compressed_weighs = do_nf4_quantization(weights_to_fake_quantize, g_c_scale)
g_decompressed_weighs = do_nf4_dequantization(g_compressed_weighs, g_c_scale)
else:
g_compressed_weighs, g_c_scale, g_c_zp = do_int_quantization(
weights_to_fake_quantize, reduction_axis, awq_config
g_decompressed_weighs = calculate_quantized_dequantized_weight(
weights_to_fake_quantize, awq_config, reduction_axis
)
g_decompressed_weighs = do_int_dequantization(g_compressed_weighs, g_c_scale, g_c_zp)
sacts = gacts / fns.unsqueeze(cur_scale, 1)

cur_out = fns.matmul(g_decompressed_weighs, sacts)
Expand Down
11 changes: 11 additions & 0 deletions nncf/quantization/algorithms/weight_compression/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,23 @@ def num_bits(self):
"""
return 8 if self.mode in [CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8_ASYM] else 4

@property
def is_int_asym(self):
return self.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT8_ASYM]

@property
def is_integer(self):
"""
:return: True if compression type in integer, else False.
"""
return self.mode not in [CompressWeightsMode.NF4, CompressWeightsMode.E2M1]

def __hash__(self):
return hash((self.mode.value, self.group_size))

def __str__(self):
return f"{self.mode.value}_{self.group_size}"


@dataclass
class WeightCompressionParameters:
Expand Down
1 change: 0 additions & 1 deletion nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ def _quantize_weights(
activations = [inp[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
wc_statistics = ScaleEstimation.activations_to_wc_statistics(activations)
scale, zero_point = ScaleEstimation.calculate_quantization_params(
self._backend_entity,
wc_statistics,
weight_tensor[:, (i1 + i) : (i1 + i + group_size)],
reduction_axes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig
from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters
from nncf.quantization.algorithms.weight_compression.weight_lowering import CompressedWeight
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_int_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_dequantization
from nncf.quantization.algorithms.weight_compression.weight_lowering import do_nf4_quantization
Expand Down Expand Up @@ -105,7 +106,7 @@ def is_applicable(self, wc_params: WeightCompressionParameters):
return wc_params.compression_config.num_bits == 4

def calculate_adapters(
self, weight: Tensor, compressed_weight: Tensor, wc_params: WeightCompressionParameters
self, weight: Tensor, compressed_weight: CompressedWeight, wc_params: WeightCompressionParameters
) -> Tuple[Tensor, Tensor, List[float]]:
"""
Calculates low rank matrices for a given original and compressed weights.
Expand Down Expand Up @@ -134,7 +135,7 @@ def calculate_adapters(
@staticmethod
def calculate_low_rank_matrices(
weight: Tensor,
compressed_weight: Tensor,
compressed_weight: CompressedWeight,
compression_config: WeightCompressionConfig,
reduction_axes: Tuple[int, ...],
lora_correction_params: AdvancedLoraCorrectionParameters,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _calc_weight_sensitivity(
if weight.dtype != TensorDataType.float32:
weight = weight.astype(TensorDataType.float32)

compressed_weights, scale, zero_point = do_int_quantization(weight, reduction_axes, backup_config)
compressed_weights, scale, zero_point = do_int_quantization(weight, backup_config, reduction_axes)
decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point)
decompressed_weight = decompressed_weight.reshape(orig_shape)
return fns.linalg.norm(decompressed_weight - weight, ord="fro").item()
Expand Down
Loading