Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WeightCompression] Statistics caching #3017

Merged
merged 101 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
de0c068
draft
kshpv Oct 15, 2024
8f8b4cf
add gzip
kshpv Oct 15, 2024
ac3b28d
add description
kshpv Oct 15, 2024
1f026eb
update WCstatistics
kshpv Oct 15, 2024
ebe17c3
fix init
kshpv Oct 15, 2024
9d69dac
Merge remote-tracking branch 'remote/develop' into statistics_cahing
kshpv Oct 15, 2024
8a358b1
rollback build_statistic_container
kshpv Oct 15, 2024
3ffa566
typo; fix mypy
kshpv Oct 15, 2024
24fc759
remain only compressed mode
kshpv Oct 15, 2024
25fa2ee
add docstrings
kshpv Oct 15, 2024
afd3ca9
minor
kshpv Oct 15, 2024
ecb82a8
rm WeightQuantizationErrorTensorStatistic
kshpv Oct 16, 2024
19646c1
add tests
kshpv Oct 16, 2024
86dc4da
improve code
kshpv Oct 16, 2024
d3903a4
add __eq__ for statistics
kshpv Oct 16, 2024
1e86ddd
polishing
kshpv Oct 16, 2024
43161e9
dump objects
kshpv Oct 16, 2024
ac6bef9
_get_statistics_key abstarctmethod
kshpv Oct 16, 2024
ca7317d
build_statistic_container -> from_kwargs
kshpv Oct 16, 2024
8263c9d
add state methods for pickle of Tensor
kshpv Oct 16, 2024
8253a9b
rm redundant test case
kshpv Oct 16, 2024
3566914
from_kwargs to WCTensorStatistic
kshpv Oct 16, 2024
5b739a6
polishing
kshpv Oct 16, 2024
504102f
fix fx
kshpv Oct 16, 2024
7c38d87
draft to a new logic
kshpv Oct 17, 2024
63c776b
extend logic by collecting all stats at first
kshpv Oct 17, 2024
9601066
Merge remote-tracking branch 'remote/develop' into statistics_cahing
kshpv Oct 17, 2024
08f7374
rollback
kshpv Oct 17, 2024
b80ea31
add logging to aggregator
kshpv Oct 17, 2024
982c66c
add cache_statistics.py
kshpv Oct 17, 2024
faa5c9c
typos
kshpv Oct 17, 2024
9feefd2
typo
kshpv Oct 17, 2024
f295ae1
rollback changes in algorithm_key
kshpv Oct 18, 2024
17752eb
polishing
kshpv Oct 18, 2024
7982bf4
comments
kshpv Oct 18, 2024
2e0e7bc
add disable() in TensorCollector
kshpv Oct 18, 2024
fe8e1f2
fix name
kshpv Oct 18, 2024
2230490
fix bug
kshpv Oct 18, 2024
5c20916
True to all subalgos
kshpv Oct 21, 2024
1ff1274
disable gptq
kshpv Oct 21, 2024
f8eefd6
draft configuration test
kshpv Oct 21, 2024
f55c345
add opt_125m test
kshpv Oct 22, 2024
9701881
polishing
kshpv Oct 22, 2024
b2efda0
add reqs
kshpv Oct 22, 2024
aeea789
rm test for torch
kshpv Oct 22, 2024
35a64f5
upgrade reqs
kshpv Oct 22, 2024
a18e18b
rollback GPTQ support
kshpv Oct 22, 2024
e576ca5
add errors for torch, torch_fx
kshpv Oct 22, 2024
74fc99e
add get_weight_compression_configuration, check_weight_compression_co…
kshpv Oct 22, 2024
4fd2afc
polishing
kshpv Oct 22, 2024
83012ec
docstrings
kshpv Oct 22, 2024
b7818b7
comments
kshpv Oct 22, 2024
3563476
fix bugs
kshpv Oct 22, 2024
5a75175
typehint
kshpv Oct 22, 2024
39fcf12
rename
kshpv Oct 22, 2024
7122ff1
add ignored scope as a param
kshpv Oct 22, 2024
5c6180c
typo
kshpv Oct 22, 2024
1912b4e
disable gptq
kshpv Oct 23, 2024
fae1b20
add backend name in statistics_key
kshpv Oct 23, 2024
a2ea028
add backend check
kshpv Oct 23, 2024
e1ba427
docstrings
kshpv Oct 23, 2024
684a6d2
mypy
kshpv Oct 23, 2024
54d0204
typos
kshpv Oct 23, 2024
f3a3414
statistics_serializer + statistics_validator
kshpv Oct 23, 2024
82897f3
upd test
kshpv Oct 23, 2024
3949ee7
some test fixes
kshpv Oct 23, 2024
20cc58f
statistics_file_path -> statistics_dir_path
kshpv Oct 23, 2024
b8a4612
mypy
kshpv Oct 23, 2024
db1ed56
comments
kshpv Oct 23, 2024
cc88aeb
from_kwargs -> from_config + minor
kshpv Oct 23, 2024
342b766
Merge remote-tracking branch 'remote/develop' into statistics_cahing
kshpv Oct 24, 2024
21ffb31
add test on tensor collector
kshpv Oct 24, 2024
9c5b999
fix test
kshpv Oct 24, 2024
769bb28
add cache methods for TensorCollector
kshpv Oct 25, 2024
f80be5b
upd errors
kshpv Oct 25, 2024
c6b2d81
mypy + fixes
kshpv Oct 25, 2024
1459f0e
fixes
kshpv Oct 25, 2024
a27df2c
error
kshpv Oct 25, 2024
ac95062
more tests
kshpv Oct 25, 2024
f07d879
rm register_tensor
kshpv Oct 25, 2024
793bf5b
comments
kshpv Oct 25, 2024
1b24408
add subset_size to metadata
kshpv Oct 25, 2024
7ba24dd
statistics_dir_path -> statistics_path
kshpv Oct 25, 2024
7192c92
make cache_weight_compression_statistics common function
kshpv Oct 25, 2024
5c0155d
import issue
kshpv Oct 25, 2024
6e6587a
polishing
kshpv Oct 25, 2024
931febd
add onnx and torch_fx tests on aggregator
kshpv Oct 25, 2024
32462f8
add get_matmul_nodes()
kshpv Oct 28, 2024
05e7643
add lora to test scope
kshpv Oct 28, 2024
bc73ac8
test refactor
kshpv Oct 28, 2024
d27e41f
minor
kshpv Oct 28, 2024
c94ee95
comments
kshpv Oct 28, 2024
4080132
rm gptq; optimize NNCFGraph creation;
kshpv Oct 29, 2024
8799874
introduce get_compression_nodes_info for WC
kshpv Oct 29, 2024
787b365
comments
kshpv Oct 29, 2024
dd926b5
rollback changes in apply (torch has no support layer attributes)
kshpv Oct 29, 2024
fd6be94
available_backends -> get_available_backends
kshpv Oct 30, 2024
35157bf
Merge remote-tracking branch 'remote/develop' into statistics_cahing
kshpv Oct 30, 2024
26bab53
typo
kshpv Oct 30, 2024
6ba8ac3
rollback backend check
kshpv Oct 30, 2024
1cf2055
comments
kshpv Oct 31, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import nncf
from nncf.common.logging import nncf_logger
from nncf.quantization.advanced_parameters import AdvancedCompressionParameters

DataItem = TypeVar("DataItem")
ModelInput = TypeVar("ModelInput")
Expand Down Expand Up @@ -63,6 +64,7 @@ def compress_model(
group_size=group_size,
awq=awq,
sensitivity_metric=nncf.parameters.SensitivityMetric.MAX_ACTIVATION_VARIANCE,
advanced_parameters=AdvancedCompressionParameters(statistics_path="statistics"),
)
return optimized_ov_model

Expand Down
68 changes: 68 additions & 0 deletions nncf/common/tensor_statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from itertools import islice
from typing import Any, Dict, Optional, TypeVar

import nncf
import nncf.common.tensor_statistics.statistics_serializer as statistics_serializer
import nncf.common.tensor_statistics.statistics_validator as statistics_validator
from nncf.common import factory
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.logging import nncf_logger
from nncf.common.logging.track_progress import track
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.data.dataset import DataItem
from nncf.data.dataset import Dataset
from nncf.data.dataset import ModelInput
from nncf.experimental.common.tensor_statistics.statistics import TensorStatistic

TensorType = TypeVar("TensorType")
TModel = TypeVar("TModel")
Expand All @@ -38,6 +44,8 @@ class StatisticsAggregator(ABC):
Base class for statistics collection.
"""

BACKEND: BackendType

def __init__(self, dataset: Dataset[DataItem, ModelInput]):
self.dataset = dataset
self.stat_subset_size = None
Expand Down Expand Up @@ -88,6 +96,56 @@ def collect_statistics(self, model: TModel, graph: NNCFGraph) -> None:
f"smaller than the requested subset size {self.stat_subset_size}."
)

def load_statistics_from_dir(self, dir_path: str) -> None:
"""
Loads statistics from a directory and populates the statistic points with the loaded data.

:param dir_path: The name of the directory from which to load the statistics.
"""
loaded_data, metadata = statistics_serializer.load_from_dir(dir_path)
statistics_validator.validate_backend(metadata, self.BACKEND)
self._load_statistics(loaded_data)
nncf_logger.info(f"Statistics were successfully loaded from a directory {dir_path}.")

def _load_statistics(self, data: Dict[str, Any]) -> None:
"""
Loads statistics into the registered statistic points from the given data.

:param data: A dictionary containing the statistics loaded from a file.
"""
for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors():
statistics = tensor_collector.get_statistics()
statistics_key = self._get_statistics_key(statistics, statistic_point.target_point)
if statistics_key not in data:
raise nncf.ValidationError(f"Not found statistics for {statistics_key}")
ljaljushkin marked this conversation as resolved.
Show resolved Hide resolved
statistics_container = tensor_collector.create_statistics_container(data[statistics_key])
tensor_collector.set_cache(statistics_container)

def dump_statistics(self, dir_path: str) -> None:
"""
Dumps the current statistics to a directory in a compressed format.

:param dir_path: The path of the directory where the statistics will be saved.
"""
data_to_dump = self._prepare_statistics()
metadata = {"backend": self.BACKEND.value, "subset_size": self.stat_subset_size}
statistics_serializer.dump_to_dir(data_to_dump, dir_path, metadata)
nncf_logger.info(f"Statistics were successfully saved to a directory {dir_path}.")

def _prepare_statistics(self) -> Dict[str, Any]:
"""
Prepares the statistics data for dumping into a directory.

:return: A dictionary containing the statistics data to be dumped.
"""
data_to_dump = {}
for _, statistic_point, tensor_collector in self.statistic_points.get_tensor_collectors():
statistics = tensor_collector.get_statistics()
statistics_key = self._get_statistics_key(statistics, statistic_point.target_point)
data = statistics.get_data()
data_to_dump[statistics_key] = data
return data_to_dump

def register_statistic_points(self, statistic_points: StatisticPointsContainer) -> None:
"""
Register statistic points for statistics collection and recalculates the maximum number samples
Expand Down Expand Up @@ -154,3 +212,13 @@ def _process_outputs(outputs: Any) -> Dict[str, NNCFTensor]:
:param outputs: raw model outputs
:return: processed model outputs in Dict[str, Tensor] format
"""

@abstractmethod
def _get_statistics_key(self, statistics: TensorStatistic, target_point: TargetPoint) -> str:
"""
Returns key of statistics.

:param statistics: Statistics value.
:param target_point: Statistics target point.
:return: Statistics key.
"""
4 changes: 2 additions & 2 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def register_input(self, x: TensorType) -> TensorType:
def _register_input(self, x: TensorType) -> None:
pass

def get_statistics(self) -> None:
def get_statistics(self) -> Any:
"""Returns collected statistics, if present."""
if self._collected_samples == 0:
raise StatisticsNotCollectedError()
return self._get_statistics()

@abstractmethod
def _get_statistics(self) -> None:
def _get_statistics(self) -> Any:
pass

def enable(self) -> None:
Expand Down
14 changes: 4 additions & 10 deletions nncf/common/tensor_statistics/statistic_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from typing import Any, Callable, Generator, Optional, Tuple, cast

from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector


class StatisticPoint:
Expand All @@ -25,7 +24,7 @@ class StatisticPoint:
algorithm implies on what algorithm nedeed this statistics.
"""

def __init__(self, target_point: TargetPoint, tensor_collector: TensorStatisticCollectorBase, algorithm: str):
def __init__(self, target_point: TargetPoint, tensor_collector: TensorCollector, algorithm: str):
self.target_point = target_point
self.algorithm_to_tensor_collectors = {algorithm: [tensor_collector]}

Expand All @@ -36,11 +35,6 @@ def __eq__(self, other: Any) -> bool:
and self.algorithm_to_tensor_collectors == other.self.algorithm_to_tensor_collectors,
)

def register_tensor(self, x: NNCFTensor) -> None:
for tensor_collectors in self.algorithm_to_tensor_collectors.values():
for tensor_collector in tensor_collectors:
tensor_collector.register_input(x)


class StatisticPointsContainer(UserDict): # type: ignore
"""
Expand Down Expand Up @@ -88,7 +82,7 @@ def iter_through_statistic_points_in_target_node(

def get_tensor_collectors(
self, filter_fn: Optional[Callable[[StatisticPoint], bool]] = None
) -> Generator[Tuple[str, StatisticPoint, TensorStatisticCollectorBase], None, None]:
) -> Generator[Tuple[str, StatisticPoint, TensorCollector], None, None]:
"""
Returns iterable through all tensor collectors.

Expand All @@ -115,7 +109,7 @@ def get_algo_statistics_for_node(
target_node_name: str,
filter_fn: Callable[[StatisticPoint], bool],
algorithm: str,
) -> Generator[TensorStatisticCollectorBase, None, None]:
) -> Generator[TensorCollector, None, None]:
"""
Returns iterable through all statistic collectors in node with target_node_name.

Expand Down
25 changes: 1 addition & 24 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC
from abc import abstractmethod
from collections import Counter
from typing import Any, Dict, List, Tuple, TypeVar, cast
from typing import Any, Dict, TypeVar, cast

from nncf.tensor import Tensor
from nncf.tensor import functions as fns
Expand Down Expand Up @@ -120,26 +120,3 @@ def __eq__(self, other: Any) -> bool:
if not isinstance(other, RawTensorStatistic):
return False
return self.tensor_eq(self.values, other.values)


class WCTensorStatistic(TensorStatistic):
MEAN_STAT = "mean_values"
SHAPE_STAT = "shape_values"

def __init__(self, mean_values: List[Tensor], shapes: List[Tuple[int, ...]]):
"""
:param mean_values: List of N tensors of shape [HiddenDim] obtained by reducing activations along batch and
sequence length dimensions.
:param shapes: List of N tuples containing original shapes of activations before reduction.
"""
self.mean_values = mean_values
self.shapes = shapes

def __eq__(self, other: Any) -> bool:
shapes_equal = all(self.shapes[i] == other.shapes[i] for i in range(len(self.mean_values)))
if not shapes_equal:
return False
mean_values_equal = all(
self.tensor_eq(self.mean_values[i], other.mean_values[i]) for i in range(len(self.mean_values))
)
return mean_values_equal
115 changes: 115 additions & 0 deletions nncf/common/tensor_statistics/statistics_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gzip
import json
import pickle
import re
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, cast

import nncf

METADATA_FILE = "statistics_metadata.json"


def sanitize_filename(filename: str) -> str:
"""
Replaces any forbidden characters with an underscore.
"""
return re.sub(r'[\/:*?"<>|]', "_", filename)


def load_metadata(dir_path: Path) -> Dict[str, Any]:
"""
Loads the metadata, including the mapping and any other metadata information from the metadata file.
:param dir_path: The directory where the metadata file is stored.
:return: A dictionary containing the mapping and metadata.
"""
metadata_file = dir_path / METADATA_FILE
if metadata_file.exists():
with open(metadata_file, "r") as f:
return cast(Dict[str, Any], json.load(f))
return {"mapping": {}, "metadata": {}}


def save_metadata(metadata: Dict[str, Any], dir_path: Path) -> None:
"""
Saves the mapping and metadata to the metadata file.
:param metadata: The dictionary containing both the mapping and other metadata.
:param dir_path: The directory where the metadata file will be stored.
"""
metadata_file = dir_path / METADATA_FILE
with open(metadata_file, "w") as f:
json.dump(metadata, f, indent=4)


def load_from_dir(dir_path: str) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""
Loads statistics from gzip-compressed files in the given directory.
:param dir_path: The path to the directory from which to load the statistics.
:return: 1) A dictionary with the original statistic names as keys and the loaded statistics as values.
2) Metadata dictionary.
"""
statistics = {}
path = Path(dir_path)
if not path.exists():
raise nncf.ValidationError("The provided directory path does not exist.")
metadata = load_metadata(path)
mapping = metadata.get("mapping", {})

for statistics_file in path.iterdir():
if statistics_file.name == METADATA_FILE:
continue # Skip the metadata file

try:
with gzip.open(statistics_file, "rb") as f:
sanitized_name = statistics_file.name
original_name = mapping.get(sanitized_name, sanitized_name)
statistics[original_name] = pickle.load(f)
except (pickle.UnpicklingError, IOError) as e:
raise nncf.InternalError(f"Error loading statistics from {statistics_file.name}: {e}")
return statistics, metadata.get("metadata", {})


def dump_to_dir(
statistics: Dict[str, Any], dir_path: str, additional_metadata: Optional[Dict[str, Any]] = None
) -> None:
"""
Dumps statistics to gzip-compressed files in the specified directory, while maintaining a mapping file.
:param statistics: A dictionary with statistic names as keys and the statistic data as values.
:param dir_path: The path to the directory where the statistics will be dumped.
:param additional_metadata: A dictionary containing any additional metadata to be saved with the mapping.
"""
path = Path(dir_path)
path.mkdir(parents=True, exist_ok=True)

metadata, mapping = {}, {}

for original_name, statistics_value in statistics.items():
sanitized_name = sanitize_filename(original_name)
file_path = path / sanitized_name

# Update the mapping
mapping[sanitized_name] = original_name

try:
with gzip.open(file_path, "wb") as f:
pickle.dump(statistics_value, f)
except (IOError, pickle.PicklingError) as e:
raise nncf.InternalError(f"Failed to write data to file {file_path}: {e}")

# Add additional metadata if provided
if additional_metadata:
metadata["metadata"] = additional_metadata

# Update the mapping in the metadata file
metadata["mapping"] = mapping
save_metadata(metadata, path)
30 changes: 30 additions & 0 deletions nncf/common/tensor_statistics/statistics_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict

import nncf
from nncf.common.utils.backend import BackendType


def validate_backend(data: Dict[str, Any], backend: BackendType) -> None:
"""
Checks whether backend in loaded data is equal to a provided backend.

:param data: Loaded statistics.
:param backend: Provided backend.
"""
if "backend" not in data:
raise nncf.ValidationError("The provided metadata has no information about backend.")
data_backend = data["backend"]
if data_backend != backend.value:
raise nncf.ValidationError(
f"Backend in loaded statistics {data_backend} does not match to an expected backend {backend.value}."
)
Loading
Loading