diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 85d016fdfb1..f5dbc3236a2 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -19,6 +19,11 @@ from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor from nncf.common.tensor_statistics.collectors import NNCFTensor from nncf.common.tensor_statistics.collectors import ReductionAxes +from nncf.common.tensor_statistics.statistics import MeanTensorStatistic +from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic +from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic +from nncf.common.tensor_statistics.statistics import RawTensorStatistic from nncf.common.tensor_statistics.statistics import TensorStatistic from nncf.quantization.advanced_parameters import AggregatorType @@ -334,7 +339,7 @@ def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]: if not self._stat_container: return kwargs - return self._stat_container(kwargs) + return self._build_statistic_container(self._stat_container, kwargs) def get_inplace_fn_info(self) -> List[Tuple[Any, int]]: """ @@ -392,6 +397,39 @@ def get_tensor_collector_inputs( target_inputs[reducer] = [outputs[name] for name in names] return target_inputs + @staticmethod + def _build_statistic_container(statistic_container_cls: TensorStatistic, kwargs: Dict[Any, Any]): + if issubclass(statistic_container_cls, MinMaxTensorStatistic): + return statistic_container_cls( + min_values=kwargs[MinMaxTensorStatistic.MIN_STAT], max_values=kwargs[MinMaxTensorStatistic.MAX_STAT] + ) + if issubclass(statistic_container_cls, MeanTensorStatistic): + return statistic_container_cls( + mean_values=kwargs[MeanTensorStatistic.MEAN_STAT], shape=kwargs[MeanTensorStatistic.SHAPE_STAT] + ) + if issubclass(statistic_container_cls, RawTensorStatistic): + return statistic_container_cls(values=kwargs[RawTensorStatistic.VALUES_STATS]) + if issubclass(statistic_container_cls, MedianMADTensorStatistic): + return statistic_container_cls( + median_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ + MedianMADTensorStatistic.MEDIAN_VALUES_STAT + ], + mad_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ + MedianMADTensorStatistic.MAD_VALUES_STAT + ], + ) + if issubclass(statistic_container_cls, PercentileTensorStatistic): + if PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY in kwargs: + percentile_vs_values_dict = kwargs[PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY] + else: + percentile_vs_values_dict = {} + for (_, percentile), value in kwargs.items(): + percentile_vs_values_dict[percentile] = value + return statistic_container_cls(percentile_vs_values_dict=percentile_vs_values_dict) + raise RuntimeError( + f"Statistic collector class {statistic_container_cls} is not supported by the TensorCollector class." + ) + class MergedTensorCollector(TensorCollector): """ @@ -708,7 +746,10 @@ def _aggregate_impl(self) -> Dict[str, NNCFTensor]: ) if not self._keepdims: median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes) - return {"median_values": median_per_ch.tensor, "mad_values": mad_values.tensor} + return { + MedianMADTensorStatistic.MEDIAN_VALUES_STAT: median_per_ch.tensor, + MedianMADTensorStatistic.MAD_VALUES_STAT: mad_values.tensor, + } class PercentileAggregator(TensorAggregatorBase): diff --git a/nncf/onnx/statistics/collectors.py b/nncf/onnx/statistics/collectors.py index a67e9544039..7af2792f003 100644 --- a/nncf/onnx/statistics/collectors.py +++ b/nncf/onnx/statistics/collectors.py @@ -151,10 +151,8 @@ def _register_input(self, x: ONNXNNCFTensor): def _get_statistics(self) -> ONNXMinMaxTensorStatistic: return ONNXMinMaxTensorStatistic( - { - ONNXMinMaxTensorStatistic.MIN_STAT: self._min_values.tensor, - ONNXMinMaxTensorStatistic.MAX_STAT: self._max_values.tensor, - } + min_values=self._min_values.tensor, + max_values=self._max_values.tensor, ) @@ -168,10 +166,8 @@ def _register_input(self, x: ONNXNNCFTensor): def _get_statistics(self) -> ONNXMinMaxTensorStatistic: return ONNXMinMaxTensorStatistic( - { - ONNXMinMaxTensorStatistic.MIN_STAT: self._min_aggregate().tensor, - ONNXMinMaxTensorStatistic.MAX_STAT: self._max_aggregate().tensor, - } + min_values=self._min_aggregate().tensor, + max_values=self._max_aggregate().tensor, ) @@ -185,10 +181,8 @@ def _register_input(self, x: ONNXNNCFTensor): def _get_statistics(self) -> ONNXMeanTensorStatistic: return ONNXMeanTensorStatistic( - { - ONNXMeanTensorStatistic.MEAN_STAT: self._mean_aggregate().tensor, - ONNXMeanTensorStatistic.SHAPE_STAT: self._shape(), - } + mean_values=self._mean_aggregate().tensor, + shape=self._shape(), ) @@ -201,4 +195,4 @@ def _register_input(self, x: ONNXNNCFTensor): self._register_input_common(x) def _get_statistics(self) -> ONNXRawTensorStatistic: - return ONNXRawTensorStatistic({ONNXRawTensorStatistic.VALUES_STATS: self._all_values}) + return ONNXRawTensorStatistic(self._all_values) diff --git a/nncf/onnx/statistics/statistics.py b/nncf/onnx/statistics/statistics.py index 91d1169c392..f9d5119201f 100644 --- a/nncf/onnx/statistics/statistics.py +++ b/nncf/onnx/statistics/statistics.py @@ -17,27 +17,18 @@ class ONNXMinMaxTensorStatistic(MinMaxTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT]) - @staticmethod def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool: return bool(np.allclose(tensor1, tensor2, rtol=rtol)) class ONNXMeanTensorStatistic(MeanTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT]) - @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) class ONNXRawTensorStatistic(RawTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.VALUES_STATS]) - @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) diff --git a/nncf/openvino/statistics/statistics.py b/nncf/openvino/statistics/statistics.py index de2cf3e23e2..12a0c82af9c 100644 --- a/nncf/openvino/statistics/statistics.py +++ b/nncf/openvino/statistics/statistics.py @@ -17,27 +17,18 @@ class OVMinMaxTensorStatistic(MinMaxTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT]) - @staticmethod def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool: return bool(np.allclose(tensor1, tensor2, rtol=rtol)) class OVMeanTensorStatistic(MeanTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT]) - @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) class OVRawTensorStatistic(RawTensorStatistic): - def __init__(self, tensor_collector_output): - super().__init__(tensor_collector_output[self.VALUES_STATS]) - @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index fc7c0486c0a..47cf5695832 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -120,9 +120,7 @@ def unify_statistics(statistics: List[ONNXMinMaxTensorStatistic]) -> ONNXMinMaxT min_values.append(np.array(statistic.min_values).flatten()) max_values = np.max(max_values, axis=0) min_values = np.min(min_values, axis=0) - return ONNXMinMaxTensorStatistic( - {ONNXMinMaxTensorStatistic.MIN_STAT: min_values, ONNXMinMaxTensorStatistic.MAX_STAT: max_values} - ) + return ONNXMinMaxTensorStatistic(min_values=min_values, max_values=max_values) @staticmethod def _get_input_edges_mapping(nncf_graph: NNCFGraph): diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index f1d2063bcd7..89450446ea9 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -119,9 +119,7 @@ def unify_statistics(statistics: List[OVMinMaxTensorStatistic]) -> OVMinMaxTenso min_values.append(np.array(statistic.min_values).flatten()) max_values = np.max(max_values, axis=0) min_values = np.min(min_values, axis=0) - return OVMinMaxTensorStatistic( - {OVMinMaxTensorStatistic.MIN_STAT: min_values, OVMinMaxTensorStatistic.MAX_STAT: max_values} - ) + return OVMinMaxTensorStatistic(min_values=min_values, max_values=max_values) @staticmethod def _get_reduction_shape_and_use_abs_max( diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index 77f9d78163f..8b9d0f57d98 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -144,9 +144,7 @@ def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTenso min_values.append(torch.tensor(statistic.min_values).flatten()) max_values = torch.max(torch.tensor(max_values)) min_values = torch.min(torch.tensor(min_values)) - return PTMinMaxTensorStatistic( - {PTMinMaxTensorStatistic.MIN_STAT: min_values, PTMinMaxTensorStatistic.MAX_STAT: max_values} - ) + return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values) @staticmethod def get_statistic_collector( diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py index 6dd4e88b68a..986adb46aa4 100644 --- a/nncf/torch/tensor.py +++ b/nncf/torch/tensor.py @@ -30,3 +30,6 @@ def __init__(self, tensor: torch.tensor): @property def device(self) -> torch.device: return self._tensor.device + + def is_empty(self) -> bool: + return self.tensor.size == 0 diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 37686759bb2..ce6f87eaf39 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -10,7 +10,7 @@ # limitations under the License. from functools import partial -from typing import Deque, List, Optional, Tuple, Union +from typing import Deque, List, Optional, Tuple, Type, Union import numpy as np import torch @@ -229,6 +229,44 @@ class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer): pass +def _reshape_all(targets: Tuple[torch.Tensor, ...], target_shape: Tuple[int, ...]): + return map(lambda stat: torch.reshape(stat, target_shape), targets) + + +def _get_wrapped_min_max_tensor_statistic(target_shape: Tuple[int, ...]) -> Type[PTMinMaxTensorStatistic]: + """ + Returns PTMinMaxTensorStatistic type but all statistics are reshaped to target_shape. + + :param target_shape: Target shape of the tensor statistic + :return: PTMinMaxTensorStatistic type but all statistics are reshaped to target_shape. + """ + + class WrappedPTMinMaxTensorStatistic(PTMinMaxTensorStatistic): + def __init__(self, min_values, max_values): + min_values, max_values = _reshape_all((min_values, max_values), target_shape) + super().__init__(min_values, max_values) + + return WrappedPTMinMaxTensorStatistic + + +def _get_wrapped_percentile_tensor_statistic(target_shape: Tuple[int, ...]) -> Type[PTPercentileTensorStatistic]: + """ + Returns PTPercentileTensorStatistic type but all statistics are reshaped to target_shape. + + :param target_shape: Target shape of the tensor statistic + :return: PTPercentileTensorStatistic type but all statistics are reshaped to target_shape. + """ + + class WrappedPTPercentileTensorStatistic(PTPercentileTensorStatistic): + def __init__(self, percentile_vs_values_dict): + reshaped_percentiles = {} + for k, v in percentile_vs_values_dict.items(): + reshaped_percentiles[k] = torch.reshape(v, target_shape) + super().__init__(reshaped_percentiles) + + return WrappedPTPercentileTensorStatistic + + def get_min_max_statistic_collector( use_abs_max: bool, reduction_axes: Tuple[int, ...], @@ -246,7 +284,8 @@ def get_min_max_statistic_collector( :param num_samples: Maximum number of samples to collect. :return: Min max statistic collector. """ - tensor_collector = TensorCollector(partial(PTMinMaxTensorStatistic, target_shape=scale_shape)) + + tensor_collector = TensorCollector(_get_wrapped_min_max_tensor_statistic(target_shape=scale_shape)) aggregator_kwargs = { "tensor_processor": PTNNCFCollectorTensorProcessor, @@ -288,7 +327,7 @@ def get_mixed_min_max_statistic_collector( Aggregates all available collected statistics in case parameter is None. :return: Mixed min max statistic collector. """ - tensor_collector = TensorCollector(partial(PTMinMaxTensorStatistic, target_shape=scale_shape)) + tensor_collector = TensorCollector(_get_wrapped_min_max_tensor_statistic(target_shape=scale_shape)) min_reducer = PTMinReducer(reduction_axes) kwargs = { @@ -329,12 +368,17 @@ def get_median_mad_statistic_collector( :return: Median Absolute Deviation statistic collector. """ + + class WrappedPTMedianMADTensorStatistic(PTMedianMADTensorStatistic): + def __init__(self, median_values, mad_values): + median_values, mad_values = _reshape_all((median_values, mad_values), scale_shape) + super().__init__(median_values, mad_values) + return _get_collection_without_reduction( MedianAbsoluteDeviationAggregator, - PTMedianMADTensorStatistic, + WrappedPTMedianMADTensorStatistic, reduction_axes=reduction_axes, aggregation_axes=aggregation_axes, - scale_shape=scale_shape, num_samples=num_samples, window_size=window_size, ) @@ -362,10 +406,9 @@ def get_percentile_tensor_collector( """ return _get_collection_without_reduction( partial(PercentileAggregator, percentiles_to_collect=percentiles_to_collect), - PTPercentileTensorStatistic, + _get_wrapped_percentile_tensor_statistic(target_shape=scale_shape), reduction_axes=reduction_axes, aggregation_axes=aggregation_axes, - scale_shape=scale_shape, num_samples=num_samples, window_size=window_size, ) @@ -376,7 +419,6 @@ def _get_collection_without_reduction( statistic_cls: TensorAggregatorBase, reduction_axes: Tuple[int, ...], aggregation_axes: Tuple[int, ...], - scale_shape: Tuple[int, ...], num_samples: int, window_size: Optional[int] = None, ) -> TensorCollector: @@ -387,13 +429,12 @@ def _get_collection_without_reduction( :param aggregator_cls: Statistic class to build the tensor collector. :param reduction_axes: Axes to use in reduction functions. :param aggregation_axes: Axes to use in aggregation functions. - :param scale_shape: Target shape for collected statistics. :param num_samples: Maximum number of samples to collect. :param window_size: Number of samples from the end of the list of collected samples to aggregate. Aggregates all available collected statistics in case parameter is None. :return: Target statistic collector. """ - tensor_collector = TensorCollector(partial(statistic_cls, target_shape=scale_shape)) + tensor_collector = TensorCollector(statistic_cls) reducer = PTNoopReducer() aggregation_axes = list(set(list(aggregation_axes) + [dim + 1 for dim in reduction_axes])) aggregator = aggregator_cls( @@ -429,7 +470,7 @@ def get_mean_percentile_statistic_collector( Aggregates all available collected statistics in case parameter is None. :return: Mean percentile statistic collector. """ - tensor_collector = TensorCollector(partial(PTPercentileTensorStatistic, target_shape=scale_shape)) + tensor_collector = TensorCollector(_get_wrapped_percentile_tensor_statistic(target_shape=scale_shape)) quantiles_to_collect = np.true_divide(percentiles_to_collect, 100) reducer = PTQuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect) for output_port_id, p in enumerate(percentiles_to_collect): diff --git a/nncf/torch/tensor_statistics/statistics.py b/nncf/torch/tensor_statistics/statistics.py index 3ad5aa65735..ba51df16ce9 100644 --- a/nncf/torch/tensor_statistics/statistics.py +++ b/nncf/torch/tensor_statistics/statistics.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import torch @@ -25,58 +25,24 @@ def _reshape_all(targets: Tuple[torch.Tensor, ...], target_shape: Tuple[int, ... class PTMinMaxTensorStatistic(MinMaxTensorStatistic): - def __init__(self, tensor_collector_output, target_shape: Optional[Tuple[int, ...]] = None): - min_values, max_values = tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT] - if target_shape: - min_values, max_values = _reshape_all((min_values, max_values), target_shape) - super().__init__(min_values=min_values, max_values=max_values) - @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTMedianMADTensorStatistic(MedianMADTensorStatistic): - def __init__(self, tensor_collector_output, target_shape: Optional[Tuple[int, ...]] = None): - median_values, mad_values = ( - tensor_collector_output[self.TENSOR_STATISTIC_OUTPUT_KEY][key] - for key in (self.MEDIAN_VALUES_STAT, self.MAD_VALUES_STAT) - ) - if target_shape: - median_values, mad_values = _reshape_all((median_values, mad_values), target_shape) - super().__init__(median_values=median_values, mad_values=mad_values) - @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTPercentileTensorStatistic(PercentileTensorStatistic): - def __init__(self, tensor_collector_output, target_shape: Optional[Tuple[int, ...]] = None): - if self.TENSOR_STATISTIC_OUTPUT_KEY in tensor_collector_output: - percentile_vs_values_dict = tensor_collector_output[self.TENSOR_STATISTIC_OUTPUT_KEY] - else: - percentile_vs_values_dict = {} - for (_, percentile), value in tensor_collector_output.items(): - percentile_vs_values_dict[percentile] = value - if target_shape: - percentile_vs_values_dict = { - k: torch.reshape(v, target_shape) for k, v in percentile_vs_values_dict.items() - } - super().__init__(percentile_vs_values_dict) - @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) class PTMeanTensorStatistic(MeanTensorStatistic): - def __init__(self, tensor_collector_output, target_shape: Optional[Tuple[int, ...]] = None): - mean_values, shape = tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT] - if target_shape: - mean_values = _reshape_all((mean_values), target_shape) - super().__init__(mean_values=mean_values, shape=shape) - @staticmethod def tensor_eq(tensor1: torch.Tensor, tensor2: torch.Tensor, rtol=1e-6) -> bool: return bool(torch.allclose(tensor1, tensor2, rtol=rtol)) @@ -89,10 +55,8 @@ def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> PTMinM # Using three-sigma approach to estimate min and max # Constant factor depends on the distribution form - assuming normal and the factor is 1.4826 return PTMinMaxTensorStatistic( - { - PTMinMaxTensorStatistic.MIN_STAT: statistic.median_values - 3 * 1.4826230 * statistic.mad_values, - PTMinMaxTensorStatistic.MAX_STAT: statistic.median_values + 3 * 1.4826230 * statistic.mad_values, - } + min_values=statistic.median_values - 3 * 1.4826230 * statistic.mad_values, + max_values=statistic.median_values + 3 * 1.4826230 * statistic.mad_values, ) if isinstance(statistic, PTPercentileTensorStatistic): if len(statistic.percentile_vs_values_dict.keys()) < 2: @@ -100,9 +64,7 @@ def pt_convert_stat_to_min_max_tensor_stat(statistic: TensorStatistic) -> PTMinM min_pct = min(statistic.percentile_vs_values_dict.keys()) max_pct = max(statistic.percentile_vs_values_dict.keys()) return PTMinMaxTensorStatistic( - { - PTMinMaxTensorStatistic.MIN_STAT: statistic.percentile_vs_values_dict[min_pct], - PTMinMaxTensorStatistic.MAX_STAT: statistic.percentile_vs_values_dict[max_pct], - } + min_values=statistic.percentile_vs_values_dict[min_pct], + max_values=statistic.percentile_vs_values_dict[max_pct], ) raise ValueError("Unknown TensorStatistic to generate min-max stat from!") diff --git a/tests/common/experimental/test_statistic_collector.py b/tests/common/experimental/test_statistic_collector.py index f2ede5d5a72..ba99c65d328 100644 --- a/tests/common/experimental/test_statistic_collector.py +++ b/tests/common/experimental/test_statistic_collector.py @@ -10,12 +10,17 @@ # limitations under the License. from abc import abstractmethod -from typing import List, Optional +from typing import List, Optional, Type import numpy as np import pytest from nncf.common.tensor import NNCFTensor +from nncf.common.tensor_statistics.statistics import MeanTensorStatistic +from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic +from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic +from nncf.common.tensor_statistics.statistics import RawTensorStatistic from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase from nncf.experimental.common.tensor_statistics.collectors import MergedTensorCollector from nncf.experimental.common.tensor_statistics.collectors import TensorCollector @@ -64,7 +69,7 @@ class DummyTensorReducerA(DummyTensorReducer): class DummyTensorAggregator(AggregatorBase): - def __init__(self, num_samples: Optional[int]): + def __init__(self, num_samples: Optional[int] = None): super().__init__(None, num_samples=num_samples) def _register_reduced_input_impl(self, x: TensorType): @@ -313,11 +318,47 @@ def test_register_unnamed_statistics(mocker): assert all(v[0] == inputs_) +def test_wrong_statistic_container_class(): + class BadStatContainer: + pass + + tensor_collector = TensorCollector(BadStatContainer) + tensor_collector.register_statistic_branch("A", DummyTensorReducer("A"), DummyTensorAggregator()) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + with pytest.raises(RuntimeError): + tensor_collector.get_statistics() + + class TemplateTestStatisticCollector: @abstractmethod def get_nncf_tensor_cls(self): pass + @abstractmethod + @pytest.fixture + def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: + pass + + @abstractmethod + @pytest.fixture + def mean_statistic_cls(self) -> Type[MeanTensorStatistic]: + pass + + @abstractmethod + @pytest.fixture + def median_mad_statistic_cls(self) -> Type[MedianMADTensorStatistic]: + pass + + @abstractmethod + @pytest.fixture + def percentile_statistic_cls(self) -> Type[PercentileTensorStatistic]: + pass + + @abstractmethod + @pytest.fixture + def raw_statistic_cls(self) -> Type[RawTensorStatistic]: + pass + @pytest.mark.parametrize("inplace", [False, True]) @pytest.mark.parametrize("any_not_empty", [False, True]) def test_empty_tensors_register(self, inplace, any_not_empty): @@ -354,3 +395,89 @@ def test_empty_tensors_register(self, inplace, any_not_empty): stats = collector.get_statistics() assert len(stats) == 1 assert stats["A"] is None + + def test_min_max_stat_building(self, min_max_statistic_cls: MinMaxTensorStatistic): + tensor_collector = TensorCollector(min_max_statistic_cls) + tensor_collector.register_statistic_branch( + min_max_statistic_cls.MIN_STAT, DummyTensorReducer("A"), DummyTensorAggregator() + ) + tensor_collector.register_statistic_branch( + min_max_statistic_cls.MAX_STAT, DummyTensorReducer("B"), DummyTensorAggregator() + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, MinMaxTensorStatistic) + assert statistic.min_values == statistic.max_values == NumpyNNCFTensor(1) + + def test_mean_max_stat_building(self, mean_statistic_cls: MeanTensorStatistic): + tensor_collector = TensorCollector(mean_statistic_cls) + tensor_collector.register_statistic_branch( + mean_statistic_cls.MEAN_STAT, DummyTensorReducer("A"), DummyTensorAggregator() + ) + tensor_collector.register_statistic_branch( + mean_statistic_cls.SHAPE_STAT, DummyTensorReducer("B"), DummyTensorAggregator() + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, MeanTensorStatistic) + assert statistic.mean_values == statistic.shape == NumpyNNCFTensor(1) + + def test_median_mad_stat_building(self, median_mad_statistic_cls: MedianMADTensorStatistic): + class DummyMADPercentileAggregator(DummyTensorAggregator): + def _aggregate_impl(self): + return { + MedianMADTensorStatistic.MEDIAN_VALUES_STAT: self._container[0], + MedianMADTensorStatistic.MAD_VALUES_STAT: self._container[0], + } + + tensor_collector = TensorCollector(median_mad_statistic_cls) + tensor_collector.register_statistic_branch( + median_mad_statistic_cls.TENSOR_STATISTIC_OUTPUT_KEY, + DummyTensorReducer("A"), + DummyMADPercentileAggregator(), + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, MedianMADTensorStatistic) + assert statistic.median_values == statistic.mad_values == NumpyNNCFTensor(1) + + def test_percentile_max_stat_building(self, percentile_statistic_cls: PercentileTensorStatistic): + class DummyPercentileTensorAggregator(DummyTensorAggregator): + def _aggregate_impl(self): + return {0.5: self._container[0]} + + tensor_collector = TensorCollector(percentile_statistic_cls) + tensor_collector.register_statistic_branch( + percentile_statistic_cls.TENSOR_STATISTIC_OUTPUT_KEY, + DummyTensorReducer("A"), + DummyPercentileTensorAggregator(), + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, PercentileTensorStatistic) + assert statistic.percentile_vs_values_dict[0.5] == NumpyNNCFTensor(1) + + tensor_collector = TensorCollector(percentile_statistic_cls) + qs = [0.3, 0.5, 0.7] + for q in qs: + tensor_collector.register_statistic_branch( + (PercentileTensorStatistic.PERCENTILE_VS_VALUE_DICT, q), + DummyTensorReducer(f"A{q}"), + DummyTensorAggregator(), + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, PercentileTensorStatistic) + assert len(statistic.percentile_vs_values_dict) == len(qs) + for q in qs: + assert statistic.percentile_vs_values_dict[q] == NumpyNNCFTensor(1) + + def test_raw_max_stat_building(self, raw_statistic_cls: RawTensorStatistic): + tensor_collector = TensorCollector(raw_statistic_cls) + tensor_collector.register_statistic_branch( + raw_statistic_cls.VALUES_STATS, DummyTensorReducer("A"), DummyTensorAggregator() + ) + tensor_collector.register_input_for_all_reducers(NumpyNNCFTensor(1)) + statistic = tensor_collector.get_statistics() + assert isinstance(statistic, RawTensorStatistic) + assert statistic.values == NNCFTensor(1) diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index e80ef50644c..1d3464882fe 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -32,9 +32,7 @@ def mock_collect_statistics(mocker): - get_statistics_value = ONNXMinMaxTensorStatistic( - {ONNXMinMaxTensorStatistic.MIN_STAT: -1, ONNXMinMaxTensorStatistic.MAX_STAT: 1} - ) + get_statistics_value = ONNXMinMaxTensorStatistic(min_values=-1, max_values=1) _ = mocker.patch( "nncf.quantization.fake_quantize.calculate_quantizer_parameters", return_value=FakeQuantizeParameters(np.array(0), np.array(0), np.array(0), np.array(0), 256), diff --git a/tests/openvino/native/test_statistic_collector.py b/tests/openvino/native/test_statistic_collector.py index 2d52c0af4cd..32123644942 100644 --- a/tests/openvino/native/test_statistic_collector.py +++ b/tests/openvino/native/test_statistic_collector.py @@ -9,6 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Type + +import pytest + +from nncf.common.tensor_statistics.statistics import MeanTensorStatistic +from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic +from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic +from nncf.common.tensor_statistics.statistics import RawTensorStatistic +from nncf.openvino.statistics.statistics import OVMeanTensorStatistic +from nncf.openvino.statistics.statistics import OVMinMaxTensorStatistic +from nncf.openvino.statistics.statistics import OVRawTensorStatistic from nncf.openvino.tensor import OVNNCFTensor from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector @@ -16,3 +28,31 @@ class TestOVStatisticCollector(TemplateTestStatisticCollector): def get_nncf_tensor_cls(self): return OVNNCFTensor + + @pytest.fixture + def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: + return OVMinMaxTensorStatistic + + @pytest.fixture + def mean_statistic_cls(self) -> Type[MeanTensorStatistic]: + return OVMeanTensorStatistic + + @pytest.fixture + def median_mad_statistic_cls(self) -> Type[MedianMADTensorStatistic]: + raise NotImplementedError() + + @pytest.mark.skip() + def test_median_mad_stat_building(self, median_mad_statistic_cls: MedianMADTensorStatistic): + pass + + @pytest.fixture + def percentile_statistic_cls(self) -> Type[PercentileTensorStatistic]: + raise NotImplementedError() + + @pytest.mark.skip + def test_percentile_max_stat_building(self, percentile_statistic_cls: PercentileTensorStatistic): + pass + + @pytest.fixture + def raw_statistic_cls(self) -> Type[RawTensorStatistic]: + return OVRawTensorStatistic diff --git a/tests/torch/ptq/helpers.py b/tests/torch/ptq/helpers.py index 8005ae84d52..5e7b672aae1 100644 --- a/tests/torch/ptq/helpers.py +++ b/tests/torch/ptq/helpers.py @@ -95,7 +95,5 @@ def mock_collect_statistics(mocker): min_, max_ = torch.tensor(min_), torch.tensor(max_) _ = mocker.patch( "nncf.experimental.common.tensor_statistics.collectors.TensorCollector.get_statistics", - return_value=PTMinMaxTensorStatistic( - {PTMinMaxTensorStatistic.MIN_STAT: min_, PTMinMaxTensorStatistic.MAX_STAT: max_} - ), + return_value=PTMinMaxTensorStatistic(min_values=min_, max_values=max_), ) diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index b75cb3f265c..4f2553de432 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -271,10 +271,8 @@ def calculate_statistics(data, mode, qgroup, half_range=False): max_values = np.amax(data, axes) statistics = PTMinMaxTensorStatistic( - { - PTMinMaxTensorStatistic.MIN_STAT: torch.from_numpy(np.array(min_values)), - PTMinMaxTensorStatistic.MAX_STAT: torch.from_numpy(np.array(max_values)), - } + min_values=torch.from_numpy(np.array(min_values)), + max_values=torch.from_numpy(np.array(max_values)), ) signedness_to_force = True if qgroup == QuantizerGroup.WEIGHTS else None qconfig = QuantizerConfig(num_bits=8, mode=mode, per_channel=per_ch, signedness_to_force=signedness_to_force) diff --git a/tests/torch/ptq/test_statistic_collector.py b/tests/torch/ptq/test_statistic_collector.py new file mode 100644 index 00000000000..0ab1ef2bb55 --- /dev/null +++ b/tests/torch/ptq/test_statistic_collector.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Type + +import pytest + +from nncf.common.tensor_statistics.statistics import MeanTensorStatistic +from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic +from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic +from nncf.common.tensor_statistics.statistics import RawTensorStatistic +from nncf.torch.tensor import PTNNCFTensor +from nncf.torch.tensor_statistics.statistics import PTMeanTensorStatistic +from nncf.torch.tensor_statistics.statistics import PTMedianMADTensorStatistic +from nncf.torch.tensor_statistics.statistics import PTMinMaxTensorStatistic +from nncf.torch.tensor_statistics.statistics import PTPercentileTensorStatistic +from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector + + +class TestOVStatisticCollector(TemplateTestStatisticCollector): + def get_nncf_tensor_cls(self): + return PTNNCFTensor + + @pytest.fixture + def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: + return PTMinMaxTensorStatistic + + @pytest.fixture + def mean_statistic_cls(self) -> Type[MeanTensorStatistic]: + return PTMeanTensorStatistic + + @pytest.fixture + def median_mad_statistic_cls(self) -> Type[MedianMADTensorStatistic]: + return PTMedianMADTensorStatistic + + @pytest.fixture + def percentile_statistic_cls(self) -> Type[PercentileTensorStatistic]: + return PTPercentileTensorStatistic + + @pytest.fixture + def raw_statistic_cls(self) -> Type[RawTensorStatistic]: + raise NotImplementedError() + + @pytest.mark.skip + def test_raw_max_stat_building(self, raw_statistic_cls: RawTensorStatistic): + pass diff --git a/tests/torch/tensor_statistics/test_tensor_statistics.py b/tests/torch/tensor_statistics/test_tensor_statistics.py index 1f162976fa7..05f731b57bf 100644 --- a/tests/torch/tensor_statistics/test_tensor_statistics.py +++ b/tests/torch/tensor_statistics/test_tensor_statistics.py @@ -43,19 +43,15 @@ class TestCollectedStatistics: get_min_max_statistic_collector, { ((1,), (0, 1)): PTMinMaxTensorStatistic( - {"min_values": torch.tensor([-4.0]), "max_values": torch.tensor([6.1])} + min_values=torch.tensor([-4.0]), max_values=torch.tensor([6.1]) ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[1.0], [-4.0], [4.0]]), - "max_values": torch.tensor([[4.5], [4.0], [6.1]]), - } + min_values=torch.tensor([[1.0], [-4.0], [4.0]]), + max_values=torch.tensor([[4.5], [4.0], [6.1]]), ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[-1.3, -4.0, -3.5]]), - "max_values": torch.tensor([[4.5, 5.8, 6.1]]), - } + min_values=torch.tensor([[-1.3, -4.0, -3.5]]), + max_values=torch.tensor([[4.5, 5.8, 6.1]]), ), # Not supported for now: # ((3, 3), ): PTMinMaxTensorStatistic( @@ -80,19 +76,15 @@ class TestCollectedStatistics: ), { ((1,), (0, 1)): PTMinMaxTensorStatistic( - {"min_values": torch.tensor([-3.5]), "max_values": torch.tensor([6.05])} + min_values=torch.tensor([-3.5]), max_values=torch.tensor([6.05]) ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[1.8], [-3.5], [4.15]]), - "max_values": torch.tensor([[3.75], [3.5], [6.05]]), - } + min_values=torch.tensor([[1.8], [-3.5], [4.15]]), + max_values=torch.tensor([[3.75], [3.5], [6.05]]), ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[-1.15, -3, -3.25]]), - "max_values": torch.tensor([[4.25, 5.4, 6.05]]), - } + min_values=torch.tensor([[-1.15, -3, -3.25]]), + max_values=torch.tensor([[4.25, 5.4, 6.05]]), ), }, ), @@ -104,19 +96,15 @@ class TestCollectedStatistics: ), { ((1,), (0, 1)): PTMinMaxTensorStatistic( - {"min_values": torch.tensor([-4.0]), "max_values": torch.tensor([6.05])} + min_values=torch.tensor([-4.0]), max_values=torch.tensor([6.05]) ), ((3, 1), (1,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[1.0], [-4.0], [4.0]]), - "max_values": torch.tensor([[3.75], [3.5], [6.05]]), - } + min_values=torch.tensor([[1.0], [-4.0], [4.0]]), + max_values=torch.tensor([[3.75], [3.5], [6.05]]), ), ((1, 3), (0,)): PTMinMaxTensorStatistic( - { - "min_values": torch.tensor([[-1.3, -4.0, -3.5]]), - "max_values": torch.tensor([[4.25, 5.4, 6.05]]), - } + min_values=torch.tensor([[-1.3, -4.0, -3.5]]), + max_values=torch.tensor([[4.25, 5.4, 6.05]]), ), }, ), @@ -149,28 +137,16 @@ def test_collected_statistics_with_shape_convert( # PTMedianMADStatisticCollector, { (1,): PTMedianMADTensorStatistic( - { - "tensor_statistic_output": { - "median_values": torch.tensor([2.8]), - "mad_values": torch.tensor([2.6]), - } - } + median_values=torch.tensor([2.8]), + mad_values=torch.tensor([2.6]), ), (3, 1): PTMedianMADTensorStatistic( - { - "tensor_statistic_output": { - "median_values": torch.tensor([[2.8], [-2.5], [5.4]]), - "mad_values": torch.tensor([[0.85], [1.1], [0.65]]), - } - } + median_values=torch.tensor([[2.8], [-2.5], [5.4]]), + mad_values=torch.tensor([[0.85], [1.1], [0.65]]), ), (1, 3): PTMedianMADTensorStatistic( - { - "tensor_statistic_output": { - "median_values": torch.tensor([[2.5, 2.3, 3.35]]), - "mad_values": torch.tensor([[1.9, 3.1, 2.7]]), - } - } + median_values=torch.tensor([[2.5, 2.3, 3.35]]), + mad_values=torch.tensor([[1.9, 3.1, 2.7]]), ), # Not supported for now: # (3, 3): PTMedianMADTensorStatistic( @@ -190,13 +166,9 @@ def test_collected_statistics_with_shape_convert( ( partial(get_percentile_tensor_collector, percentiles_to_collect=[10.0]), { - (1,): PTPercentileTensorStatistic({"tensor_statistic_output": {10.0: torch.tensor([-3.15])}}), - (3, 1): PTPercentileTensorStatistic( - {"tensor_statistic_output": {10.0: torch.tensor([[1.5], [-3.75], [4.15]])}} - ), - (1, 3): PTPercentileTensorStatistic( - {"tensor_statistic_output": {10.0: torch.tensor([[-1.15, -3, -3.25]])}} - ), + (1,): PTPercentileTensorStatistic({10.0: torch.tensor([-3.15])}), + (3, 1): PTPercentileTensorStatistic({10.0: torch.tensor([[1.5], [-3.75], [4.15]])}), + (1, 3): PTPercentileTensorStatistic({10.0: torch.tensor([[-1.15, -3, -3.25]])}), # Not supported for now: # (3, 3): PTPercentileTensorStatistic( # { @@ -212,13 +184,9 @@ def test_collected_statistics_with_shape_convert( ( partial(get_mean_percentile_statistic_collector, percentiles_to_collect=[10.0]), { - (1,): PTPercentileTensorStatistic({"tensor_statistic_output": {10.0: torch.tensor([-2.9])}}), - (3, 1): PTPercentileTensorStatistic( - {"tensor_statistic_output": {10.0: torch.tensor([[2.0100], [-3.3500], [4.4000]])}} - ), - (1, 3): PTPercentileTensorStatistic( - {"tensor_statistic_output": {10.0: torch.tensor([[-0.3900, -1.9400, -1.9300]])}} - ), + (1,): PTPercentileTensorStatistic({10.0: torch.tensor([-2.9])}), + (3, 1): PTPercentileTensorStatistic({10.0: torch.tensor([[2.0100], [-3.3500], [4.4000]])}), + (1, 3): PTPercentileTensorStatistic({10.0: torch.tensor([[-0.3900, -1.9400, -1.9300]])}), # Not supported for now: # (3, 3): PTPercentileTensorStatistic( # {