From f97299f34ea1245c14539ec098d8c76a5dd2f62f Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 8 Sep 2023 18:07:50 +0200 Subject: [PATCH] Fix onnx --- .../tensor_statistics/statistic_point.py | 2 +- nncf/onnx/statistics/collectors.py | 23 +++++++++++++++---- nncf/onnx/statistics/statistics.py | 9 ++++++++ .../algorithms/min_max/onnx_backend.py | 4 +++- .../tensor_statistics/collectors.py | 1 + tests/onnx/quantization/common.py | 4 +++- .../quantization/test_quantizer_config.py | 9 ++++++++ .../quantization/test_quantizer_config.py | 6 +++++ .../test_templates/test_quantizer_config.py | 10 +++++--- tests/torch/ptq/test_quantizer_config.py | 6 +++++ 10 files changed, 64 insertions(+), 10 deletions(-) diff --git a/nncf/common/tensor_statistics/statistic_point.py b/nncf/common/tensor_statistics/statistic_point.py index 735a56e6fef..705d5af552f 100644 --- a/nncf/common/tensor_statistics/statistic_point.py +++ b/nncf/common/tensor_statistics/statistic_point.py @@ -38,7 +38,7 @@ def __eq__(self, other): def register_tensor(self, x: TensorType): for tensor_collectors in self.algorithm_to_tensor_collectors.values(): for tensor_collector in tensor_collectors: - tensor_collector.register_unnamed_inputs(x) + tensor_collector.register_inputs(x) class StatisticPointsContainer(UserDict): diff --git a/nncf/onnx/statistics/collectors.py b/nncf/onnx/statistics/collectors.py index e43818f2889..e5d3d2878ed 100644 --- a/nncf/onnx/statistics/collectors.py +++ b/nncf/onnx/statistics/collectors.py @@ -156,7 +156,12 @@ def _register_input(self, x: ONNXNNCFTensor): self._register_input_common(x) def _get_statistics(self) -> ONNXMinMaxTensorStatistic: - return ONNXMinMaxTensorStatistic(self._min_values.tensor, self._max_values.tensor) + return ONNXMinMaxTensorStatistic( + { + ONNXMinMaxTensorStatistic.MIN_STAT: self._min_values.tensor, + ONNXMinMaxTensorStatistic.MAX_STAT: self._max_values.tensor, + } + ) class ONNXMeanMinMaxStatisticCollector(MeanMinMaxStatisticCollector): @@ -168,7 +173,12 @@ def _register_input(self, x: ONNXNNCFTensor): self._register_input_common(x) def _get_statistics(self) -> ONNXMinMaxTensorStatistic: - return ONNXMinMaxTensorStatistic(self._min_aggregate().tensor, self._max_aggregate().tensor) + return ONNXMinMaxTensorStatistic( + { + ONNXMinMaxTensorStatistic.MIN_STAT: self._min_aggregate().tensor, + ONNXMinMaxTensorStatistic.MAX_STAT: self._max_aggregate().tensor, + } + ) class ONNXMeanStatisticCollector(MeanStatisticCollector): @@ -180,7 +190,12 @@ def _register_input(self, x: ONNXNNCFTensor): self._register_input_common(x) def _get_statistics(self) -> ONNXMeanTensorStatistic: - return ONNXMeanTensorStatistic(self._mean_aggregate().tensor, self._shape()) + return ONNXMeanTensorStatistic( + { + ONNXMeanTensorStatistic.MEAN_STAT: self._mean_aggregate().tensor, + ONNXMeanTensorStatistic.SHAPE_STAT: self._shape(), + } + ) class ONNXRawStatisticCollector(RawStatisticCollector): @@ -192,4 +207,4 @@ def _register_input(self, x: ONNXNNCFTensor): self._register_input_common(x) def _get_statistics(self) -> ONNXRawTensorStatistic: - return ONNXRawTensorStatistic(self._all_values) + return ONNXRawTensorStatistic({ONNXRawTensorStatistic.VALUES_STATS: self._all_values}) diff --git a/nncf/onnx/statistics/statistics.py b/nncf/onnx/statistics/statistics.py index f9d5119201f..91d1169c392 100644 --- a/nncf/onnx/statistics/statistics.py +++ b/nncf/onnx/statistics/statistics.py @@ -17,18 +17,27 @@ class ONNXMinMaxTensorStatistic(MinMaxTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT]) + @staticmethod def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool: return bool(np.allclose(tensor1, tensor2, rtol=rtol)) class ONNXMeanTensorStatistic(MeanTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT]) + @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) class ONNXRawTensorStatistic(RawTensorStatistic): + def __init__(self, tensor_collector_output): + super().__init__(tensor_collector_output[self.VALUES_STATS]) + @staticmethod def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool: return bool(np.all(tensor, rtol=rtol)) diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index f5649f94ba2..a8cac7acd23 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -120,7 +120,9 @@ def unify_statistics(statistics: List[ONNXMinMaxTensorStatistic]) -> ONNXMinMaxT min_values.append(np.array(statistic.min_values).flatten()) max_values = np.max(max_values, axis=0) min_values = np.min(min_values, axis=0) - return ONNXMinMaxTensorStatistic(min_values=min_values, max_values=max_values) + return ONNXMinMaxTensorStatistic( + {ONNXMinMaxTensorStatistic.MIN_STAT: min_values, ONNXMinMaxTensorStatistic.MAX_STAT: max_values} + ) @staticmethod def _get_input_edges_mapping(nncf_graph: NNCFGraph): diff --git a/nncf/tensorflow/tensor_statistics/collectors.py b/nncf/tensorflow/tensor_statistics/collectors.py index 63d21087023..d42b15b8ff1 100644 --- a/nncf/tensorflow/tensor_statistics/collectors.py +++ b/nncf/tensorflow/tensor_statistics/collectors.py @@ -16,6 +16,7 @@ from nncf.common.tensor import NNCFTensor from nncf.common.tensor import TensorElementsType +from nncf.common.tensor_statistics.collectors import MaskedReduceFN from nncf.common.tensor_statistics.collectors import MeanMinMaxStatisticCollector from nncf.common.tensor_statistics.collectors import MeanPercentileStatisticCollector from nncf.common.tensor_statistics.collectors import MedianMADStatisticCollector diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 1d3464882fe..e80ef50644c 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -32,7 +32,9 @@ def mock_collect_statistics(mocker): - get_statistics_value = ONNXMinMaxTensorStatistic(min_values=-1, max_values=1) + get_statistics_value = ONNXMinMaxTensorStatistic( + {ONNXMinMaxTensorStatistic.MIN_STAT: -1, ONNXMinMaxTensorStatistic.MAX_STAT: 1} + ) _ = mocker.patch( "nncf.quantization.fake_quantize.calculate_quantizer_parameters", return_value=FakeQuantizeParameters(np.array(0), np.array(0), np.array(0), np.array(0), 256), diff --git a/tests/onnx/quantization/test_quantizer_config.py b/tests/onnx/quantization/test_quantizer_config.py index 374ae440f13..5b2fc84fd19 100644 --- a/tests/onnx/quantization/test_quantizer_config.py +++ b/tests/onnx/quantization/test_quantizer_config.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import pytest from nncf.common.graph.transformations.commands import TargetType +from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXAddLayerMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype @@ -27,6 +30,9 @@ ParamsCls = TemplateTestQuantizerConfig.TestGetStatisticsCollectorParameters +# pylint: disable=protected-access + + class TestQuantizerConfig(TemplateTestQuantizerConfig): def get_algo_backend(self): return ONNXMinMaxAlgoBackend() @@ -37,6 +43,9 @@ def check_is_min_max_statistic_collector(self, tensor_collector): def check_is_mean_min_max_statistic_collector(self, tensor_collector): assert isinstance(tensor_collector, ONNXMeanMinMaxStatisticCollector) + def get_reduction_axes(self, reducer: TensorStatisticCollectorBase) -> Tuple[int, ...]: + return reducer._reduction_shape + @pytest.fixture( params=[ pytest.param( diff --git a/tests/openvino/native/quantization/test_quantizer_config.py b/tests/openvino/native/quantization/test_quantizer_config.py index 45d41644ba4..8365c7a2ed6 100644 --- a/tests/openvino/native/quantization/test_quantizer_config.py +++ b/tests/openvino/native/quantization/test_quantizer_config.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import pytest from nncf.common.graph.transformations.commands import TargetType @@ -16,6 +18,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype @@ -45,6 +48,9 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl assert MeanAggregator in aggrs assert aggrs[0].__class__ == aggrs[1].__class__ + def get_reduction_axes(self, reducer: TensorReducerBase) -> Tuple[int, ...]: + return reducer._reduction_axes + @pytest.fixture( params=[ pytest.param( diff --git a/tests/post_training/test_templates/test_quantizer_config.py b/tests/post_training/test_templates/test_quantizer_config.py index afd2accc285..39c380070b5 100644 --- a/tests/post_training/test_templates/test_quantizer_config.py +++ b/tests/post_training/test_templates/test_quantizer_config.py @@ -12,7 +12,7 @@ from abc import abstractmethod from copy import deepcopy from dataclasses import dataclass -from typing import List +from typing import List, Tuple import pytest @@ -53,6 +53,10 @@ def check_is_min_max_statistic_collector(self, tensor_collector): def check_is_mean_min_max_statistic_collector(self, tensor_collector): pass + @abstractmethod + def get_reduction_axes(self, reducer) -> Tuple[int, ...]: + pass + @abstractmethod @pytest.fixture def single_conv_nncf_graph(self) -> NNCFGraphToTest: @@ -278,8 +282,8 @@ def test_get_stat_collector( for reducer in reducers: if q_config_per_channel: - assert reducer._reduction_axes == params.ref_per_ch_reduction_shape + assert self.get_reduction_axes(reducer) == params.ref_per_ch_reduction_shape else: - assert reducer._reduction_axes == params.ref_per_tensor_reduction_shape + assert self.get_reduction_axes(reducer) == params.ref_per_tensor_reduction_shape assert tensor_collector.num_samples == num_samples diff --git a/tests/torch/ptq/test_quantizer_config.py b/tests/torch/ptq/test_quantizer_config.py index 152503c802b..8c2e5cbfa38 100644 --- a/tests/torch/ptq/test_quantizer_config.py +++ b/tests/torch/ptq/test_quantizer_config.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import pytest from nncf.common.graph.transformations.commands import TargetType @@ -16,6 +18,7 @@ from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import MinAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from tests.post_training.test_templates.models import NNCFGraphToTest from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv @@ -44,6 +47,9 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl assert MeanAggregator in aggrs assert aggrs[0].__class__ == aggrs[1].__class__ + def get_reduction_axes(self, reducer: TensorReducerBase) -> Tuple[int, ...]: + return reducer._reduction_axes + @pytest.fixture( params=[ (TargetType.PRE_LAYER_OPERATION, "/Sum_1_0", (0, 2), (0, 1, 2)),