diff --git a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py index 7ab0176eb85..081e99125b4 100644 --- a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py +++ b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py @@ -24,6 +24,7 @@ import nncf from nncf.common.logging import nncf_logger +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters DataItem = TypeVar("DataItem") ModelInput = TypeVar("ModelInput") @@ -63,6 +64,7 @@ def compress_model( group_size=group_size, awq=awq, sensitivity_metric=nncf.parameters.SensitivityMetric.MAX_ACTIVATION_VARIANCE, + advanced_parameters=AdvancedCompressionParameters(statistics_path="statistics"), ) return optimized_ov_model diff --git a/nncf/common/tensor_statistics/aggregator.py b/nncf/common/tensor_statistics/aggregator.py index 56691e5abde..2b1b21aa7f3 100644 --- a/nncf/common/tensor_statistics/aggregator.py +++ b/nncf/common/tensor_statistics/aggregator.py @@ -8,22 +8,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from abc import ABC from abc import abstractmethod from itertools import islice from typing import Any, Dict, Optional, TypeVar import nncf +import nncf.common.tensor_statistics.statistics_serializer as statistics_serializer +import nncf.common.tensor_statistics.statistics_validator as statistics_validator from nncf.common import factory from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType from nncf.data.dataset import DataItem from nncf.data.dataset import Dataset from nncf.data.dataset import ModelInput +from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic TensorType = TypeVar("TensorType") TModel = TypeVar("TModel") @@ -38,6 +44,8 @@ class StatisticsAggregator(ABC): Base class for statistics collection. """ + BACKEND: BackendType + def __init__(self, dataset: Dataset[DataItem, ModelInput]): self.dataset = dataset self.stat_subset_size = None @@ -88,6 +96,56 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None: f"smaller than the requested subset size {self.stat_subset_size}." ) + def load_statistics_from_dir(self, dir_path: str) -> None: + """ + Loads statistics from a directory and populates the statistic points with the loaded data. + + :param dir_path: The name of the directory from which to load the statistics. + """ + loaded_data, metadata = statistics_serializer.load_from_dir(dir_path) + statistics_validator.validate_backend(metadata, self.BACKEND) + self._load_statistics(loaded_data) + nncf_logger.info(f"Statistics were successfully loaded from a directory {dir_path}.") + + def _load_statistics(self, data: Dict[str, Any]) -> None: + """ + Loads statistics into the registered statistic points from the given data. + + :param data: A dictionary containing the statistics loaded from a file. + """ + for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors(): + statistics = tensor_collector.get_statistics() + statistics_key = self._get_statistics_key(statistics, statistic_point.target_point) + if statistics_key not in data: + raise nncf.ValidationError(f"Not found statistics for {statistics_key}") + statistics_container = tensor_collector.create_statistics_container(data[statistics_key]) + tensor_collector.set_cache(statistics_container) + + def dump_statistics(self, dir_path: str) -> None: + """ + Dumps the current statistics to a directory in a compressed format. + + :param dir_path: The path of the directory where the statistics will be saved. + """ + data_to_dump = self._prepare_statistics() + metadata = {"backend": self.BACKEND.value, "subset_size": self.stat_subset_size} + statistics_serializer.dump_to_dir(data_to_dump, dir_path, metadata) + nncf_logger.info(f"Statistics were successfully saved to a directory {dir_path}.") + + def _prepare_statistics(self) -> Dict[str, Any]: + """ + Prepares the statistics data for dumping into a directory. + + :return: A dictionary containing the statistics data to be dumped. + """ + data_to_dump = {} + for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors(): + statistics = tensor_collector.get_statistics() + statistics_key = self._get_statistics_key(statistics, statistic_point.target_point) + data = statistics.get_data() + data_to_dump[statistics_key] = data + return data_to_dump + def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None: """ Register statistic points for statistics collection and recalculates the maximum number samples @@ -154,3 +212,13 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]: :param outputs: raw model outputs :return: processed model outputs in Dict[str, Tensor] format """ + + @abstractmethod + def _get_statistics_key(self, statistics: TensorStatistic, target_point: TargetPoint) -> str: + """ + Returns key of statistics. + + :param statistics: Statistics value. + :param target_point: Statistics target point. + :return: Statistics key. + """ diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index 2dd681b8e36..38a3d1c3f89 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -59,14 +59,14 @@ def register_input(self, x: TensorType) -> TensorType: def _register_input(self, x: TensorType) -> None: pass - def get_statistics(self) -> None: + def get_statistics(self) -> Any: """Returns collected statistics, if present.""" if self._collected_samples == 0: raise StatisticsNotCollectedError() return self._get_statistics() @abstractmethod - def _get_statistics(self) -> None: + def _get_statistics(self) -> Any: pass def enable(self) -> None: diff --git a/nncf/common/tensor_statistics/statistic_point.py b/nncf/common/tensor_statistics/statistic_point.py index b728f1c1353..2e41d96261d 100644 --- a/nncf/common/tensor_statistics/statistic_point.py +++ b/nncf/common/tensor_statistics/statistic_point.py @@ -13,8 +13,7 @@ from typing import Any, Callable, Generator, Optional, Tuple, cast from nncf.common.graph.transformations.commands import TargetPoint -from nncf.common.tensor import NNCFTensor -from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector class StatisticPoint: @@ -25,7 +24,7 @@ class StatisticPoint: algorithm implies on what algorithm nedeed this statistics. """ - def __init__(self, target_point: TargetPoint, tensor_collector: TensorStatisticCollectorBase, algorithm: str): + def __init__(self, target_point: TargetPoint, tensor_collector: TensorCollector, algorithm: str): self.target_point = target_point self.algorithm_to_tensor_collectors = {algorithm: [tensor_collector]} @@ -36,11 +35,6 @@ def __eq__(self, other: Any) -> bool: and self.algorithm_to_tensor_collectors == other.self.algorithm_to_tensor_collectors, ) - def register_tensor(self, x: NNCFTensor) -> None: - for tensor_collectors in self.algorithm_to_tensor_collectors.values(): - for tensor_collector in tensor_collectors: - tensor_collector.register_input(x) - class StatisticPointsContainer(UserDict): # type: ignore """ @@ -88,7 +82,7 @@ def iter_through_statistic_points_in_target_node( def get_tensor_collectors( self, filter_fn: Optional[Callable[[StatisticPoint], bool]] = None - ) -> Generator[Tuple[str, StatisticPoint, TensorStatisticCollectorBase], None, None]: + ) -> Generator[Tuple[str, StatisticPoint, TensorCollector], None, None]: """ Returns iterable through all tensor collectors. @@ -115,7 +109,7 @@ def get_algo_statistics_for_node( target_node_name: str, filter_fn: Callable[[StatisticPoint], bool], algorithm: str, - ) -> Generator[TensorStatisticCollectorBase, None, None]: + ) -> Generator[TensorCollector, None, None]: """ Returns iterable through all statistic collectors in node with target_node_name. diff --git a/nncf/common/tensor_statistics/statistics.py b/nncf/common/tensor_statistics/statistics.py index dbb4cb7861a..68204a2ad94 100644 --- a/nncf/common/tensor_statistics/statistics.py +++ b/nncf/common/tensor_statistics/statistics.py @@ -12,7 +12,7 @@ from abc import ABC from abc import abstractmethod from collections import Counter -from typing import Any, Dict, List, Tuple, TypeVar, cast +from typing import Any, Dict, TypeVar, cast from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -120,26 +120,3 @@ def __eq__(self, other: Any) -> bool: if not isinstance(other, RawTensorStatistic): return False return self.tensor_eq(self.values, other.values) - - -class WCTensorStatistic(TensorStatistic): - MEAN_STAT = "mean_values" - SHAPE_STAT = "shape_values" - - def __init__(self, mean_values: List[Tensor], shapes: List[Tuple[int, ...]]): - """ - :param mean_values: List of N tensors of shape [HiddenDim] obtained by reducing activations along batch and - sequence length dimensions. - :param shapes: List of N tuples containing original shapes of activations before reduction. - """ - self.mean_values = mean_values - self.shapes = shapes - - def __eq__(self, other: Any) -> bool: - shapes_equal = all(self.shapes[i] == other.shapes[i] for i in range(len(self.mean_values))) - if not shapes_equal: - return False - mean_values_equal = all( - self.tensor_eq(self.mean_values[i], other.mean_values[i]) for i in range(len(self.mean_values)) - ) - return mean_values_equal diff --git a/nncf/common/tensor_statistics/statistics_serializer.py b/nncf/common/tensor_statistics/statistics_serializer.py new file mode 100644 index 00000000000..03a1d72bc3c --- /dev/null +++ b/nncf/common/tensor_statistics/statistics_serializer.py @@ -0,0 +1,115 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gzip +import json +import pickle +import re +from pathlib import Path +from typing import Any, Dict, Optional, Tuple, cast + +import nncf + +METADATA_FILE = "statistics_metadata.json" + + +def sanitize_filename(filename: str) -> str: + """ + Replaces any forbidden characters with an underscore. + """ + return re.sub(r'[\/:*?"<>|]', "_", filename) + + +def load_metadata(dir_path: Path) -> Dict[str, Any]: + """ + Loads the metadata, including the mapping and any other metadata information from the metadata file. + :param dir_path: The directory where the metadata file is stored. + :return: A dictionary containing the mapping and metadata. + """ + metadata_file = dir_path / METADATA_FILE + if metadata_file.exists(): + with open(metadata_file, "r") as f: + return cast(Dict[str, Any], json.load(f)) + return {"mapping": {}, "metadata": {}} + + +def save_metadata(metadata: Dict[str, Any], dir_path: Path) -> None: + """ + Saves the mapping and metadata to the metadata file. + :param metadata: The dictionary containing both the mapping and other metadata. + :param dir_path: The directory where the metadata file will be stored. + """ + metadata_file = dir_path / METADATA_FILE + with open(metadata_file, "w") as f: + json.dump(metadata, f, indent=4) + + +def load_from_dir(dir_path: str) -> Tuple[Dict[str, Any], Dict[str, str]]: + """ + Loads statistics from gzip-compressed files in the given directory. + :param dir_path: The path to the directory from which to load the statistics. + :return: 1) A dictionary with the original statistic names as keys and the loaded statistics as values. + 2) Metadata dictionary. + """ + statistics = {} + path = Path(dir_path) + if not path.exists(): + raise nncf.ValidationError("The provided directory path does not exist.") + metadata = load_metadata(path) + mapping = metadata.get("mapping", {}) + + for statistics_file in path.iterdir(): + if statistics_file.name == METADATA_FILE: + continue # Skip the metadata file + + try: + with gzip.open(statistics_file, "rb") as f: + sanitized_name = statistics_file.name + original_name = mapping.get(sanitized_name, sanitized_name) + statistics[original_name] = pickle.load(f) + except (pickle.UnpicklingError, IOError) as e: + raise nncf.InternalError(f"Error loading statistics from {statistics_file.name}: {e}") + return statistics, metadata.get("metadata", {}) + + +def dump_to_dir( + statistics: Dict[str, Any], dir_path: str, additional_metadata: Optional[Dict[str, Any]] = None +) -> None: + """ + Dumps statistics to gzip-compressed files in the specified directory, while maintaining a mapping file. + :param statistics: A dictionary with statistic names as keys and the statistic data as values. + :param dir_path: The path to the directory where the statistics will be dumped. + :param additional_metadata: A dictionary containing any additional metadata to be saved with the mapping. + """ + path = Path(dir_path) + path.mkdir(parents=True, exist_ok=True) + + metadata, mapping = {}, {} + + for original_name, statistics_value in statistics.items(): + sanitized_name = sanitize_filename(original_name) + file_path = path / sanitized_name + + # Update the mapping + mapping[sanitized_name] = original_name + + try: + with gzip.open(file_path, "wb") as f: + pickle.dump(statistics_value, f) + except (IOError, pickle.PicklingError) as e: + raise nncf.InternalError(f"Failed to write data to file {file_path}: {e}") + + # Add additional metadata if provided + if additional_metadata: + metadata["metadata"] = additional_metadata + + # Update the mapping in the metadata file + metadata["mapping"] = mapping + save_metadata(metadata, path) diff --git a/nncf/common/tensor_statistics/statistics_validator.py b/nncf/common/tensor_statistics/statistics_validator.py new file mode 100644 index 00000000000..fe926aaa3f2 --- /dev/null +++ b/nncf/common/tensor_statistics/statistics_validator.py @@ -0,0 +1,30 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict + +import nncf +from nncf.common.utils.backend import BackendType + + +def validate_backend(data: Dict[str, Any], backend: BackendType) -> None: + """ + Checks whether backend in loaded data is equal to a provided backend. + + :param data: Loaded statistics. + :param backend: Provided backend. + """ + if "backend" not in data: + raise nncf.ValidationError("The provided metadata has no information about backend.") + data_backend = data["backend"] + if data_backend != backend.value: + raise nncf.ValidationError( + f"Backend in loaded statistics {data_backend} does not match to an expected backend {backend.value}." + ) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 89d85c8d2d4..1ffd53dceab 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -13,19 +13,15 @@ from abc import abstractmethod from collections import defaultdict from collections import deque +from copy import deepcopy from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import nncf import nncf.tensor.functions as fns from nncf.common.tensor import TensorType from nncf.common.tensor_statistics.collectors import ReductionAxes -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel -from nncf.experimental.common.tensor_statistics.statistics import MeanTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import MedianMADTensorStatistic -from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.common.tensor_statistics.statistics import PercentileTensorStatistic -from nncf.experimental.common.tensor_statistics.statistics import RawTensorStatistic from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic from nncf.quantization.advanced_parameters import AggregatorType from nncf.tensor import Tensor @@ -199,7 +195,8 @@ def __init__(self, statistic_container: Optional[Type[TensorStatistic]] = None) self._aggregators: Dict[Tuple[int, int, int], AggregatorBase] = {} self._stat_container_kwargs_map: Dict[str, Tuple[int, int, int]] = {} self._stat_container = statistic_container - self._enabled = True + self.enable() + self.clear_cache() @property def num_samples(self) -> Optional[int]: @@ -268,9 +265,8 @@ def register_inputs(self, inputs: Dict[int, List[Tensor]]) -> None: :param inputs: Tensor inputs in format of dict where keys are reducer names and values are correspondent input tensors """ - if not self._enabled: + if not self.enabled: return - reduced_inputs = {} for reducer in self._reducers: reducer_hash = hash(reducer) @@ -304,22 +300,47 @@ def _aggregate(self) -> None: result[key] = val return result - def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]: + def set_cache(self, statistics: TensorStatistic) -> None: + """ + Sets cached statistics from given config and disable TensorCollector. + :param statistics: TensorStatistic. + """ + self._cached_statistics = statistics + self.reset() + self.disable() + + def create_statistics_container(self, config: Dict[str, Any]) -> TensorStatistic: + """ + Returns a TensorStatistic instance with aggregated values. + + :param config: Aggregated values. + :return: TensorStatistic instance. + """ + if not self._stat_container: # TODO(kshpv): need to remove an ability to return a Dict. + return config + return self._stat_container.from_config(config) + + def clear_cache(self) -> None: + """ + Clears the cached statistics and enables TensorCollector. + """ + self._cached_statistics = None + + def get_statistics(self) -> TensorStatistic: """ Returns aggregated values in format of a TensorStatistic instance or a dict. - :returns: Aggregated values. + :return: Aggregated values. """ + if self._cached_statistics is not None: + return deepcopy(self._cached_statistics) aggregated_values = self._aggregate() - kwargs = {} + statistics_config = {} for container_key, branch_key in self._stat_container_kwargs_map.items(): - kwargs[container_key] = aggregated_values[branch_key] - - if not self._stat_container: - return kwargs - return self._build_statistic_container(self._stat_container, kwargs) + statistics_config[container_key] = aggregated_values[branch_key] + return self.create_statistics_container(statistics_config) def replace_aggregator(self, key: Tuple[int, int, int], aggregator: AggregatorBase) -> None: """ @@ -356,43 +377,6 @@ def get_tensor_collector_inputs( target_inputs[reducer] = [outputs[name] for name in names] return target_inputs - @staticmethod - def _build_statistic_container(statistic_container_cls: Type[TensorStatistic], kwargs: Dict[Any, Any]): - if issubclass(statistic_container_cls, MinMaxTensorStatistic): - return statistic_container_cls( - min_values=kwargs[MinMaxTensorStatistic.MIN_STAT], max_values=kwargs[MinMaxTensorStatistic.MAX_STAT] - ) - if issubclass(statistic_container_cls, MeanTensorStatistic): - return statistic_container_cls( - mean_values=kwargs[MeanTensorStatistic.MEAN_STAT], shape=kwargs[MeanTensorStatistic.SHAPE_STAT] - ) - if issubclass(statistic_container_cls, RawTensorStatistic): - return statistic_container_cls(values=kwargs[RawTensorStatistic.VALUES_STATS]) - if issubclass(statistic_container_cls, MedianMADTensorStatistic): - return statistic_container_cls( - median_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ - MedianMADTensorStatistic.MEDIAN_VALUES_STAT - ], - mad_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ - MedianMADTensorStatistic.MAD_VALUES_STAT - ], - ) - if issubclass(statistic_container_cls, PercentileTensorStatistic): - if PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY in kwargs: - percentile_vs_values_dict = kwargs[PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY] - else: - percentile_vs_values_dict = {} - for (_, percentile), value in kwargs.items(): - percentile_vs_values_dict[percentile] = value - return statistic_container_cls(percentile_vs_values_dict=percentile_vs_values_dict) - if issubclass(statistic_container_cls, WCTensorStatistic): - mean_values = [fns.squeeze(it) for it in kwargs[WCTensorStatistic.MEAN_STAT]] - shapes = [tuple(it.data) for it in kwargs[WCTensorStatistic.SHAPE_STAT]] - return statistic_container_cls(mean_values=mean_values, shapes=shapes) - raise nncf.InternalError( - f"Statistic collector class {statistic_container_cls} is not supported by the TensorCollector class." - ) - class MergedTensorCollector(TensorCollector): """ diff --git a/nncf/experimental/common/tensor_statistics/statistics.py b/nncf/experimental/common/tensor_statistics/statistics.py index 5711ca900c7..9396674edb3 100644 --- a/nncf/experimental/common/tensor_statistics/statistics.py +++ b/nncf/experimental/common/tensor_statistics/statistics.py @@ -13,7 +13,7 @@ from collections import Counter from dataclasses import dataclass -from typing import ClassVar, Dict, Tuple +from typing import Any, ClassVar, Dict, List, Tuple from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -24,6 +24,22 @@ class TensorStatistic: TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output" + def get_data(self) -> Dict[str, Any]: + return {key: getattr(self, key) for key in self.keys()} + + def load_data(self, data: Dict[str, Any]): + for key in self.keys(): + setattr(self, key, data.get(key)) + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> TensorStatistic: + args = {key: config[key] for key in cls.keys()} # noqa: SIM118 + return cls(**args) + + @classmethod + def keys(cls) -> Tuple[str]: + return () + @dataclass class MinMaxTensorStatistic(TensorStatistic): @@ -33,12 +49,32 @@ class MinMaxTensorStatistic(TensorStatistic): min_values: Tensor max_values: Tensor + @classmethod + def keys(cls): + return (cls.MIN_STAT, cls.MAX_STAT) + def __eq__(self, other: TensorStatistic): if isinstance(other, MinMaxTensorStatistic): return fns.allclose(self.min_values, other.min_values) and fns.allclose(self.max_values, other.max_values) return False +@dataclass +class AbsMaxTensorStatistic(TensorStatistic): + ABS_MAX_STAT: ClassVar[str] = "abs_max" + + abs_max: Tensor + + @classmethod + def keys(cls): + return (cls.ABS_MAX_STAT,) + + def __eq__(self, other: TensorStatistic): + if isinstance(other, AbsMaxTensorStatistic): + return fns.allclose(self.abs_max, other.abs_max) + return False + + @dataclass class MeanTensorStatistic(TensorStatistic): MEAN_STAT: ClassVar[str] = "mean_values" @@ -47,6 +83,10 @@ class MeanTensorStatistic(TensorStatistic): mean_values: Tensor shape: Tuple[int, ...] + @classmethod + def keys(cls): + return (cls.MEAN_STAT, cls.SHAPE_STAT) + def __eq__(self, other: TensorStatistic): if isinstance(other, MeanTensorStatistic): return self.shape == other.shape and fns.allclose(self.mean_values, other.mean_values) @@ -61,6 +101,10 @@ class MedianMADTensorStatistic(TensorStatistic): median_values: Tensor mad_values: Tensor + @classmethod + def keys(cls): + return (cls.MEDIAN_VALUES_STAT, cls.MAD_VALUES_STAT) + def __eq__(self, other: TensorStatistic): if isinstance(other, MedianMADTensorStatistic): return fns.allclose(self.median_values, other.median_values) and fns.allclose( @@ -68,6 +112,13 @@ def __eq__(self, other: TensorStatistic): ) return False + @classmethod + def from_config(cls, config: Dict[str, Any]) -> TensorStatistic: + return cls( + median_values=config[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MEDIAN_VALUES_STAT], + mad_values=config[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MAD_VALUES_STAT], + ) + @dataclass class PercentileTensorStatistic(TensorStatistic): @@ -75,6 +126,10 @@ class PercentileTensorStatistic(TensorStatistic): percentile_vs_values_dict: Dict[str, Tensor] + @classmethod + def keys(cls): + return (cls.PERCENTILE_VS_VALUE_DICT,) + def __eq__(self, other: TensorStatistic): if isinstance(other, PercentileTensorStatistic): if Counter(self.percentile_vs_values_dict.keys()) != Counter(other.percentile_vs_values_dict.keys()): @@ -85,6 +140,16 @@ def __eq__(self, other: TensorStatistic): return True return False + @classmethod + def from_config(cls, config: Dict[str, Any]) -> TensorStatistic: + if cls.TENSOR_STATISTIC_OUTPUT_KEY in config: + percentile_vs_values_dict = config[cls.TENSOR_STATISTIC_OUTPUT_KEY] + else: + percentile_vs_values_dict = {} + for (_, percentile), value in config.items(): + percentile_vs_values_dict[percentile] = value + return cls(percentile_vs_values_dict=percentile_vs_values_dict) + @dataclass class RawTensorStatistic(TensorStatistic): @@ -92,7 +157,106 @@ class RawTensorStatistic(TensorStatistic): values: Tensor + @classmethod + def keys(cls): + return (cls.VALUES_STATS,) + def __eq__(self, other: RawTensorStatistic) -> bool: - if isinstance(other, PercentileTensorStatistic): + if isinstance(other, RawTensorStatistic): return fns.allclose(self.values, other.values) return False + + +@dataclass +class HessianTensorStatistic(TensorStatistic): + HESSIAN_INPUT_ACTIVATION_STATS: ClassVar[str] = "hessian" + + hessian: Tensor + + @classmethod + def keys(cls): + return (cls.HESSIAN_INPUT_ACTIVATION_STATS,) + + def __eq__(self, other: TensorStatistic): + if isinstance(other, HessianTensorStatistic): + return fns.allclose(self.hessian, other.hessian) + return False + + +@dataclass +class MeanVarianceTensorStatistic(TensorStatistic): + MEAN_VARIANCE_STAT: ClassVar[str] = "mean_variance" + + mean_variance: Tensor + + @classmethod + def keys(cls): + return (cls.MEAN_VARIANCE_STAT,) + + def __eq__(self, other: TensorStatistic): + if isinstance(other, MeanVarianceTensorStatistic): + return fns.allclose(self.mean_variance, other.mean_variance) + return False + + +@dataclass +class MaxVarianceTensorStatistic(TensorStatistic): + MAX_VARIANCE_STAT: ClassVar[str] = "max_variance" + + max_variance: Tensor + + @classmethod + def keys(cls): + return (cls.MAX_VARIANCE_STAT,) + + def __eq__(self, other: TensorStatistic): + if isinstance(other, MaxVarianceTensorStatistic): + return fns.allclose(self.max_variance, other.max_variance) + return False + + +@dataclass +class MeanMagnitudeTensorStatistic(TensorStatistic): + MEAN_MAGNITUDE_STAT: ClassVar[str] = "mean_magnitude" + + mean_magnitude: Tensor + + @classmethod + def keys(cls): + return (cls.MEAN_MAGNITUDE_STAT,) + + def __eq__(self, other: TensorStatistic): + if isinstance(other, MeanMagnitudeTensorStatistic): + return fns.allclose(self.mean_magnitude, other.mean_magnitude) + return False + + +@dataclass +class WCTensorStatistic(TensorStatistic): + MEAN_STAT = "mean_values" + SHAPE_STAT = "shape_values" + + mean_values: List[Tensor] + shape_values: List[Tuple[int, ...]] + + @classmethod + def keys(cls): + return (cls.MEAN_STAT, cls.SHAPE_STAT) + + def __eq__(self, other: Any) -> bool: + shapes_equal = all(self.shapes[i] == other.shapes[i] for i in range(len(self.mean_values))) + if not shapes_equal: + return False + mean_values_equal = all( + self.tensor_eq(self.mean_values[i], other.mean_values[i]) for i in range(len(self.mean_values)) + ) + return mean_values_equal + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> TensorStatistic: + mean_values, shape_values = None, None + if cls.MEAN_STAT in config and config[cls.MEAN_STAT] is not None: + mean_values = [fns.squeeze(it) for it in config[cls.MEAN_STAT]] + if cls.SHAPE_STAT in config and config[cls.SHAPE_STAT] is not None: + shape_values = [tuple(it) for it in config[cls.SHAPE_STAT]] + return cls(mean_values=mean_values, shape_values=shape_values) diff --git a/nncf/experimental/torch/fx/statistics/aggregator.py b/nncf/experimental/torch/fx/statistics/aggregator.py index 9f109147d83..05af7f19bf5 100644 --- a/nncf/experimental/torch/fx/statistics/aggregator.py +++ b/nncf/experimental/torch/fx/statistics/aggregator.py @@ -20,7 +20,9 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder from nncf.tensor import Tensor @@ -51,6 +53,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FXStatisticsAggregator(StatisticsAggregator): + BACKEND: BackendType = BackendType.TORCH_FX HOOKS_GROUP_NAME = "statistics_hooks" def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: @@ -121,3 +124,14 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: return outputs + + def _get_statistics_key(self, statistics: TensorStatistic, target_point: PTTargetPoint) -> str: + """ + Returns key of statistics. + + :param statistics: Statistics value. + :param target_point: Statistics target point. + :return: Statistics key. + """ + target_point_id = f"{target_point.target_node_name}_{target_point.type}_{target_point.input_port_id}" + return f"{statistics.__class__.__name__}_{target_point_id}" diff --git a/nncf/onnx/statistics/aggregator.py b/nncf/onnx/statistics/aggregator.py index 0ae00618eb1..1240f5ee1d0 100644 --- a/nncf/onnx/statistics/aggregator.py +++ b/nncf/onnx/statistics/aggregator.py @@ -20,15 +20,20 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.aggregator import StatisticsAggregator from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic from nncf.onnx.graph.node_utils import get_input_edge from nncf.onnx.graph.node_utils import get_input_edges_mapping from nncf.onnx.graph.onnx_helper import get_name_to_node_map from nncf.onnx.graph.transformations.commands import ONNXOutputInsertionCommand +from nncf.onnx.graph.transformations.commands import ONNXTargetPoint from nncf.tensor import Tensor class ONNXStatisticsAggregator(StatisticsAggregator): + BACKEND: BackendType = BackendType.ONNX + def collect_statistics(self, model: onnx.ModelProto, graph: NNCFGraph) -> None: self.input_edges_mapping = get_input_edges_mapping(graph) self.node_mapping = get_name_to_node_map(model) @@ -87,3 +92,14 @@ def _get_merged_statistic_points( @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: return {n: Tensor(v) for n, v in outputs.items()} + + def _get_statistics_key(self, statistics: TensorStatistic, target_point: ONNXTargetPoint) -> str: + """ + Returns key of statistics. + + :param statistics: Statistics value. + :param target_point: Statistics target point. + :return: Statistics key. + """ + target_point_id = f"{target_point.target_node_name}_{target_point.type}_{target_point.port_id}" + return f"{statistics.__class__.__name__}_{target_point_id}" diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 37666a0980b..cbaf9ffb62d 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -10,12 +10,14 @@ # limitations under the License. from copy import deepcopy +from pathlib import Path from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union import openvino.runtime as ov from openvino._offline_transformations import compress_quantize_weights_transformation from nncf.common.factory import NNCFGraphFactory +from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.logging import nncf_logger from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset @@ -49,6 +51,8 @@ from nncf.quantization.quantize_model import is_model_no_batchwise_support from nncf.quantization.quantize_model import quantize_with_tune_hyperparams from nncf.quantization.quantize_model import warning_model_no_batchwise_support +from nncf.quantization.statistics_caching import cache_weight_compression_statistics +from nncf.quantization.statistics_caching import register_statistics_for_algorithm from nncf.scopes import IgnoredScope from nncf.scopes import validate_ignored_scope @@ -378,8 +382,8 @@ def compress_weights_impl( """ Implementation of the `compress_weights()` method for the OpenVINO backend. """ - model = remove_friendly_name_duplicates(model) + graph = NNCFGraphFactory.create(model) compression_algorithm = WeightCompression( mode, ratio, @@ -395,5 +399,24 @@ def compress_weights_impl( backup_mode, advanced_parameters, ) - graph = NNCFGraphFactory.create(model) - return compression_algorithm.apply(model, graph, dataset=dataset) + + statistics_points = None + if advanced_parameters and advanced_parameters.statistics_path: + # If there is no such directory, then caches statistics + if not Path(advanced_parameters.statistics_path).exists(): + cache_weight_compression_statistics(model, graph, dataset, subset_size, advanced_parameters.statistics_path) + statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + compression_algorithm.set_backend_entity(model) + _, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(graph) + register_statistics_for_algorithm( + statistics_aggregator, + model, + graph, + subset_size, + compression_algorithm, + matmul_input_to_output_nodes_map, + ) + statistics_aggregator.load_statistics_from_dir(advanced_parameters.statistics_path) + statistics_points = statistics_aggregator.statistic_points + + return compression_algorithm.apply(model, graph, statistics_points, dataset) diff --git a/nncf/openvino/statistics/aggregator.py b/nncf/openvino/statistics/aggregator.py index 86a27c71e08..0f0dbd37ecd 100644 --- a/nncf/openvino/statistics/aggregator.py +++ b/nncf/openvino/statistics/aggregator.py @@ -21,16 +21,21 @@ from nncf.common.tensor_statistics.aggregator import StatisticsAggregator from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic from nncf.openvino.graph.node_utils import get_ov_model_reduce_node_name from nncf.openvino.graph.node_utils import get_reducer_output_node_names from nncf.openvino.graph.transformations.commands import OVInplaceFnInsertionCommand from nncf.openvino.graph.transformations.commands import OVOutputInsertionCommand +from nncf.openvino.graph.transformations.commands import OVTargetPoint from nncf.tensor import Tensor class OVStatisticsAggregator(StatisticsAggregator): + BACKEND: BackendType = BackendType.OPENVINO + def collect_statistics(self, model: ov.Model, graph: NNCFGraph) -> None: self._name_to_node_mapping = {op.get_friendly_name(): op for op in model.get_ops()} super().collect_statistics(model, graph) @@ -125,3 +130,14 @@ def _translate_to_post_layer_operation(self, statistic_point: StatisticPoint): @staticmethod def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: return {n: Tensor(v) for n, v in outputs.items()} + + def _get_statistics_key(self, statistics: TensorStatistic, target_point: OVTargetPoint) -> str: + """ + Returns key of statistics. + + :param statistics: Statistics value. + :param target_point: Statistics target point. + :return: Statistics key. + """ + target_point_id = f"{target_point.target_node_name}_{target_point.type}_{target_point.port_id}" + return f"{statistics.__class__.__name__}_{target_point_id}" diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index f8224e15e15..cad3f1ba969 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -361,12 +361,15 @@ class AdvancedCompressionParameters: """ Contains advanced parameters for compression algorithms. + :param statistics_path: Directory path to dump statistics. + :type statistics_path: str :param awq_params: Advanced parameters for AWQ algorithm. :type awq_params: AdvancedAWQParameters :param scale_estimation_params: Advanced parameters for scale estimation algorithm. :type scale_estimation_params: AdvancedScaleEstimationParameters """ + statistics_path: Optional[str] = None # Advanced AWQ algorithm parameters awq_params: AdvancedAWQParameters = field(default_factory=AdvancedAWQParameters) diff --git a/nncf/quantization/algorithms/weight_compression/activation_stats.py b/nncf/quantization/algorithms/weight_compression/activation_stats.py index d5513eca525..5d9d528e57b 100644 --- a/nncf/quantization/algorithms/weight_compression/activation_stats.py +++ b/nncf/quantization/algorithms/weight_compression/activation_stats.py @@ -13,7 +13,7 @@ from operator import mul from typing import Tuple -from nncf.common.tensor_statistics.statistics import WCTensorStatistic +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.tensor import Tensor from nncf.tensor import functions as fns @@ -34,7 +34,7 @@ def process_stats(stats: WCTensorStatistic, subset_size: int) -> Tuple[Tensor, T # prevent high memory and time consumption if X_full.shape[1] > subset_size: # activations were reduced across all but the last dimension - lens = [reduce(mul, shape[:-1], 1) for shape in stats.shapes] + lens = [reduce(mul, shape[:-1], 1) for shape in stats.shape_values] step = X_full.shape[1] // subset_size idxs = [i[0] for i in sorted(enumerate(lens), key=lambda x: -x[1])][::step] X = X_full[:, idxs] # [HiddenDim, ~SubsetSize] diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 4dcc4822c7c..81bb4406f0a 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -14,7 +14,7 @@ from collections import defaultdict from functools import partial from functools import reduce -from typing import Dict, Iterable, List, Optional, Tuple, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar import nncf from nncf import Dataset @@ -27,10 +27,10 @@ from nncf.common.scopes import should_consider_scope from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend from nncf.common.utils.helpers import create_table +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.parameters import BackupMode from nncf.parameters import CompressWeightsMode from nncf.parameters import SensitivityMetric @@ -50,6 +50,127 @@ TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") +INT8_MODES = [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM] +NON_INT8_MODES = [ + CompressWeightsMode.INT4_SYM, + CompressWeightsMode.INT4_ASYM, + CompressWeightsMode.NF4, + CompressWeightsMode.E2M1, +] + + +def get_weight_compression_configuration( + mode: CompressWeightsMode = CompressWeightsMode.INT8_ASYM, + dataset: Optional[Dataset] = None, + ratio: Optional[float] = None, + group_size: Optional[int] = None, + all_layers: Optional[bool] = None, + awq: Optional[bool] = None, + scale_estimation: Optional[bool] = None, + gptq: Optional[bool] = None, + lora_correction: Optional[bool] = None, + ignored_scope: Optional[IgnoredScope] = None, + sensitivity_metric: Optional[SensitivityMetric] = None, + backup_mode: Optional[BackupMode] = None, + advanced_parameters: Optional[AdvancedCompressionParameters] = None, +) -> Dict[str, Any]: + """ + Generates a configuration dictionary for weight compression based on the provided parameters. + """ + group_size = ( + -1 + if group_size is None and mode in INT8_MODES + else 128 if group_size is None and mode in NON_INT8_MODES else group_size + ) + + return { + "mode": mode, + "ratio": ratio or 1, + "group_size": group_size, + "all_layers": all_layers or False, + "awq": awq or False, + "scale_estimation": scale_estimation or False, + "gptq": gptq or False, + "lora_correction": lora_correction or False, + "ignored_scope": ignored_scope or IgnoredScope(), + "sensitivity_metric": ( + ( + SensitivityMetric.WEIGHT_QUANTIZATION_ERROR + if dataset is None + else SensitivityMetric.MAX_ACTIVATION_VARIANCE + ) + if sensitivity_metric is None + else sensitivity_metric + ), + "backup_mode": backup_mode or BackupMode.INT8_ASYM, + "advanced_parameters": advanced_parameters or AdvancedCompressionParameters(), + } + + +def check_user_compression_configuration( + mode: CompressWeightsMode, + subset_size: int, + dataset: Optional[Dataset], + ratio: Optional[float], + group_size: Optional[int], + all_layers: Optional[bool], + awq: Optional[bool], + scale_estimation: Optional[bool], + gptq: Optional[bool], + lora_correction: Optional[bool], + ignored_scope: Optional[IgnoredScope], + sensitivity_metric: Optional[SensitivityMetric], + backup_mode: Optional[BackupMode], + advanced_parameters: Optional[AdvancedCompressionParameters], +) -> None: + """ + Validates the user's weight compression configuration for correctness. + """ + if mode in INT8_MODES: + if (ratio and ratio != 1) or (group_size and group_size != -1): + raise nncf.ParameterNotSupportedError( + "INT8 modes require per-channel quantization of all layers in 8 bit. " + "Default values of `ratio` (1) and `group_size` (-1) cannot be overridden." + ) + + if advanced_parameters and advanced_parameters.statistics_path: + raise nncf.ParameterNotSupportedError( + "INT8 modes do not support the `statistics_path` option in `AdvancedCompressionParameters`." + ) + + unsupported_options = { + "all_layers": all_layers, + "sensitivity_metric": sensitivity_metric, + "dataset": dataset, + "awq": awq, + "scale_estimation": scale_estimation, + "gptq": gptq, + "lora_correction": lora_correction, + "backup_mode": backup_mode, + } + unsupported_for_int8 = [name for name, value in unsupported_options.items() if value is not None] + if unsupported_for_int8: + raise nncf.ParameterNotSupportedError( + f"INT8 modes do not support {', '.join(unsupported_for_int8)} option(s). Set them to None." + ) + + if ratio is not None and not (0 <= ratio <= 1): + raise nncf.ValidationError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.") + + if subset_size <= 0: + raise nncf.ValidationError(f"The subset_size value should be positive, but subset_size={subset_size} is given.") + + if ( + ratio + and dataset is None + and sensitivity_metric is not None + and sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR + ): + raise nncf.ValidationError( + f"Mixed precision selection with sensitivity metric={sensitivity_metric.value} \ + requires a dataset, but it's not provided." + ) + class WeightCompression(Algorithm): """ @@ -136,8 +257,7 @@ def __init__( primary_config = WeightCompressionConfig(mode=self._mode, group_size=self._group_size) criterion_cls = MIXED_PRECISION_CRITERIA.get(self._sensitivity_metric) self._mixed_precision_algo = criterion_cls(primary_config, self._ratio) - self._mixed_precision_statistics = None - + self._statistics_path = self._advanced_parameters.statistics_path if self._gptq: gptq_params = self._advanced_parameters.gptq_params self._gptq_algo = GPTQ( @@ -147,11 +267,16 @@ def __init__( scale_estimation=self._scale_estimation, ) + self._data_aware_mixed_precision = ( + self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 + ) + self._data_aware_compression = self._awq or self._scale_estimation or self._lora_correction or self._gptq + @property def available_backends(self) -> List[BackendType]: return [BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] - def _set_backend_entity(self, model: TModel) -> None: + def set_backend_entity(self, model: TModel) -> None: """ Creates a helper class with a backed-specific logic of the algorithm. @@ -175,7 +300,7 @@ def _set_backend_entity(self, model: TModel) -> None: "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) ) - def _get_nodes_to_compress(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: + def get_nodes_to_compress(self, nncf_graph: NNCFGraph) -> List[NNCFNode]: """ Collects nodes in the model's graph corresponding to the layers for weight compression. @@ -243,6 +368,7 @@ def _set_weight_compression_config( ratio_defining_params: List[WeightCompressionParameters], model: TModel, graph: NNCFGraph, + statistics_points: StatisticPointsContainer, ) -> None: """ Sets the appropriate compression configuration for weights based on some criteria. @@ -251,15 +377,14 @@ def _set_weight_compression_config( backup precisions. :param model: The model. :param graph: The model graph associated with the model. + :param statistics_points: Statistics points. """ primary_config = WeightCompressionConfig(mode=self._mode, group_size=self._group_size) if self._ratio == 1: for weight_param in ratio_defining_params: weight_param.compression_config = primary_config else: - self._mixed_precision_algo.apply( - model, graph, self._mixed_precision_statistics, weight_params=ratio_defining_params - ) + self._mixed_precision_algo.apply(model, graph, statistics_points, weight_params=ratio_defining_params) @staticmethod def _proportion_str(num_weights_list: List[int], total_num_weights: int, total_num_params: int) -> str: @@ -360,16 +485,25 @@ def apply( statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, ) -> TModel: - self._set_backend_entity(model) - nodes_to_compress = self._get_nodes_to_compress(graph) + self.set_backend_entity(model) + + nodes_to_compress = self.get_nodes_to_compress(graph) statistics = None - data_aware_mixed_precision = ( - self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 - ) - data_aware_compression = self._awq or self._scale_estimation or self._lora_correction or self._gptq - if data_aware_mixed_precision or data_aware_compression: - statistics = self._collect_statistics(dataset, nodes_to_compress, graph, model) + if self._data_aware_mixed_precision or self._data_aware_compression: + matmul_nodes_to_compress = [ + node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes + ] + matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map( + matmul_nodes_to_compress, graph + ) + if statistic_points is None: + statistic_points = self.get_statistic_points(model, graph, matmul_input_to_output_nodes_map.keys()) + statistic_points = self._collect_statistics(dataset, graph, model, statistic_points) + statistics = self._get_statistics_for_weights_compression( + matmul_input_to_output_nodes_map, statistic_points + ) + all_weight_params: List[WeightCompressionParameters] = [] weight_names = set() @@ -421,7 +555,7 @@ def apply( weight_names.add(weight_name) ratio_defining_params = self._get_ratio_defining_params(all_weight_params, is_last_layer_shared) - self._set_weight_compression_config(ratio_defining_params, model, graph) + self._set_weight_compression_config(ratio_defining_params, model, graph, statistic_points) ignored_scope_weight_statistics = self._get_ignored_scope_weight_statistics(model, graph) nncf_logger.info( self._get_bitwidth_distribution_str( @@ -549,58 +683,76 @@ def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) - port_id = activation_edge.output_port_id return activation_node, port_id - def _collect_statistics(self, dataset: Dataset, nodes: List[NNCFNode], graph: NNCFGraph, model: TModel): + def get_matmul_input_to_output_nodes_map( + self, matmul_nodes: List[NNCFNode], graph: NNCFGraph + ) -> Dict[Tuple[NNCFNode, int], List[NNCFNode]]: + """ + Maps activation nodes to their corresponding MatMul nodes in the graph. + + Each weighted MatMul node takes two inputs: an activation and a weight. + An activation node may serve as an input to multiple MatMul nodes. + This function returns a mapping where each key is a tuple consisting of an + activation node and its output port ID, and the value is a list of MatMul + nodes that use this activation as input. + + :param matmul_nodes: A list of MatMul nodes from the computation graph. + :param graph: An instance of NNCFGraph representing the computation graph. + :return: A dictionary mapping from a tuple of (activation node, port ID) + to a list of corresponding MatMul nodes that accept the activation as input. """ - Collects statistics required for data-aware algorithms and/or mixed precision assignment. + matmul_input_to_output_nodes_map = defaultdict(list) + for node in matmul_nodes: + if node.layer_attributes.input_attributes["transpose"]: # It works only for OV + raise nncf.UnsupportedModelError("Transposed input is not supported") + act_node, output_port_id = self._get_activation_node_and_port(node, graph) + matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node) + return matmul_input_to_output_nodes_map + + def get_compression_nodes_info( + self, graph: NNCFGraph + ) -> Tuple[List[NNCFNode], Dict[Tuple[NNCFNode, int], List[NNCFNode]]]: + """ + Retrieves the nodes to compress along with a mapping of activation nodes + to their corresponding MatMul nodes. + + This function first identifies all nodes that can be compressed from the + provided graph. It then filters these nodes to find those that are of + MatMul type and generates a mapping of activation nodes to their + corresponding MatMul nodes using the + `get_matmul_input_to_output_nodes_map` function. + + :param graph: An instance of NNCFGraph representing the computation graph. + :return: A tuple containing: + - Nodes for compression. + - A dictionary mapping from a tuple of (activation node, port ID) + to a list of MatMul nodes that accept the activation as input. + """ + nodes_to_compress = self.get_nodes_to_compress(graph) + matmul_nodes_to_compress = [ + node for node in nodes_to_compress if node.metatype in self._backend_entity.matmul_metatypes + ] + matmul_input_to_output_nodes_map = self.get_matmul_input_to_output_nodes_map(matmul_nodes_to_compress, graph) + return nodes_to_compress, matmul_input_to_output_nodes_map + + def _collect_statistics( + self, + dataset: Dataset, + graph: NNCFGraph, + model: TModel, + statistic_points: StatisticPointsContainer, + ): + """ + Creates statistics aggregator, registers all statistics specified for algorithm, and then collect them. :param dataset: Dataset to collect values. - :param nodes: List of nodes, whose inputs are collected. :param graph: Model graph. :param model: Model for statistics collection. + :param statistic_points: Statistics points. """ - statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) - - statistic_points = None - matmul_input_to_output_nodes_map = None - - data_aware_precision_assignment = ( - self._sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR and self._ratio != 1.0 - ) - data_aware_compression = self._awq or self._scale_estimation or self._lora_correction - if data_aware_compression or data_aware_precision_assignment: - # Collect statistics only for weighted MatMul nodes - matmul_metatypes = self._backend_entity.matmul_metatypes - matmul_nodes = filter(lambda node: node.metatype in matmul_metatypes, nodes) - - # Each weighted MatMul node has two input nodes: an activation and a weight. - # A single activation may be an input to multiple MatMul nodes. - # Below is a mapping from activation node and a port id to corresponding matmul nodes which accept this - # activation as an input. - matmul_input_to_output_nodes_map = defaultdict(list) - for node in matmul_nodes: - if node.layer_attributes.input_attributes["transpose"]: - raise nncf.UnsupportedModelError("Transposed input is not supported") - act_node, output_port_id = self._get_activation_node_and_port(node, graph) - matmul_input_to_output_nodes_map[(act_node, output_port_id)].append(node) - - if data_aware_precision_assignment: - self._mixed_precision_statistics = self._mixed_precision_algo.get_statistic_points( - model, graph, matmul_input_to_output_nodes_map.keys(), self._subset_size - ) - statistics_aggregator.register_statistic_points(self._mixed_precision_statistics) - if data_aware_compression: - statistic_points = self.get_statistic_points( - model, graph, matmul_input_to_output_nodes_map.keys(), self._subset_size - ) - statistics_aggregator.register_statistic_points(statistic_points) - + statistics_aggregator.register_statistic_points(statistic_points) statistics_aggregator.collect_statistics(model, graph) - - statistics = None - if statistic_points is not None: - statistics = self._get_statistics(matmul_input_to_output_nodes_map, statistic_points) - return statistics + return statistics_aggregator.statistic_points def get_statistic_points( self, @@ -618,31 +770,42 @@ def get_statistic_points( :param subset_size: Number of samples to collect. :return: Statistic points, for which StatisticsCollector should collect statistics. """ - statistic_container = StatisticPointsContainer() - for node, output_port_id in nodes_and_port_ids: - statistic_point = self._backend_entity.target_point( - TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id - ) - # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden - # size dimension. - n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) - stat_collector = self._backend_entity.mean_statistic_collector( - reduction_axes=tuple(range(n_dims - 1)), subset_size=subset_size - ) - statistic_container.add_statistic_point( - StatisticPoint( - target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key + # Statistics for data aware algorithms + if self._data_aware_compression: + for node, output_port_id in nodes_and_port_ids: + statistic_point = self._backend_entity.target_point( + TargetType.POST_LAYER_OPERATION, node.node_name, port_id=output_port_id ) + # Reduce activations across all but the last dimension. The last dimension is assumed to be the hidden + # size dimension. + n_dims = len(graph.get_output_edges_by_port_id(node, output_port_id)[0].tensor_shape) + stat_collector = self._backend_entity.mean_statistic_collector( + reduction_axes=tuple(range(n_dims - 1)), subset_size=subset_size + ) + statistic_container.add_statistic_point( + StatisticPoint( + target_point=statistic_point, tensor_collector=stat_collector, algorithm=self._algorithm_key + ) + ) + # Statistics for mixed precision algorithm + if self._data_aware_mixed_precision: + mixed_precision_statistics = self._mixed_precision_algo.get_statistic_points( + model, graph, nodes_and_port_ids, self._subset_size ) + for points in mixed_precision_statistics.values(): + for point in points: + statistic_container.add_statistic_point(point) return statistic_container - def _get_statistics( - self, matmul_input_to_output_nodes_map: Dict[Tuple[NNCFNode, int], List[NNCFNode]], statistic_points + def _get_statistics_for_weights_compression( + self, + matmul_input_to_output_nodes_map: Dict[Tuple[NNCFNode, int], List[NNCFNode]], + statistic_points: StatisticPointsContainer, ) -> Dict[str, WCTensorStatistic]: """ - Retrieve collected statistics. + Retrieve collected statistics only for WeightCompression algorithm and not for MixedPrecision. :param matmul_input_to_output_nodes_map: A mapping from activation node and a port id to corresponding matmul nodes which accept this activation as an input. @@ -672,11 +835,13 @@ def input_filter_func(point, port_id): act_node.node_name, partial(input_filter_func, port_id=output_port_id), self._algorithm_key ) ) - assert len(tensor_collectors) == 1 - stats = tensor_collectors[0].get_statistics() - - # Each activation node may have multiple MatMul nodes which it is an input to - for node in matmul_nodes: - statistics[node.node_name] = copy.deepcopy(stats) - + # Statistics could be empty in case when the statistics is registered for another algorithm, + # e.g. mixed precision. + if tensor_collectors: + assert len(tensor_collectors) == 1 + stats = tensor_collectors[0].get_statistics() + + # Each activation node may have multiple MatMul nodes which it is an input to + for node in matmul_nodes: + statistics[node.node_name] = copy.deepcopy(stats) return statistics diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index ee5237eb8e2..6de2a3659cf 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -23,9 +23,9 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.algorithm import Algorithm from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats diff --git a/nncf/quantization/algorithms/weight_compression/backend.py b/nncf/quantization/algorithms/weight_compression/backend.py index 6cc40e564b2..357542be6aa 100644 --- a/nncf/quantization/algorithms/weight_compression/backend.py +++ b/nncf/quantization/algorithms/weight_compression/backend.py @@ -13,7 +13,6 @@ from abc import abstractmethod from typing import Dict, Iterable, List, Optional, Tuple, TypeVar -from nncf import SensitivityMetric from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -23,6 +22,7 @@ from nncf.experimental.common.tensor_statistics.collectors import HAWQAggregator from nncf.experimental.common.tensor_statistics.collectors import NoopReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import HessianTensorStatistic from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters from nncf.tensor import Tensor from nncf.tensor import TensorDataType @@ -254,8 +254,8 @@ class MixedPrecisionAlgoBackend(ABC): def hawq_statistic_collector(subset_size: Optional[int] = None) -> TensorCollector: reducer = NoopReducer() aggregator = HAWQAggregator(num_samples=subset_size) - collector = TensorCollector() - collector.register_statistic_branch(SensitivityMetric.HESSIAN_INPUT_ACTIVATION.value, reducer, aggregator) + collector = TensorCollector(HessianTensorStatistic) + collector.register_statistic_branch(HessianTensorStatistic.HESSIAN_INPUT_ACTIVATION_STATS, reducer, aggregator) return collector @staticmethod diff --git a/nncf/quantization/algorithms/weight_compression/lora_correction.py b/nncf/quantization/algorithms/weight_compression/lora_correction.py index 0c9bb3409ba..83c0acbd228 100644 --- a/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -16,9 +16,9 @@ import nncf from nncf.common.logging import nncf_logger -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.common.utils.debug import DEBUG_LOG_DIR from nncf.common.utils.debug import is_debug +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.parameters import CompressWeightsMode from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats diff --git a/nncf/quantization/algorithms/weight_compression/mixed_precision.py b/nncf/quantization/algorithms/weight_compression/mixed_precision.py index 86ff7145a78..93c9c8d8b6c 100644 --- a/nncf/quantization/algorithms/weight_compression/mixed_precision.py +++ b/nncf/quantization/algorithms/weight_compression/mixed_precision.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC from abc import abstractmethod from typing import Iterable, List, Optional, Tuple, TypeVar @@ -205,7 +206,7 @@ def get_statistic_points( raise RuntimeError("No statistics collection intended for data-free mixed precision criterion") -class DataBasedCriterion(DataFreeCriterion): +class DataBasedCriterion(DataFreeCriterion, ABC): """ Data-based mixed precision criterion that takes into account outliers in the input statistics. Expecting statistics of the following shape: [hidden_dim] @@ -326,11 +327,12 @@ def input_filter_func(point): for tensor_collector in statistic_points.get_algo_statistics_for_node( act_node.node_name, input_filter_func, self._algorithm_key ): - statistics = tensor_collector.get_statistics()[stat_key] - if isinstance(statistics, Tensor): - stats.append(statistics) - else: - stats.extend(statistics) + statistics = tensor_collector.get_statistics() + for data in statistics.get_data().values(): + if isinstance(data, Tensor): + stats.append(data) + else: + stats.extend(data) return stats diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 3d17d1a6af4..ec4dfab4711 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -21,10 +21,13 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.utils import get_reduction_axes -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MaxVarianceTensorStatistic +from nncf.experimental.common.tensor_statistics.statistics import MeanMagnitudeTensorStatistic +from nncf.experimental.common.tensor_statistics.statistics import MeanVarianceTensorStatistic +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.openvino.graph.metatypes import openvino_metatypes as om from nncf.openvino.graph.metatypes.groups import ATOMIC_ACTIVATIONS_OPERATIONS from nncf.openvino.graph.model_transformer import OVModelTransformer @@ -39,7 +42,6 @@ from nncf.openvino.statistics.collectors import OVMeanVarianceReducer from nncf.openvino.statistics.collectors import OVShapeReducer from nncf.parameters import CompressWeightsMode -from nncf.parameters import SensitivityMetric from nncf.quantization.algorithms.weight_compression.awq_patterns import get_awq_patterns from nncf.quantization.algorithms.weight_compression.backend import AWQAlgoBackend from nncf.quantization.algorithms.weight_compression.backend import MixedPrecisionAlgoBackend @@ -405,8 +407,8 @@ def mean_variance_statistic_collector( ) -> TensorCollector: reducer = OVMeanVarianceReducer(reduction_axes, inplace=True) aggregator = MeanAggregator(num_samples=subset_size) - collector = TensorCollector() - collector.register_statistic_branch(SensitivityMetric.MEAN_ACTIVATION_VARIANCE.value, reducer, aggregator) + collector = TensorCollector(MeanVarianceTensorStatistic) + collector.register_statistic_branch(MeanVarianceTensorStatistic.MEAN_VARIANCE_STAT, reducer, aggregator) return collector @staticmethod @@ -415,8 +417,8 @@ def max_variance_statistic_collector( ) -> TensorCollector: reducer = OVMaxVarianceReducer(reduction_axes, inplace=True) aggregator = MeanAggregator(num_samples=subset_size) - collector = TensorCollector() - collector.register_statistic_branch(SensitivityMetric.MAX_ACTIVATION_VARIANCE.value, reducer, aggregator) + collector = TensorCollector(MaxVarianceTensorStatistic) + collector.register_statistic_branch(MaxVarianceTensorStatistic.MAX_VARIANCE_STAT, reducer, aggregator) return collector @staticmethod @@ -425,6 +427,6 @@ def mean_abs_max_statistic_collector( ) -> TensorCollector: reducer = OVMeanAbsMaxReducer(reduction_axes, inplace=True) aggregator = MeanAggregator(num_samples=subset_size) - collector = TensorCollector() - collector.register_statistic_branch(SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE.value, reducer, aggregator) + collector = TensorCollector(MeanMagnitudeTensorStatistic) + collector.register_statistic_branch(MeanMagnitudeTensorStatistic.MEAN_MAGNITUDE_STAT, reducer, aggregator) return collector diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index 0596e94d432..aaa46a4e7c6 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -18,9 +18,9 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.logging.track_progress import track from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.activation_stats import process_stats from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index 52fae7531ec..136c38413ab 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -21,11 +21,11 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.experimental.common.tensor_statistics.collectors import MeanReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.parameters import CompressWeightsMode from nncf.quantization.algorithms.weight_compression.backend import WeightCompressionAlgoBackend from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index 2816f82a6e2..c7c0a685244 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -22,11 +22,11 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.tensor_statistics.statistics import WCTensorStatistic from nncf.experimental.common.tensor_statistics.collectors import MeanReducer from nncf.experimental.common.tensor_statistics.collectors import NoopAggregator from nncf.experimental.common.tensor_statistics.collectors import ShapeReducer from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import WCTensorStatistic from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.experimental.torch.fx.model_transformer import FXModelTransformer from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index fd0548c99a7..eb520bfcd1b 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -36,6 +36,8 @@ from nncf.quantization.algorithms.hyperparameter_tuner.algorithm import HyperparameterTuner from nncf.quantization.algorithms.hyperparameter_tuner.param_grid import get_quantization_param_grids from nncf.quantization.algorithms.post_training.pipeline import create_ptq_pipeline +from nncf.quantization.algorithms.weight_compression.algorithm import check_user_compression_configuration +from nncf.quantization.algorithms.weight_compression.algorithm import get_weight_compression_configuration from nncf.quantization.telemetry_extractors import CompressionStartedWithCompressWeightsApi from nncf.quantization.telemetry_extractors import CompressionStartedWithQuantizeApi from nncf.quantization.telemetry_extractors import CompressionStartedWithQuantizeWithAccuracyControlApi @@ -425,7 +427,7 @@ def compress_weights( dataset: Optional[Dataset] = None, sensitivity_metric: Optional[SensitivityMetric] = None, *, - subset_size: Optional[int] = 128, + subset_size: int = 128, awq: Optional[bool] = None, scale_estimation: Optional[bool] = None, gptq: Optional[bool] = None, @@ -528,6 +530,9 @@ def compress_weights( "Set None or SensitivityMetric.WEIGHT_QUANTIZATION_ERROR." ) + if advanced_parameters and advanced_parameters.statistics_path: + raise nncf.ParameterNotSupportedError("Torch does not support statistics caching.") + if is_wrapped_model(model): if not model.nncf.trace_parameters: raise nncf.ValidationError( @@ -576,6 +581,8 @@ def compress_weights( raise nncf.ParameterNotSupportedError( "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" ) + if advanced_parameters and advanced_parameters.statistics_path: + raise nncf.ParameterNotSupportedError("TorchFX does not supports statistics caching.") compression_weights_impl = fx_compression_weights_impl if backend == BackendType.OPENVINO: @@ -595,91 +602,48 @@ def compress_weights( ) compression_weights_impl = ov_compress_weights_impl - - if mode in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: - if ratio is None: - ratio = 1 - if group_size is None: - group_size = -1 - if ratio != 1 or group_size != -1: - raise nncf.ParameterNotSupportedError( - "INT8 modes assume per-channel quantization of all layers in 8 bit. " - "Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden" - ) - - if backup_mode is not None: - raise nncf.ParameterNotSupportedError("INT8 modes do not support the `backup_mode` option") - - options = { - "all_layers": all_layers, - "sensitivity_metric": sensitivity_metric, - "dataset": dataset, - "awq": awq, - "scale_estimation": scale_estimation, - "gptq": gptq, - "lora_correction": lora_correction, - } - unsupported_for_int8 = [name for name, value in options.items() if value is not None] - if unsupported_for_int8: - raise nncf.ParameterNotSupportedError( - f"INT8 modes do not support {', '.join(unsupported_for_int8)} option(s). Set them to None." - ) - - if ratio is None: - ratio = 1 - if group_size is None: - group_size = 128 - if all_layers is None: - all_layers = False - if awq is None: - awq = False - if scale_estimation is None: - scale_estimation = False - if gptq is None: - gptq = False - if lora_correction is None: - lora_correction = False - if ignored_scope is None: - ignored_scope = IgnoredScope() - if sensitivity_metric is None: - sensitivity_metric = ( - SensitivityMetric.WEIGHT_QUANTIZATION_ERROR - if dataset is None - else SensitivityMetric.MAX_ACTIVATION_VARIANCE - ) - if backup_mode is None: - backup_mode = BackupMode.INT8_ASYM - if ratio != 1 and dataset is None and sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR: - raise nncf.ValidationError( - f"Mixed precision selection based on the given sensitivity metric={sensitivity_metric.value} requires " - "a dataset, but it's not provided." - ) - if ratio < 0 or ratio > 1: - raise nncf.ValidationError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.") - if subset_size is None or subset_size <= 0: - raise nncf.ValidationError(f"The subset_size value should be positive, but subset_size={subset_size} is given.") - - if compression_weights_impl is None: - raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") - - return compression_weights_impl( - model, - dataset, + check_user_compression_configuration( mode, + subset_size, + dataset, ratio, group_size, - ignored_scope, all_layers, + awq, + scale_estimation, + gptq, + lora_correction, + ignored_scope, sensitivity_metric, + backup_mode, + advanced_parameters, + ) + weight_compression_configuration = get_weight_compression_configuration( + mode, + dataset, + ratio, + group_size, + all_layers, awq, - subset_size, scale_estimation, gptq, lora_correction, + ignored_scope, + sensitivity_metric, backup_mode, advanced_parameters, ) + if compression_weights_impl is None: + raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") + + return compression_weights_impl( + model=model, + dataset=dataset, + subset_size=subset_size, + **weight_compression_configuration, + ) + def quantize_with_tune_hyperparams( model: TModel, diff --git a/nncf/quantization/statistics_caching.py b/nncf/quantization/statistics_caching.py new file mode 100644 index 00000000000..e806e3cc65d --- /dev/null +++ b/nncf/quantization/statistics_caching.py @@ -0,0 +1,129 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +from nncf.api.compression import TModel +from nncf.common.factory import StatisticsAggregatorFactory +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.data import Dataset +from nncf.parameters import SensitivityMetric +from nncf.quantization.algorithms.weight_compression.algorithm import WeightCompression +from nncf.quantization.algorithms.weight_compression.algorithm import get_weight_compression_configuration +from nncf.quantization.algorithms.weight_compression.mixed_precision import MIXED_PRECISION_CRITERIA + + +def register_statistics_for_algorithm( + aggregator: StatisticsAggregator, + model: TModel, + graph: NNCFGraph, + subset_size: int, + compression_algo: WeightCompression, + matmul_input_to_output_nodes_map: Dict[Tuple[NNCFNode, int], List[NNCFNode]], +) -> None: + """ + Registers the statistics required for the given compression algorithm. + + :param aggregator: Aggregator to register statistics. + :param model: Model being analyzed. + :param graph: Model's computational graph. + :param subset_size: Size of dataset subset for statistics. + :param compression_algo: WeightCompression algorithm instance. + :param matmul_input_to_output_nodes_map: A dictionary mapping from a tuple of (activation node, port ID) + to a list of MatMul nodes that accept the activation as input. + """ + statistic_points = compression_algo.get_statistic_points( + model, graph, matmul_input_to_output_nodes_map.keys(), subset_size + ) + aggregator.register_statistic_points(statistic_points) + + +def _register_mixed_precision( + aggregator: StatisticsAggregator, + model: TModel, + graph: NNCFGraph, + matmul_input_to_output_nodes_map: Dict[Tuple[NNCFNode, int], List[NNCFNode]], + subset_size: int, +) -> None: + """ + Registers statistics for mixed precision compression algorithm. + + :param aggregator: Aggregator to register statistics. + :param model: Model being analyzed. + :param graph: Model's computational graph. + :param matmul_input_to_output_nodes_map: A dictionary mapping from a tuple of (activation node, port ID) + to a list of MatMul nodes that accept the activation as input. + :param subset_size: Size of dataset subset for statistics. + """ + sensitivities = [ + SensitivityMetric.HESSIAN_INPUT_ACTIVATION, + SensitivityMetric.MEAN_ACTIVATION_VARIANCE, + SensitivityMetric.MAX_ACTIVATION_VARIANCE, + SensitivityMetric.MEAN_ACTIVATION_MAGNITUDE, + ] + + for sensitivity in sensitivities: + criterion_cls = MIXED_PRECISION_CRITERIA.get(sensitivity) + mixed_prec_algo = criterion_cls(None, None) + statistic_points = mixed_prec_algo.get_statistic_points( + model, graph, matmul_input_to_output_nodes_map.keys(), subset_size + ) + aggregator.register_statistic_points(statistic_points) + + +def register_all_statistics( + aggregator: StatisticsAggregator, + model: TModel, + graph: NNCFGraph, + subset_size: int, + compression_algo: WeightCompression, + enable_mixed_precision: bool = True, +) -> None: + """ + Registers all required statistics for the model compression. + + :param aggregator: Aggregator to register statistics. + :param model: Model being analyzed. + :param graph: Model's computational graph. + :param subset_size: Size of dataset subset for statistics. + :param compression_algo: WeightCompression algorithm instance. + :param enable_mixed_precision: Whether to enable mixed precision statistics. + """ + _, matmul_input_to_output_nodes_map = compression_algo.get_compression_nodes_info(graph) + + register_statistics_for_algorithm( + aggregator, model, graph, subset_size, compression_algo, matmul_input_to_output_nodes_map + ) + + if enable_mixed_precision: + _register_mixed_precision(aggregator, model, graph, matmul_input_to_output_nodes_map, subset_size) + + +def cache_weight_compression_statistics( + model: TModel, graph: NNCFGraph, dataset: Dataset, subset_size: int, statistics_path: str +) -> None: + """ + Caches compression statistics for a given model and dataset. + + :param model: Model being analyzed. + :param graph: Model's computational graph. + :param dataset: Dataset to analyze model statistics. + :param subset_size: Size of dataset subset for statistics. + :param statistics_path: Path to save cached statistics. + """ + config = get_weight_compression_configuration(awq=True, scale_estimation=True, lora_correction=True) + compression_algo = WeightCompression(**config, subset_size=subset_size) + compression_algo.set_backend_entity(model) + aggregator = StatisticsAggregatorFactory.create(model, dataset) + register_all_statistics(aggregator, model, graph, subset_size, compression_algo) + aggregator.collect_statistics(model, graph) + aggregator.dump_statistics(statistics_path) diff --git a/nncf/tensor/tensor.py b/nncf/tensor/tensor.py index 52966be1ad1..9edbd4acb50 100644 --- a/nncf/tensor/tensor.py +++ b/nncf/tensor/tensor.py @@ -162,6 +162,13 @@ def __gt__(self, other: Union[Tensor, float]) -> Tensor: def __ge__(self, other: Union[Tensor, float]) -> Tensor: return Tensor(self.data >= unwrap_tensor_data(other)) + # Methods to support pickling and unpickling + def __getstate__(self): + return self._data + + def __setstate__(self, state): + self._data = state + # Tensor functions def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Tensor: diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 2d0a1ba0a8b..c52c2eb4c33 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -19,13 +19,17 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic from nncf.tensor import Tensor from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.tensor_statistics.algo import create_register_input_hook class PTStatisticsAggregator(StatisticsAggregator): + BACKEND: BackendType = BackendType.TORCH HOOKS_GROUP_NAME = "statistics_hooks" def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: @@ -72,3 +76,14 @@ def _get_merged_statistic_points( def _process_outputs(outputs: torch.Tensor) -> Dict[str, Tensor]: # PyTorch backend doesn't use outputs to register statistics return {} + + def _get_statistics_key(self, statistics: TensorStatistic, target_point: PTTargetPoint) -> str: + """ + Returns key of statistics. + + :param statistics: Statistics value. + :param target_point: Statistics target point. + :return: Statistics key. + """ + target_point_id = f"{target_point.target_node_name}_{target_point.type}_{target_point.input_port_id}" + return f"{statistics.__class__.__name__}_{target_point_id}" diff --git a/tests/common/experimental/test_statistic_collector.py b/tests/common/experimental/test_statistic_collector.py index 5a704e901a4..36186605032 100644 --- a/tests/common/experimental/test_statistic_collector.py +++ b/tests/common/experimental/test_statistic_collector.py @@ -30,6 +30,12 @@ from nncf.tensor import Tensor +class DummyStatContainer: + @staticmethod + def from_config(config): + return config + + class DummyTensorReducer(TensorReducerBase): def __init__(self, output_name: str, inplace: bool = False, inplace_mock=None): super().__init__(inplace=inplace) @@ -302,17 +308,6 @@ def test_register_unnamed_statistics(mocker): assert all(v[0] == inputs_) -def test_wrong_statistic_container_class(): - class BadStatContainer: - pass - - tensor_collector = TensorCollector(BadStatContainer) - tensor_collector.register_statistic_branch("A", DummyTensorReducer("A"), DummyTensorAggregator()) - tensor_collector.register_input_for_all_reducers(Tensor(np.array(1))) - with pytest.raises(nncf.InternalError): - tensor_collector.get_statistics() - - class TemplateTestStatisticCollector: @abstractmethod def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor: @@ -347,13 +342,12 @@ def test_empty_tensors_register(self, inplace, any_not_empty): stats = collector.get_statistics() assert len(stats) == 1 assert stats["A"] == self.get_nncf_tensor([100]) - return - - assert len(aggregator._container) == 0 - assert aggregator._collected_samples == 0 - stats = collector.get_statistics() - assert len(stats) == 1 - assert stats["A"] is None + else: + assert len(aggregator._container) == 0 + assert aggregator._collected_samples == 0 + stats = collector.get_statistics() + assert len(stats) == 1 + assert stats["A"] is None def test_min_max_stat_building(self): tensor_collector = TensorCollector(MinMaxTensorStatistic) @@ -440,3 +434,48 @@ def test_raw_max_stat_building(self): statistic = tensor_collector.get_statistics() assert isinstance(statistic, RawTensorStatistic) assert statistic.values == 1 + + def test_tensor_collector_cache_and_statistics(self): + # Initialize a single instance of TensorCollector + collector = TensorCollector(DummyStatContainer) + + # Test setting cache + cached_statistics = {"values": [1, 2]} + collector.set_cache(cached_statistics) + assert collector._cached_statistics == cached_statistics, "Cache should match the set value" + assert not collector.enabled, "Collector should be disabled after setting cache" + + # Test clearing cache + collector.clear_cache() + collector.enable() + assert collector._cached_statistics is None, "Cache should be cleared" + + # Test default behavior of get_statistics without cache + collector.register_statistic_branch("container_key", DummyTensorReducer("A"), DummyTensorAggregator()) + collector.register_input_for_all_reducers(Tensor(np.array(1))) + statistics = collector.get_statistics() + assert statistics == {"container_key": Tensor(np.array(1))}, "Statistics should reflect registered input" + + # Test get_statistics with cache + collector.set_cache({"container_key": Tensor(np.array(25))}) + statistics = collector.get_statistics() + assert statistics == {"container_key": Tensor(np.array(25))}, "Statistics should return cached value" + + # Attempt to register new input while cache is set + collector.register_input_for_all_reducers(Tensor(np.array(2))) + statistics = collector.get_statistics() + assert statistics == {"container_key": Tensor(np.array(25))}, "Statistics should still reflect cached value" + + # Clear cache and check behavior + collector.clear_cache() + collector.enable() + empty_stats = {"container_key": None} + statistics = collector.get_statistics() + assert statistics == empty_stats, "Statistics should be empty after clearing cache" + + # Register new input after clearing cache + collector.register_input_for_all_reducers(Tensor(np.array(8))) + statistics = collector.get_statistics() + assert statistics == { + "container_key": Tensor(np.array(8)) + }, "Statistics should reflect the new input after clearing cache" diff --git a/tests/common/test_statistics_caching.py b/tests/common/test_statistics_caching.py new file mode 100644 index 00000000000..0093c39545f --- /dev/null +++ b/tests/common/test_statistics_caching.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque + +import numpy as np +import pytest + +import nncf +import nncf.common.tensor_statistics.statistics_serializer as statistics_serializer +from nncf.tensor import Tensor +from nncf.tensor.functions import allclose + + +def _compare_dicts(dict1, dict2): + """ + Recursively compares two dictionaries. + Supports comparing numpy arrays and Tensor objects. + """ + if not isinstance(dict1, dict) or not isinstance(dict2, dict): + raise ValueError("Both inputs must be dictionaries") + + if dict1.keys() != dict2.keys(): + return False + + for key in dict1: + val1 = dict1[key] + val2 = dict2[key] + + if isinstance(val1, np.ndarray) and isinstance(val2, np.ndarray): + if not np.array_equal(val1, val2): + return False + elif isinstance(val1, Tensor) and isinstance(val2, Tensor): + if not allclose(val1, val2): + return False + # Recursively compare nested dictionaries + elif isinstance(val1, dict) and isinstance(val2, dict): + if not _compare_dicts(val1, val2): + return False + # Direct comparison for other types + else: + if val1 != val2: + return False + + return True + + +@pytest.fixture +def dummy_statistics(): + """ + Returns a dummy statistics dictionary for testing purposes. + """ + return { + "point_A": {"min": 1, "max": 2}, + "point_B": { + "min_tuple": (1, 2), + "max_dict": {"tensor_1": [10, 10], "tensor_2": deque([1, 2])}, + "tensor_numpy": Tensor(np.ones(shape=(10, 5, 3))), + }, + } + + +def test_dump_and_load_statistics(tmp_path, dummy_statistics): + """ + Tests that dumped statistics can be loaded and match the original. + """ + test_dir = "test_dir" + statistics_serializer.dump_to_dir(dummy_statistics, tmp_path / test_dir) + assert (tmp_path / test_dir).exists(), "Dumped file was not created" + + loaded_statistics, _ = statistics_serializer.load_from_dir(tmp_path / test_dir) + assert _compare_dicts(dummy_statistics, loaded_statistics), "Loaded statistics do not match the original" + + +def test_load_statistics_from_non_existent_dir(): + """ + Tests that attempting to load statistics from a non-existent directory raises an error. + """ + file_path = "non_existent_dir" + with pytest.raises(nncf.ValidationError) as exc_info: + statistics_serializer.load_from_dir(file_path) + assert "The provided directory path does not exist." in str(exc_info) diff --git a/tests/common/test_statistics_serializer.py b/tests/common/test_statistics_serializer.py new file mode 100644 index 00000000000..2ecfc7c0a57 --- /dev/null +++ b/tests/common/test_statistics_serializer.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from pathlib import Path + +import pytest + +import nncf +from nncf.common.tensor_statistics.statistics_serializer import dump_to_dir +from nncf.common.tensor_statistics.statistics_serializer import load_from_dir +from nncf.common.tensor_statistics.statistics_serializer import load_metadata +from nncf.common.tensor_statistics.statistics_serializer import sanitize_filename +from nncf.common.tensor_statistics.statistics_serializer import save_metadata + + +def test_sanitize_filename(): + filename = "layer/1_mean/activation" + sanitized = sanitize_filename(filename) + assert sanitized == "layer_1_mean_activation", "Filename was not sanitized correctly" + + +def test_load_metadata(tmp_path): + # Create a metadata file in the temp directory + metadata = {"mapping": {"key1": "value1"}, "metadata": {"model": "test"}} + metadata_file = tmp_path / "statistics_metadata.json" + with open(metadata_file, "w") as f: + json.dump(metadata, f) + + loaded_metadata = load_metadata(tmp_path) + assert loaded_metadata == metadata, "Metadata was not loaded correctly" + + +def test_save_metadata(tmp_path): + metadata = {"mapping": {"key1": "value1"}, "metadata": {"model": "test"}} + save_metadata(metadata, tmp_path) + + metadata_file = tmp_path / "statistics_metadata.json" + assert metadata_file.exists(), "Metadata file was not created" + + with open(metadata_file, "r") as f: + loaded_metadata = json.load(f) + assert loaded_metadata == metadata, "Metadata was not saved correctly" + + +def test_dump_and_load_statistics(tmp_path): + statistics = {"layer/1_mean/activation": [0.1, 0.2, 0.3], "layer/2_variance": [0.05, 0.06, 0.07]} + additional_metadata = {"model": "facebook/opt-125m", "compression": "8-bit"} + + dump_to_dir(statistics, tmp_path, additional_metadata) + + assert len(list(Path(tmp_path).iterdir())) > 0, "No files created during dumping" + + metadata_file = tmp_path / "statistics_metadata.json" + assert metadata_file.exists(), "Metadata file was not created" + + with open(metadata_file, "r") as f: + metadata = json.load(f) + assert "mapping" in metadata, "Mapping is missing in metadata" + assert metadata["metadata"]["model"] == "facebook/opt-125m" + + # Load the statistics and ensure it was loaded correctly + loaded_statistics, loaded_metadata = load_from_dir(tmp_path) + assert "layer/1_mean/activation" in loaded_statistics, "Statistics not loaded correctly" + assert loaded_statistics["layer/1_mean/activation"] == [0.1, 0.2, 0.3] + assert loaded_metadata["model"] == "facebook/opt-125m", "Metadata not loaded correctly" + + +def test_invalid_gzip_file(tmp_path): + # Create a corrupt gzip file in the directory + invalid_file = tmp_path / "invalid_file.gz" + with open(invalid_file, "w") as f: + f.write("This is not a valid gzip file") + + # Expect the load_from_dir to raise an error when trying to load the invalid file + with pytest.raises(nncf.InternalError, match="Error loading statistics"): + load_from_dir(tmp_path) diff --git a/tests/common/test_statistics_validator.py b/tests/common/test_statistics_validator.py new file mode 100644 index 00000000000..8f234f011e6 --- /dev/null +++ b/tests/common/test_statistics_validator.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +import nncf +from nncf.common.tensor_statistics.statistics_validator import validate_backend +from nncf.common.utils.backend import BackendType + + +@pytest.mark.parametrize("backend_value", [BackendType.TORCH, BackendType.ONNX]) +def test_validate_backend(backend_value): + # Test case where backend matches + data = {"backend": backend_value.value} + backend = backend_value + + validate_backend(data, backend) + + with pytest.raises(nncf.ValidationError) as exc_info: + # Test case where backend does not match + validate_backend({"backend": BackendType.ONNX.value}, BackendType.TORCH) + assert "Backend in loaded statistics ONNX does not match to an expected backend Torch." in str(exc_info) + + with pytest.raises(nncf.ValidationError) as exc_info: + # Test case where backend key is missing + validate_backend({}, BackendType.TORCH) + assert "The provided metadata has no information about backend." in str(exc_info) diff --git a/tests/cross_fw/test_templates/test_statistics_caching.py b/tests/cross_fw/test_templates/test_statistics_caching.py new file mode 100644 index 00000000000..1035dc8fcb4 --- /dev/null +++ b/tests/cross_fw/test_templates/test_statistics_caching.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import abstractmethod +from pathlib import Path + +import numpy as np +import pytest + +import nncf +from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.tensor_statistics.statistic_point import StatisticPoint +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer +from nncf.common.utils.backend import BackendType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.tensor import Tensor + + +class TemplateTestStatisticsCaching: + @property + @abstractmethod + def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> TargetPoint: + """ + Creates a backend-specific TargetPoint. + + :param target_point_type: The type of target point (e.g., PRE_LAYER_OPERATION). + :param name: The name of the target point. + :param port_id: The port ID for the target point. + :return: A backend-specific TargetPoint. + """ + pass + + @abstractmethod + def get_statistics_aggregator(self): + """ + Returns a statistics aggregator. Must be implemented by subclasses. + + :return: Statistics aggregator instance specific to the backend. + """ + pass + + def _create_dummy_statistic_point(self) -> StatisticPoint: + """ + Creates a dummy statistic point for testing purposes. + + :return: A StatisticPoint object with dummy data. + """ + dummy_t_p = self.create_target_point(TargetType.PRE_LAYER_OPERATION, "dummy_name", 0) + dummy_tensor_collector = TensorCollector() + dummy_tensor_collector._cached_statistics = MinMaxTensorStatistic(Tensor(np.zeros((3))), Tensor(np.ones((3)))) + return StatisticPoint( + target_point=dummy_t_p, tensor_collector=dummy_tensor_collector, algorithm="dummy_algorithm" + ) + + def test_dump_and_load_statistics(self, tmp_path: Path): + """ + Tests the dumping and loading of statistics to and from a file. + + :param tmp_path: The temporary path provided by pytest. + """ + test_dir = "test_dir" + aggregator = self.get_statistics_aggregator() + statistics_points = StatisticPointsContainer() + + dummy_statistic_point = self._create_dummy_statistic_point() + statistics_points.add_statistic_point(dummy_statistic_point) + + aggregator.statistic_points = statistics_points + aggregator.dump_statistics(tmp_path / test_dir) + assert (tmp_path / test_dir).exists(), "Statistics file was not created" + + aggregator.load_statistics_from_dir(tmp_path / test_dir) + + def test_incorrect_backend_statistics_load(self, tmp_path: Path): + """ + Tests the dumping and loading of statistics to and from a file with non matched backends. + + :param tmp_path: The temporary path provided by pytest. + """ + test_file = "test" + aggregator = self.get_statistics_aggregator() + statistics_points = StatisticPointsContainer() + + dummy_statistic_point = self._create_dummy_statistic_point() + statistics_points.add_statistic_point(dummy_statistic_point) + + aggregator.statistic_points = statistics_points + aggregator.dump_statistics(tmp_path / test_file) + assert (tmp_path / test_file).exists(), "Statistics file was not created" + # spoil backend + aggregator.BACKEND = BackendType.TENSORFLOW + with pytest.raises(nncf.ValidationError): + aggregator.load_statistics_from_dir(tmp_path / test_file) diff --git a/tests/onnx/test_statistics_caching.py b/tests/onnx/test_statistics_caching.py new file mode 100644 index 00000000000..e1224e7b5f3 --- /dev/null +++ b/tests/onnx/test_statistics_caching.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nncf.common.graph.transformations.commands import TargetType +from nncf.onnx.graph.transformations.commands import ONNXTargetPoint +from nncf.onnx.statistics.aggregator import ONNXStatisticsAggregator +from tests.cross_fw.test_templates.test_statistics_caching import TemplateTestStatisticsCaching + + +class TestStatisticsCaching(TemplateTestStatisticsCaching): + def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> ONNXTargetPoint: + return ONNXTargetPoint(target_type=target_point_type, target_node_name=name, port_id=port_id) + + def get_statistics_aggregator(self): + return ONNXStatisticsAggregator(None) diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index edc50652710..db72b267698 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -29,6 +29,7 @@ from nncf.openvino.graph.node_utils import get_const_value from nncf.parameters import BackupMode from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams from nncf.quantization.advanced_parameters import AdvancedGPTQParameters as GPTQParams from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters as LoraParams @@ -727,6 +728,7 @@ def test_raise_error_channel_size_is_not_divisible_by_group_size(): {"backup_mode": BackupMode.NONE}, {"backup_mode": BackupMode.INT8_ASYM}, {"backup_mode": BackupMode.INT8_SYM}, + {"advanced_parameters": AdvancedCompressionParameters(statistics_path="anything")}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): @@ -910,7 +912,7 @@ def test_default_subset_value(): assert default_value == 128 -@pytest.mark.parametrize("subset_size", (-1, 0, None)) +@pytest.mark.parametrize("subset_size", (-1, 0)) def test_invalid_subset_size(subset_size): model = IdentityMatmul().ov_model dataset = Dataset([ACTIVATION]) diff --git a/tests/openvino/native/quantization/test_weights_compression_statistics_caching.py b/tests/openvino/native/quantization/test_weights_compression_statistics_caching.py new file mode 100644 index 00000000000..538643075ee --- /dev/null +++ b/tests/openvino/native/quantization/test_weights_compression_statistics_caching.py @@ -0,0 +1,170 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +from functools import partial +from itertools import product +from typing import Tuple + +import datasets +import openvino as ov +from optimum.intel.openvino import OVModelForCausalLM +from transformers import AutoTokenizer + +import nncf +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters +from nncf.scopes import IgnoredScope + +MODEL_ID = "hf-internal-testing/tiny-random-OPTForCausalLM" +DEFAULT_RATIO = 0.4 +DEFAULT_GROUP_SIZE = 4 +DEFAULT_SENSITIVITY = nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION +DEFAULT_IGNORED_SCOPE = IgnoredScope() +DEFAULT_SUBSET_SIZE = 4 +DEFAULT_MODE = nncf.CompressWeightsMode.INT4_ASYM + + +def create_transform_fn(model: OVModelForCausalLM, tokenizer: AutoTokenizer): + def transform_fn(data, model=model, tokenizer=tokenizer): + tokenized_text = tokenizer(data["text"], return_tensors="np") + input_ids = tokenized_text["input_ids"] + inputs = {"input_ids": input_ids, "attention_mask": tokenized_text["attention_mask"]} + + batch_size = input_ids.shape[0] + if hasattr(model, "key_value_input_names"): + for input_name in model.key_value_input_names: + model_inputs = model.model.input(input_name) + shape = model_inputs.get_partial_shape() + shape[0] = batch_size + shape[2 if shape[2].is_dynamic else 1] = 0 + inputs[input_name] = ov.Tensor(model_inputs.get_element_type(), shape.get_shape()) + return inputs + + return transform_fn + + +def _setup_model_and_dataset(model_id: str) -> Tuple[OVModelForCausalLM, nncf.Dataset]: + dataset = datasets.load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = OVModelForCausalLM.from_pretrained(model_id, export=True, load_in_8bit=False, compile=False, stateful=False) + transform_fn = create_transform_fn(model, tokenizer) + quantization_dataset = nncf.Dataset(dataset, partial(transform_fn)) + return model, quantization_dataset + + +def _test_basic_configurations(model, quantization_dataset, tmp_path, subset_size, mode) -> int: + awq_options = [True, False] + group_size_options = [1, 4] + ratio_options = [0.4, 0.8] + sensitivity_metrics = [ + nncf.SensitivityMetric.HESSIAN_INPUT_ACTIVATION, + nncf.SensitivityMetric.MAX_ACTIVATION_VARIANCE, + ] + ignored_scopes = [IgnoredScope(), IgnoredScope(types=["MatMul"])] + + load_count = 0 + for awq, group_size, ratio, sensitivity, scope in product( + awq_options, group_size_options, ratio_options, sensitivity_metrics, ignored_scopes + ): + print(f"Testing: AWQ={awq}, Group={group_size}, Ratio={ratio}, Metric={sensitivity}, Scope={scope}") + nncf.compress_weights( + deepcopy(model.model), + mode=mode, + dataset=quantization_dataset, + ratio=ratio, + awq=awq, + gptq=False, + group_size=group_size, + scale_estimation=False, + subset_size=subset_size, + sensitivity_metric=sensitivity, + ignored_scope=scope, + lora_correction=False, + advanced_parameters=AdvancedCompressionParameters(statistics_path=tmp_path / "statistics"), + ) + load_count += 1 + return load_count + + +def _test_advanced_gptq_scale_estimation(model, quantization_dataset, tmp_path, subset_size, mode) -> int: + load_count = 0 + for gptq, scale_est in product([True, False], [True, False]): + print(f"Testing: AWQ=True, GPTQ={gptq}, Scale={scale_est}, LoRA=False") + nncf.compress_weights( + deepcopy(model.model), + mode=mode, + dataset=quantization_dataset, + ratio=DEFAULT_RATIO, + awq=True, + gptq=gptq, + group_size=DEFAULT_GROUP_SIZE, + scale_estimation=scale_est, + subset_size=subset_size, + sensitivity_metric=DEFAULT_SENSITIVITY, + ignored_scope=DEFAULT_IGNORED_SCOPE, + lora_correction=False, + advanced_parameters=AdvancedCompressionParameters(statistics_path=tmp_path / "statistics"), + ) + load_count += 1 + return load_count + + +def _test_advanced_lora_scale_estimation(model, quantization_dataset, tmp_path, subset_size, mode) -> int: + load_count = 0 + for scale_est, lora_corr in product([True, False], [True, False]): + print(f"Testing: AWQ=True, GPTQ=False, Scale={scale_est}, LoRA={lora_corr}") + nncf.compress_weights( + deepcopy(model.model), + mode=mode, + dataset=quantization_dataset, + ratio=DEFAULT_RATIO, + awq=True, + gptq=False, + group_size=DEFAULT_GROUP_SIZE, + scale_estimation=scale_est, + subset_size=subset_size, + sensitivity_metric=DEFAULT_SENSITIVITY, + ignored_scope=DEFAULT_IGNORED_SCOPE, + lora_correction=lora_corr, + advanced_parameters=AdvancedCompressionParameters(statistics_path=tmp_path / "statistics"), + ) + load_count += 1 + return load_count + + +def test_weight_compression_statistics_caching(tmp_path, mocker): + """ + Tests the weight compression process, focusing on the statistics caching mechanism. + Ensures that: + - Statistics are collected once. + - Statistics are loaded according to the number of configurations tested. + - Statistics are dumped once. + :param tmp_path: Temporary directory path for storing statistics. + :param mocker: Mocking utility to spy on function calls. + """ + from nncf.openvino.statistics.aggregator import OVStatisticsAggregator + + collect_spy = mocker.spy(OVStatisticsAggregator, "collect_statistics") + load_spy = mocker.spy(OVStatisticsAggregator, "load_statistics_from_dir") + dump_spy = mocker.spy(OVStatisticsAggregator, "dump_statistics") + + model_id = MODEL_ID + subset_size = DEFAULT_SUBSET_SIZE + mode = DEFAULT_MODE + model, quantization_dataset = _setup_model_and_dataset(model_id) + + load_count = 0 + load_count += _test_basic_configurations(model, quantization_dataset, tmp_path, subset_size, mode) + load_count += _test_advanced_gptq_scale_estimation(model, quantization_dataset, tmp_path, subset_size, mode) + load_count += _test_advanced_lora_scale_estimation(model, quantization_dataset, tmp_path, subset_size, mode) + + assert collect_spy.call_count == 1, "Statistics should be collected only once." + assert load_spy.call_count == load_count, f"Expected {load_count} load calls, found {load_spy.call_count}." + assert dump_spy.call_count == 1, "Statistics should be dumped only once." diff --git a/tests/openvino/native/test_statistics_caching.py b/tests/openvino/native/test_statistics_caching.py new file mode 100644 index 00000000000..15000d0bba7 --- /dev/null +++ b/tests/openvino/native/test_statistics_caching.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nncf.common.graph.transformations.commands import TargetType +from nncf.openvino.graph.transformations.commands import OVTargetPoint +from nncf.openvino.statistics.aggregator import OVStatisticsAggregator +from tests.cross_fw.test_templates.test_statistics_caching import TemplateTestStatisticsCaching + + +class TestStatisticsCaching(TemplateTestStatisticsCaching): + def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> OVTargetPoint: + return OVTargetPoint(target_point_type, name, port_id) + + def get_statistics_aggregator(self): + return OVStatisticsAggregator(None) diff --git a/tests/openvino/requirements.txt b/tests/openvino/requirements.txt index 9a74f4d74d8..9f6a390a1d5 100644 --- a/tests/openvino/requirements.txt +++ b/tests/openvino/requirements.txt @@ -12,3 +12,7 @@ virtualenv addict>=2.4.0 timm==0.9.2 efficientnet_pytorch==0.7.1 +datasets==3.0.1 +transformers==4.45.2 +optimum-intel==1.20.0 +optimum==1.23.1 diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 835398bd57e..2caffb35edf 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -22,6 +22,7 @@ from nncf.data.dataset import Dataset from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.torch.dynamic_graph.patch_pytorch import disable_patching from tests.torch.ptq.test_weights_compression import ALL_SENSITIVITY_METRICS from tests.torch.ptq.test_weights_compression import DATA_BASED_SENSITIVITY_METRICS @@ -169,7 +170,6 @@ def test_compressed_model_inference(mode): @pytest.mark.parametrize("mode", SUPPORTED_MODES) def test_compress_weights_model_size_conv(mode): - dtype = torch.int8 if mode == CompressWeightsMode.INT8_SYM else torch.uint8 model = ConvolutionModel() @@ -227,6 +227,7 @@ def test_compress_weights_functional_model(mode): {"backup_mode": BackupMode.NONE}, {"backup_mode": BackupMode.INT8_ASYM}, {"backup_mode": BackupMode.INT8_SYM}, + {"advanced_parameters": AdvancedCompressionParameters(statistics_path="anything")}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): @@ -266,6 +267,14 @@ def test_raise_error_with_not_int8(mode): compress_weights(exported_model, mode=mode) +def test_raise_error_for_statistics_caching(): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + exported_model = _capture_model(dummy_torch_model, dummy_input) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(exported_model, advanced_parameters=AdvancedCompressionParameters(statistics_path="anything")) + + def test_get_dtype_attribute_of_parameter(): model = DTypeModel() dummy_input = torch.randint(0, 10, [3, 3]) diff --git a/tests/torch/fx/test_statistics_caching.py b/tests/torch/fx/test_statistics_caching.py new file mode 100644 index 00000000000..1533769c69a --- /dev/null +++ b/tests/torch/fx/test_statistics_caching.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator +from nncf.torch.graph.transformations.commands import PTTargetPoint +from tests.cross_fw.test_templates.test_statistics_caching import TemplateTestStatisticsCaching + + +class TestStatisticsCaching(TemplateTestStatisticsCaching): + def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(target_type=target_point_type, target_node_name=name, input_port_id=port_id) + + def get_statistics_aggregator(self): + return FXStatisticsAggregator(None) diff --git a/tests/torch/ptq/test_statistics_caching.py b/tests/torch/ptq/test_statistics_caching.py new file mode 100644 index 00000000000..6fd527ce15c --- /dev/null +++ b/tests/torch/ptq/test_statistics_caching.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nncf.common.graph.transformations.commands import TargetType +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.statistics.aggregator import PTStatisticsAggregator +from tests.cross_fw.test_templates.test_statistics_caching import TemplateTestStatisticsCaching + + +class TestStatisticsCaching(TemplateTestStatisticsCaching): + def create_target_point(self, target_point_type: TargetType, name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(target_type=target_point_type, target_node_name=name, input_port_id=port_id) + + def get_statistics_aggregator(self): + return PTStatisticsAggregator(None) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 2e902e1af50..3c3a424e2af 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -18,6 +18,7 @@ from nncf import CompressWeightsMode from nncf import SensitivityMetric from nncf.quantization import compress_weights +from nncf.quantization.advanced_parameters import AdvancedCompressionParameters from nncf.torch import wrap_model from nncf.torch.quantization.layers import INT4AsymmetricWeightsDecompressor from nncf.torch.quantization.layers import INT4SymmetricWeightsDecompressor @@ -236,6 +237,7 @@ def forward(self, input): {"backup_mode": BackupMode.NONE}, {"backup_mode": BackupMode.INT8_ASYM}, {"backup_mode": BackupMode.INT8_SYM}, + {"advanced_parameters": AdvancedCompressionParameters(statistics_path="anything")}, ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): @@ -274,6 +276,14 @@ def test_raise_error_with_not_int8(mode): compress_weights(wrapped_model, mode=mode) +def test_raise_error_for_statistics_caching(): + dummy_torch_model = EmptyModel() + dummy_input = torch.Tensor() + wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(wrapped_model, advanced_parameters=AdvancedCompressionParameters(statistics_path="anything")) + + class DTypeModel(torch.nn.Module): def __init__(self): super().__init__()