Skip to content

Commit

Permalink
Fix onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 8, 2023
1 parent 5ff5953 commit f97299f
Show file tree
Hide file tree
Showing 10 changed files with 64 additions and 10 deletions.
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/statistic_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __eq__(self, other):
def register_tensor(self, x: TensorType):
for tensor_collectors in self.algorithm_to_tensor_collectors.values():
for tensor_collector in tensor_collectors:
tensor_collector.register_unnamed_inputs(x)
tensor_collector.register_inputs(x)


class StatisticPointsContainer(UserDict):
Expand Down
23 changes: 19 additions & 4 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,12 @@ def _register_input(self, x: ONNXNNCFTensor):
self._register_input_common(x)

def _get_statistics(self) -> ONNXMinMaxTensorStatistic:
return ONNXMinMaxTensorStatistic(self._min_values.tensor, self._max_values.tensor)
return ONNXMinMaxTensorStatistic(
{
ONNXMinMaxTensorStatistic.MIN_STAT: self._min_values.tensor,
ONNXMinMaxTensorStatistic.MAX_STAT: self._max_values.tensor,
}
)


class ONNXMeanMinMaxStatisticCollector(MeanMinMaxStatisticCollector):
Expand All @@ -168,7 +173,12 @@ def _register_input(self, x: ONNXNNCFTensor):
self._register_input_common(x)

def _get_statistics(self) -> ONNXMinMaxTensorStatistic:
return ONNXMinMaxTensorStatistic(self._min_aggregate().tensor, self._max_aggregate().tensor)
return ONNXMinMaxTensorStatistic(
{
ONNXMinMaxTensorStatistic.MIN_STAT: self._min_aggregate().tensor,
ONNXMinMaxTensorStatistic.MAX_STAT: self._max_aggregate().tensor,
}
)


class ONNXMeanStatisticCollector(MeanStatisticCollector):
Expand All @@ -180,7 +190,12 @@ def _register_input(self, x: ONNXNNCFTensor):
self._register_input_common(x)

def _get_statistics(self) -> ONNXMeanTensorStatistic:
return ONNXMeanTensorStatistic(self._mean_aggregate().tensor, self._shape())
return ONNXMeanTensorStatistic(
{
ONNXMeanTensorStatistic.MEAN_STAT: self._mean_aggregate().tensor,
ONNXMeanTensorStatistic.SHAPE_STAT: self._shape(),
}
)


class ONNXRawStatisticCollector(RawStatisticCollector):
Expand All @@ -192,4 +207,4 @@ def _register_input(self, x: ONNXNNCFTensor):
self._register_input_common(x)

def _get_statistics(self) -> ONNXRawTensorStatistic:
return ONNXRawTensorStatistic(self._all_values)
return ONNXRawTensorStatistic({ONNXRawTensorStatistic.VALUES_STATS: self._all_values})
9 changes: 9 additions & 0 deletions nncf/onnx/statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,27 @@


class ONNXMinMaxTensorStatistic(MinMaxTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT])

@staticmethod
def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool:
return bool(np.allclose(tensor1, tensor2, rtol=rtol))


class ONNXMeanTensorStatistic(MeanTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))


class ONNXRawTensorStatistic(RawTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.VALUES_STATS])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))
4 changes: 3 additions & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def unify_statistics(statistics: List[ONNXMinMaxTensorStatistic]) -> ONNXMinMaxT
min_values.append(np.array(statistic.min_values).flatten())
max_values = np.max(max_values, axis=0)
min_values = np.min(min_values, axis=0)
return ONNXMinMaxTensorStatistic(min_values=min_values, max_values=max_values)
return ONNXMinMaxTensorStatistic(
{ONNXMinMaxTensorStatistic.MIN_STAT: min_values, ONNXMinMaxTensorStatistic.MAX_STAT: max_values}
)

@staticmethod
def _get_input_edges_mapping(nncf_graph: NNCFGraph):
Expand Down
1 change: 1 addition & 0 deletions nncf/tensorflow/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor import TensorElementsType
from nncf.common.tensor_statistics.collectors import MaskedReduceFN
from nncf.common.tensor_statistics.collectors import MeanMinMaxStatisticCollector
from nncf.common.tensor_statistics.collectors import MeanPercentileStatisticCollector
from nncf.common.tensor_statistics.collectors import MedianMADStatisticCollector
Expand Down
4 changes: 3 additions & 1 deletion tests/onnx/quantization/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@


def mock_collect_statistics(mocker):
get_statistics_value = ONNXMinMaxTensorStatistic(min_values=-1, max_values=1)
get_statistics_value = ONNXMinMaxTensorStatistic(
{ONNXMinMaxTensorStatistic.MIN_STAT: -1, ONNXMinMaxTensorStatistic.MAX_STAT: 1}
)
_ = mocker.patch(
"nncf.quantization.fake_quantize.calculate_quantizer_parameters",
return_value=FakeQuantizeParameters(np.array(0), np.array(0), np.array(0), np.array(0), 256),
Expand Down
9 changes: 9 additions & 0 deletions tests/onnx/quantization/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import pytest

from nncf.common.graph.transformations.commands import TargetType
from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXAddLayerMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXConvolutionMetatype
from nncf.onnx.graph.metatypes.onnx_metatypes import ONNXDepthwiseConvolutionMetatype
Expand All @@ -27,6 +30,9 @@
ParamsCls = TemplateTestQuantizerConfig.TestGetStatisticsCollectorParameters


# pylint: disable=protected-access


class TestQuantizerConfig(TemplateTestQuantizerConfig):
def get_algo_backend(self):
return ONNXMinMaxAlgoBackend()
Expand All @@ -37,6 +43,9 @@ def check_is_min_max_statistic_collector(self, tensor_collector):
def check_is_mean_min_max_statistic_collector(self, tensor_collector):
assert isinstance(tensor_collector, ONNXMeanMinMaxStatisticCollector)

def get_reduction_axes(self, reducer: TensorStatisticCollectorBase) -> Tuple[int, ...]:
return reducer._reduction_shape

@pytest.fixture(
params=[
pytest.param(
Expand Down
6 changes: 6 additions & 0 deletions tests/openvino/native/quantization/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import pytest

from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.openvino.graph.layer_attributes import OVLayerAttributes
from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype
from nncf.openvino.graph.metatypes.openvino_metatypes import OVDepthwiseConvolutionMetatype
Expand Down Expand Up @@ -45,6 +48,9 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl
assert MeanAggregator in aggrs
assert aggrs[0].__class__ == aggrs[1].__class__

def get_reduction_axes(self, reducer: TensorReducerBase) -> Tuple[int, ...]:
return reducer._reduction_axes

@pytest.fixture(
params=[
pytest.param(
Expand Down
10 changes: 7 additions & 3 deletions tests/post_training/test_templates/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import List
from typing import List, Tuple

import pytest

Expand Down Expand Up @@ -53,6 +53,10 @@ def check_is_min_max_statistic_collector(self, tensor_collector):
def check_is_mean_min_max_statistic_collector(self, tensor_collector):
pass

@abstractmethod
def get_reduction_axes(self, reducer) -> Tuple[int, ...]:
pass

@abstractmethod
@pytest.fixture
def single_conv_nncf_graph(self) -> NNCFGraphToTest:
Expand Down Expand Up @@ -278,8 +282,8 @@ def test_get_stat_collector(

for reducer in reducers:
if q_config_per_channel:
assert reducer._reduction_axes == params.ref_per_ch_reduction_shape
assert self.get_reduction_axes(reducer) == params.ref_per_ch_reduction_shape
else:
assert reducer._reduction_axes == params.ref_per_tensor_reduction_shape
assert self.get_reduction_axes(reducer) == params.ref_per_tensor_reduction_shape

assert tensor_collector.num_samples == num_samples
6 changes: 6 additions & 0 deletions tests/torch/ptq/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import pytest

from nncf.common.graph.transformations.commands import TargetType
from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator
from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator
from nncf.experimental.common.tensor_statistics.collectors import MinAggregator
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase
from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
from tests.post_training.test_templates.models import NNCFGraphToTest
from tests.post_training.test_templates.models import NNCFGraphToTestDepthwiseConv
Expand Down Expand Up @@ -44,6 +47,9 @@ def check_is_mean_min_max_statistic_collector(self, tensor_collector: TensorColl
assert MeanAggregator in aggrs
assert aggrs[0].__class__ == aggrs[1].__class__

def get_reduction_axes(self, reducer: TensorReducerBase) -> Tuple[int, ...]:
return reducer._reduction_axes

@pytest.fixture(
params=[
(TargetType.PRE_LAYER_OPERATION, "/Sum_1_0", (0, 2), (0, 1, 2)),
Expand Down

0 comments on commit f97299f

Please sign in to comment.