From 2cbbb0148fe260660bc7b91940d89687bc2382ad Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Mon, 21 Oct 2024 10:46:17 +0200 Subject: [PATCH] Something works --- .../quantization/compression_primitives.py | 11 ++--- .../weight_lowering/__init__.py | 18 +++++--- .../weight_lowering/common.py | 5 ++- .../weight_lowering/dispatched_functions.py | 6 ++- .../weight_lowering/ov_backend.py | 45 +++++++++---------- .../weight_lowering/tensor_backend.py | 16 +++++-- .../weight_lowering_dispatcher.py | 5 +-- 7 files changed, 57 insertions(+), 49 deletions(-) diff --git a/nncf/openvino/quantization/compression_primitives.py b/nncf/openvino/quantization/compression_primitives.py index d917afe089..2b0c4e9508 100644 --- a/nncf/openvino/quantization/compression_primitives.py +++ b/nncf/openvino/quantization/compression_primitives.py @@ -123,6 +123,7 @@ def get_compress_decompress_weight_primitive( zero_point_shape, ) + def _build_compress_decompress_model( config: WeightCompressionConfig, params: PrimitiveParameters, @@ -131,13 +132,7 @@ def _build_compress_decompress_model( zero_point_shape: Optional[Tuple] = None, ): ov_parameters, ov_results = _build_compress_model( - config, - params, - weight_shape, - scale_shape, - zero_point_shape, - reduction_axes=None, - return_nodes=True + config, params, weight_shape, scale_shape, zero_point_shape, reduction_axes=None, return_nodes=True ) return _get_compress_decompress_model( config, @@ -196,7 +191,7 @@ def _build_compress_model( num_groups_per_channel = channel_size // group_size shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis - shape[reduction_axes: reduction_axes + 1] = (num_groups_per_channel, group_size) + shape[reduction_axes : reduction_axes + 1] = (num_groups_per_channel, group_size) weight = opset.reshape(weight, shape, special_zero=False) reduction_axes += 1 diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/__init__.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/__init__.py index d0d85a667d..8aa16c0ee4 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/__init__.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/__init__.py @@ -10,8 +10,16 @@ # limitations under the License. -from .common import reshape_weight_for_grouped_quantization, calculate_nf4_scale, do_nf4_quantization, \ - do_nf4_dequantization, calculate_normalized_weight_and_fp4_scale, calculate_integer_quantization_params, \ - calculate_quantized_weight, compress_weight, do_int_dequantization - -from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight +from .common import WeightCompressionConfig +from .common import calculate_integer_quantization_params +from .common import calculate_nf4_scale +from .common import calculate_normalized_weight_and_fp4_scale +from .common import calculate_quantized_weight +from .common import compress_weight +from .common import do_int_dequantization +from .common import do_nf4_dequantization +from .common import do_nf4_quantization +from .common import get_integer_quantization_error +from .common import reshape_weight_for_grouped_quantization +from .dispatched_functions import calculate_quantized_dequantized_weight +from .dispatched_functions import do_int_quantization diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/common.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/common.py index 24842eb19b..3f9baadfd8 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/common.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/common.py @@ -24,7 +24,6 @@ from .dispatched_functions import do_int_quantization - ReductionAxes = Tuple[int, ...] NF4_QUANTILES = np.array( @@ -342,7 +341,9 @@ def get_integer_quantization_error( if weight.dtype != TensorDataType.float32: weight = weight.astype(TensorDataType.float32) - compressed_weights, scale, zero_point = do_int_quantization(weight, reduction_axes, config, invert_division=invert_division) + compressed_weights, scale, zero_point = do_int_quantization( + weight, reduction_axes, config, invert_division=invert_division + ) decompressed_weight = do_int_dequantization(compressed_weights, scale, zero_point) decompressed_weight = decompressed_weight.reshape(orig_shape) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/dispatched_functions.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/dispatched_functions.py index 9e0e16d73a..7dbe02a010 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/dispatched_functions.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/dispatched_functions.py @@ -8,11 +8,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Optional +from typing import Optional, Tuple from nncf.tensor import Tensor -from .weight_lowering_dispatcher import weight_lowering_dispatcher, ov_available_backend_selector, BackendParametersContainer + from ..config import WeightCompressionConfig +from .weight_lowering_dispatcher import ov_available_backend_selector +from .weight_lowering_dispatcher import weight_lowering_dispatcher @weight_lowering_dispatcher(ov_available_backend_selector) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/ov_backend.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/ov_backend.py index dabe3c15c7..ac27fc60a3 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/ov_backend.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/ov_backend.py @@ -34,8 +34,9 @@ from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.tensor import Tensor +from .dispatched_functions import calculate_quantized_dequantized_weight +from .dispatched_functions import do_int_quantization from .weight_lowering_dispatcher import WeightLoweringBackend -from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight @dataclass @@ -46,11 +47,14 @@ class OVModelParameters: share_outputs: bool = True input_dtype: str = "fp32" + def __hash__(self): + return hash((self.dynamic, self.recompile, self.release_memory, self.share_outputs, self.input_dtype)) + class CompiledModelCache: def __init__(self): self._cache = {} - + def clear(self): self._cache.clear() @@ -72,7 +76,7 @@ def wrapper(*args, **kwargs): cache = COMPILED_MODEL_CACHE._cache if not recompile and cache_key in cache: return cache[cache_key] - result = func(cache, *args, **kwargs) + result = func(*args, **kwargs) cache[cache_key] = result return result @@ -109,12 +113,12 @@ def _( ) if precomputed_scale is None: - results = model(weight) + results = model(weight.data) compressed_weight, scale, zero_point = [Tensor(it) for it in results] else: - inputs = [weight, precomputed_scale] + inputs = [weight.data, precomputed_scale.data] if precomputed_zero_point is not None: - inputs += [precomputed_zero_point] + inputs += [precomputed_zero_point.data] compressed_weight = Tensor(model(inputs)[0]) scale, zero_point = precomputed_scale, precomputed_zero_point @@ -123,7 +127,12 @@ def _( @calculate_quantized_dequantized_weight.register(WeightLoweringBackend.OV) def _( - weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, ov_model_params: Optional[OVModelParameters] = None, **kwargs + weight: Tensor, + config: WeightCompressionConfig, + scale: Tensor, + zero_point: Optional[Tensor] = None, + ov_model_params: Optional[OVModelParameters] = None, + **kwargs, ) -> Tensor: weight_shape = weight.shape scale_shape = scale.shape @@ -134,17 +143,11 @@ def _( if config.mode in [CompressWeightsMode.INT4_ASYM, CompressWeightsMode.INT4_SYM]: ov_model_params.dynamic = False - model = get_compress_decompress_weight_model( - config, - weight_shape, - scale_shape, - zero_point_shape, - ov_model_params - ) + model = get_compress_decompress_weight_model(config, weight_shape, scale_shape, zero_point_shape, ov_model_params) - inputs = [weight, scale] + inputs = [weight.data, scale.data] if zero_point is not None: - inputs.append(zero_point) + inputs.append(zero_point.data) results = model(inputs) decompressed_weight = [Tensor(it) for it in results][0] return decompressed_weight @@ -218,13 +221,7 @@ def _build_compress_decompress_model( zero_point_shape: Optional[Tuple] = None, ): ov_parameters, ov_results = _build_compress_model( - config, - ov_model_params, - weight_shape, - scale_shape, - zero_point_shape, - reduction_axes=None, - return_nodes=True + config, ov_model_params, weight_shape, scale_shape, zero_point_shape, reduction_axes=None, return_nodes=True ) return _get_compress_decompress_model( config, @@ -283,7 +280,7 @@ def _build_compress_model( num_groups_per_channel = channel_size // group_size shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis - shape[reduction_axes: reduction_axes + 1] = (num_groups_per_channel, group_size) + shape[reduction_axes : reduction_axes + 1] = (num_groups_per_channel, group_size) weight = opset.reshape(weight, shape, special_zero=False) reduction_axes += 1 diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/tensor_backend.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/tensor_backend.py index 9f458c7ce6..47889dddb1 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/tensor_backend.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/tensor_backend.py @@ -22,11 +22,14 @@ from nncf.tensor import functions as fns from nncf.tensor.definitions import TensorDataType -from .dispatched_functions import do_int_quantization, calculate_quantized_dequantized_weight +from .common import calculate_integer_quantization_params +from .common import calculate_quantized_weight +from .common import do_int_dequantization +from .common import reshape_weight_for_grouped_quantization +from .dispatched_functions import calculate_quantized_dequantized_weight +from .dispatched_functions import do_int_quantization from .weight_lowering_dispatcher import WeightLoweringBackend -from .common import reshape_weight_for_grouped_quantization, calculate_quantized_weight, calculate_integer_quantization_params, do_int_dequantization - ReductionAxes = Tuple[int, ...] @@ -90,7 +93,12 @@ def _( @calculate_quantized_dequantized_weight.register(WeightLoweringBackend.TENSOR) def _( - weight: Tensor, config: WeightCompressionConfig, scale: Tensor, zero_point: Optional[Tensor] = None, invert_division=False, **kwargs + weight: Tensor, + config: WeightCompressionConfig, + scale: Tensor, + zero_point: Optional[Tensor] = None, + invert_division=False, + **kwargs, ) -> Tensor: compressed_weight = calculate_quantized_weight(weight, config, scale, zero_point, invert_division) decompressed_weight = do_int_dequantization(compressed_weight, scale, zero_point) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering/weight_lowering_dispatcher.py b/nncf/quantization/algorithms/weight_compression/weight_lowering/weight_lowering_dispatcher.py index d1cde91ac8..e3dd3a64d5 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering/weight_lowering_dispatcher.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering/weight_lowering_dispatcher.py @@ -10,12 +10,9 @@ # limitations under the License. from enum import Enum from functools import wraps -from typing import Dict, Any, Callable, Optional +from typing import Any, Callable, Dict, Optional from nncf.utils import is_openvino_available -from .ov_backend import do_int_quantization as do_int_quantization_ov -from .tensor_backend import do_int_quantization as do_int_quantization_tensor -from functools import singledispatch class WeightLoweringBackend(Enum):