Skip to content

Commit

Permalink
[WeightCompression] Statistics caching (#3017)
Browse files Browse the repository at this point in the history
### Changes

Add statistics saving and loading for the `WeightCompression `algorithm:

1. Statistics are cached for all configurations such as:
`awq = True, scale_estimation=True with all type of sensitivities.`
3. Then the statistics are dumped in a directory which can be reused for
any `weights_compression()` configuration.

The example for tinyllama was updated with this functionality.

More changes:
1. Make all statistics used in `WeightCompression` aligned with
`TesnorStatistics` from
`nncf/experimental/common/tensor_statistics/statistics.py`
2. Extend `StatisticsAggregator` by the logic of loading and saving
statistics.
3. Dumping statistics is done using pickle and gzip. Serialization
methods were added for `Tensor`.
4. Introduced `statistics_serializer `and `statistics_validator` to
handle the statistics loading/dumping to the file.


Statistics sizes

Model | subset size | statistics directory size | statistics collection
time |
-- |  -- | -- | -- |
tinyllama | 128 | 100 MB | 61 sec |
Phi-3-mini-4k-instruct | 128 | 258 MB | 51 sec | 
ruDialoGPT-medium | 128 | 80 MB | 9 sec | 
 llama-3.1-8b | 128 | 393 MB | 95 sec | 


### Reason for changes

Speed up compression configuration finding.

### Related tickets

153129

### Tests

Test coverage were extended by tests on `statistics_serializer`,
`statistics_validator`, `StatisticsAggregator` and on
`WeightCompression` algorithm with the proposed functional.
  • Loading branch information
kshpv authored Oct 31, 2024
1 parent 51a7fb6 commit 4ddb5da
Show file tree
Hide file tree
Showing 42 changed files with 1,618 additions and 297 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import nncf
from nncf.common.logging import nncf_logger
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters

DataItem = TypeVar("DataItem")
ModelInput = TypeVar("ModelInput")
Expand Down Expand Up @@ -63,6 +64,7 @@ def compress_model(
group_size=group_size,
awq=awq,
sensitivity_metric=nncf.parameters.SensitivityMetric.MAX_ACTIVATION_VARIANCE,
advanced_parameters=AdvancedCompressionParameters(statistics_path="statistics"),
)
return optimized_ov_model

Expand Down
68 changes: 68 additions & 0 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, Optional, TypeVar

import nncf
import nncf.common.tensor_statistics.statistics_serializer as statistics_serializer
import nncf.common.tensor_statistics.statistics_validator as statistics_validator
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.data.dataset import DataItem
from nncf.data.dataset import Dataset
from nncf.data.dataset import ModelInput
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand All @@ -38,6 +44,8 @@ class StatisticsAggregator(ABC):
Base class for statistics collection.
"""

BACKEND: BackendType

def __init__(self, dataset: Dataset[DataItem, ModelInput]):
self.dataset = dataset
self.stat_subset_size = None
Expand Down Expand Up @@ -88,6 +96,56 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
f"smaller than the requested subset size {self.stat_subset_size}."
)

def load_statistics_from_dir(self, dir_path: str) -> None:
"""
Loads statistics from a directory and populates the statistic points with the loaded data.
:param dir_path: The name of the directory from which to load the statistics.
"""
loaded_data, metadata = statistics_serializer.load_from_dir(dir_path)
statistics_validator.validate_backend(metadata, self.BACKEND)
self._load_statistics(loaded_data)
nncf_logger.info(f"Statistics were successfully loaded from a directory {dir_path}.")

def _load_statistics(self, data: Dict[str, Any]) -> None:
"""
Loads statistics into the registered statistic points from the given data.
:param data: A dictionary containing the statistics loaded from a file.
"""
for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors():
statistics = tensor_collector.get_statistics()
statistics_key = self._get_statistics_key(statistics, statistic_point.target_point)
if statistics_key not in data:
raise nncf.ValidationError(f"Not found statistics for {statistics_key}")
statistics_container = tensor_collector.create_statistics_container(data[statistics_key])
tensor_collector.set_cache(statistics_container)

def dump_statistics(self, dir_path: str) -> None:
"""
Dumps the current statistics to a directory in a compressed format.
:param dir_path: The path of the directory where the statistics will be saved.
"""
data_to_dump = self._prepare_statistics()
metadata = {"backend": self.BACKEND.value, "subset_size": self.stat_subset_size}
statistics_serializer.dump_to_dir(data_to_dump, dir_path, metadata)
nncf_logger.info(f"Statistics were successfully saved to a directory {dir_path}.")

def _prepare_statistics(self) -> Dict[str, Any]:
"""
Prepares the statistics data for dumping into a directory.
:return: A dictionary containing the statistics data to be dumped.
"""
data_to_dump = {}
for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors():
statistics = tensor_collector.get_statistics()
statistics_key = self._get_statistics_key(statistics, statistic_point.target_point)
data = statistics.get_data()
data_to_dump[statistics_key] = data
return data_to_dump

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Register statistic points for statistics collection and recalculates the maximum number samples
Expand Down Expand Up @@ -154,3 +212,13 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
:param outputs: raw model outputs
:return: processed model outputs in Dict[str, Tensor] format
"""

@abstractmethod
def _get_statistics_key(self, statistics: TensorStatistic, target_point: TargetPoint) -> str:
"""
Returns key of statistics.
:param statistics: Statistics value.
:param target_point: Statistics target point.
:return: Statistics key.
"""
4 changes: 2 additions & 2 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def register_input(self, x: TensorType) -> TensorType:
def _register_input(self, x: TensorType) -> None:
pass

def get_statistics(self) -> None:
def get_statistics(self) -> Any:
"""Returns collected statistics, if present."""
if self._collected_samples == 0:
raise StatisticsNotCollectedError()
return self._get_statistics()

@abstractmethod
def _get_statistics(self) -> None:
def _get_statistics(self) -> Any:
pass

def enable(self) -> None:
Expand Down
14 changes: 4 additions & 10 deletions nncf/common/tensor_statistics/statistic_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from typing import Any, Callable, Generator, Optional, Tuple, cast

from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector


class StatisticPoint:
Expand All @@ -25,7 +24,7 @@ class StatisticPoint:
algorithm implies on what algorithm nedeed this statistics.
"""

def __init__(self, target_point: TargetPoint, tensor_collector: TensorStatisticCollectorBase, algorithm: str):
def __init__(self, target_point: TargetPoint, tensor_collector: TensorCollector, algorithm: str):
self.target_point = target_point
self.algorithm_to_tensor_collectors = {algorithm: [tensor_collector]}

Expand All @@ -36,11 +35,6 @@ def __eq__(self, other: Any) -> bool:
and self.algorithm_to_tensor_collectors == other.self.algorithm_to_tensor_collectors,
)

def register_tensor(self, x: NNCFTensor) -> None:
for tensor_collectors in self.algorithm_to_tensor_collectors.values():
for tensor_collector in tensor_collectors:
tensor_collector.register_input(x)


class StatisticPointsContainer(UserDict): # type: ignore
"""
Expand Down Expand Up @@ -88,7 +82,7 @@ def iter_through_statistic_points_in_target_node(

def get_tensor_collectors(
self, filter_fn: Optional[Callable[[StatisticPoint], bool]] = None
) -> Generator[Tuple[str, StatisticPoint, TensorStatisticCollectorBase], None, None]:
) -> Generator[Tuple[str, StatisticPoint, TensorCollector], None, None]:
"""
Returns iterable through all tensor collectors.
Expand All @@ -115,7 +109,7 @@ def get_algo_statistics_for_node(
target_node_name: str,
filter_fn: Callable[[StatisticPoint], bool],
algorithm: str,
) -> Generator[TensorStatisticCollectorBase, None, None]:
) -> Generator[TensorCollector, None, None]:
"""
Returns iterable through all statistic collectors in node with target_node_name.
Expand Down
25 changes: 1 addition & 24 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC
from abc import abstractmethod
from collections import Counter
from typing import Any, Dict, List, Tuple, TypeVar, cast
from typing import Any, Dict, TypeVar, cast

from nncf.tensor import Tensor
from nncf.tensor import functions as fns
Expand Down Expand Up @@ -120,26 +120,3 @@ def __eq__(self, other: Any) -> bool:
if not isinstance(other, RawTensorStatistic):
return False
return self.tensor_eq(self.values, other.values)


class WCTensorStatistic(TensorStatistic):
MEAN_STAT = "mean_values"
SHAPE_STAT = "shape_values"

def __init__(self, mean_values: List[Tensor], shapes: List[Tuple[int, ...]]):
"""
:param mean_values: List of N tensors of shape [HiddenDim] obtained by reducing activations along batch and
sequence length dimensions.
:param shapes: List of N tuples containing original shapes of activations before reduction.
"""
self.mean_values = mean_values
self.shapes = shapes

def __eq__(self, other: Any) -> bool:
shapes_equal = all(self.shapes[i] == other.shapes[i] for i in range(len(self.mean_values)))
if not shapes_equal:
return False
mean_values_equal = all(
self.tensor_eq(self.mean_values[i], other.mean_values[i]) for i in range(len(self.mean_values))
)
return mean_values_equal
115 changes: 115 additions & 0 deletions nncf/common/tensor_statistics/statistics_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gzip
import json
import pickle
import re
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, cast

import nncf

METADATA_FILE = "statistics_metadata.json"


def sanitize_filename(filename: str) -> str:
"""
Replaces any forbidden characters with an underscore.
"""
return re.sub(r'[\/:*?"<>|]', "_", filename)


def load_metadata(dir_path: Path) -> Dict[str, Any]:
"""
Loads the metadata, including the mapping and any other metadata information from the metadata file.
:param dir_path: The directory where the metadata file is stored.
:return: A dictionary containing the mapping and metadata.
"""
metadata_file = dir_path / METADATA_FILE
if metadata_file.exists():
with open(metadata_file, "r") as f:
return cast(Dict[str, Any], json.load(f))
return {"mapping": {}, "metadata": {}}


def save_metadata(metadata: Dict[str, Any], dir_path: Path) -> None:
"""
Saves the mapping and metadata to the metadata file.
:param metadata: The dictionary containing both the mapping and other metadata.
:param dir_path: The directory where the metadata file will be stored.
"""
metadata_file = dir_path / METADATA_FILE
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=4)


def load_from_dir(dir_path: str) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
Loads statistics from gzip-compressed files in the given directory.
:param dir_path: The path to the directory from which to load the statistics.
:return: 1) A dictionary with the original statistic names as keys and the loaded statistics as values.
2) Metadata dictionary.
"""
statistics = {}
path = Path(dir_path)
if not path.exists():
raise nncf.ValidationError("The provided directory path does not exist.")
metadata = load_metadata(path)
mapping = metadata.get("mapping", {})

for statistics_file in path.iterdir():
if statistics_file.name == METADATA_FILE:
continue # Skip the metadata file

try:
with gzip.open(statistics_file, "rb") as f:
sanitized_name = statistics_file.name
original_name = mapping.get(sanitized_name, sanitized_name)
statistics[original_name] = pickle.load(f)
except (pickle.UnpicklingError, IOError) as e:
raise nncf.InternalError(f"Error loading statistics from {statistics_file.name}: {e}")
return statistics, metadata.get("metadata", {})


def dump_to_dir(
statistics: Dict[str, Any], dir_path: str, additional_metadata: Optional[Dict[str, Any]] = None
) -> None:
"""
Dumps statistics to gzip-compressed files in the specified directory, while maintaining a mapping file.
:param statistics: A dictionary with statistic names as keys and the statistic data as values.
:param dir_path: The path to the directory where the statistics will be dumped.
:param additional_metadata: A dictionary containing any additional metadata to be saved with the mapping.
"""
path = Path(dir_path)
path.mkdir(parents=True, exist_ok=True)

metadata, mapping = {}, {}

for original_name, statistics_value in statistics.items():
sanitized_name = sanitize_filename(original_name)
file_path = path / sanitized_name

# Update the mapping
mapping[sanitized_name] = original_name

try:
with gzip.open(file_path, "wb") as f:
pickle.dump(statistics_value, f)
except (IOError, pickle.PicklingError) as e:
raise nncf.InternalError(f"Failed to write data to file {file_path}: {e}")

# Add additional metadata if provided
if additional_metadata:
metadata["metadata"] = additional_metadata

# Update the mapping in the metadata file
metadata["mapping"] = mapping
save_metadata(metadata, path)
30 changes: 30 additions & 0 deletions nncf/common/tensor_statistics/statistics_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict

import nncf
from nncf.common.utils.backend import BackendType


def validate_backend(data: Dict[str, Any], backend: BackendType) -> None:
"""
Checks whether backend in loaded data is equal to a provided backend.
:param data: Loaded statistics.
:param backend: Provided backend.
"""
if "backend" not in data:
raise nncf.ValidationError("The provided metadata has no information about backend.")
data_backend = data["backend"]
if data_backend != backend.value:
raise nncf.ValidationError(
f"Backend in loaded statistics {data_backend} does not match to an expected backend {backend.value}."
)
Loading

0 comments on commit 4ddb5da

Please sign in to comment.