diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index f82877da854..e55b6ee8f10 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -254,7 +254,7 @@ class PTBatchMeanReducer(PTReducerMixIn, BatchMeanReducer): pass -class PTMeanPerChReducer(PTReducerMixIn, MeanPerChReducer): +class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer): pass @@ -424,7 +424,7 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None): if channel_axis == 0: reducer = PTBatchMeanReducer() else: - reducer = PTMeanPerChReducer(channel_axis) + reducer = PTMeanPerChanelReducer(channel_axis) noop_reducer = PTNoopReducer() kwargs = { diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py new file mode 100644 index 00000000000..8abf176289f --- /dev/null +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from nncf.torch.tensor import PTNNCFTensor +from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer +from nncf.torch.tensor_statistics.collectors import PTAbsQuantileReducer +from nncf.torch.tensor_statistics.collectors import PTBatchMeanReducer +from nncf.torch.tensor_statistics.collectors import PTMaxReducer +from nncf.torch.tensor_statistics.collectors import PTMeanPerChanelReducer +from nncf.torch.tensor_statistics.collectors import PTMeanReducer +from nncf.torch.tensor_statistics.collectors import PTMinReducer +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor +from nncf.torch.tensor_statistics.collectors import PTNoopReducer +from nncf.torch.tensor_statistics.collectors import PTQuantileReducer +from tests.experimental.common.test_reducers_and_aggregators import TemplateTestReducersAggreagtors + + +class TestReducersAggregators(TemplateTestReducersAggreagtors): + @pytest.fixture + def tensor_processor(self): + return PTNNCFCollectorTensorProcessor + + def get_nncf_tensor(self, x: torch.Tensor): + return PTNNCFTensor(x) + + @pytest.fixture(scope="module") + def reducers(self): + return { + "noop": PTNoopReducer, + "min": PTMinReducer, + "max": PTMaxReducer, + "abs_max": PTAbsMaxReducer, + "mean": PTMeanReducer, + "quantile": PTQuantileReducer, + "abs_quantile": PTAbsQuantileReducer, + "batch_mean": PTBatchMeanReducer, + "mean_per_ch": PTMeanPerChanelReducer, + } + + def all_close(self, val, ref) -> bool: + val_ = torch.tensor(val) + ref_ = torch.tensor(ref) + return torch.allclose(val_, ref_) and val_.shape == ref_.shape