From b4aced00d3e1eb52d4697ddff3f03e9144082148 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 7 Sep 2023 18:26:06 +0200 Subject: [PATCH] Fix tests --- .../common/tensor_statistics/collectors.py | 18 +++++----- nncf/openvino/statistics/collectors.py | 10 +++--- .../algorithms/min_max/openvino_backend.py | 2 +- nncf/torch/tensor_statistics/collectors.py | 11 +++--- tests/common/test_statistics_aggregator.py | 7 ++-- .../common/test_reducers_and_aggregators.py | 35 ++++++++++++++----- .../test_reducers_and_aggregators.py | 4 ++- .../ptq/test_reducers_and_aggregators.py | 7 ++++ 8 files changed, 64 insertions(+), 30 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 99166320922..94f447732d8 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -40,7 +40,7 @@ def __init__(self, reduction_axes: Optional[ReductionShape] = None, inplace: boo :param keepdims: Should the axes which are reduced are left in the result as dimensions with size one or not. """ - self._reduction_shape = reduction_axes + self._reduction_axes = reduction_axes self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor() self._inplace = inplace self._keepdims = keepdims @@ -98,16 +98,16 @@ def __call__(self, x: List[NNCFTensor]): def __eq__(self, __o: object) -> bool: return ( isinstance(__o, self.__class__) - and self._reduction_shape == __o._reduction_shape + and self._reduction_axes == __o._reduction_axes and self._inplace == __o.inplace ) def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape)) + return hash((self.__class__.__name__, self.inplace, self._reduction_axes)) def _get_reduction_shape(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...]]: - if self._reduction_shape is not None: - return self._reduction_shape + if self._reduction_axes is not None: + return self._reduction_axes return tuple(range(len(tensor.shape))) @@ -481,7 +481,7 @@ def __eq__(self, __o: object) -> bool: return super().__eq__(__o) and self._quantile == __o._quantile def __hash__(self) -> int: - return hash((self.__class__.__name__, self.inplace, self._reduction_shape, tuple(self._quantile))) + return hash((self.__class__.__name__, self.inplace, self._reduction_axes, tuple(self._quantile))) class QuantileReducer(QuantileReducerBase): @@ -521,7 +521,7 @@ def __init__(self, channel_dim: int = 1, inplace: bool = False): super().__init__(channel_dim, inplace) def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: - return [self._tensor_processor.mean_per_channel(x[0], self._reduction_shape)] + return [self._tensor_processor.mean_per_channel(x[0], self._reduction_axes)] ##################################################Aggregators################################################## @@ -567,7 +567,7 @@ def __init__( class OnlineAggregatorBase(OnlineOfflineAggregatorBase, ABC): def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None: - online_aggregation_axes = [dim - 1 for dim in self._aggregation_axes if dim != 0] + online_aggregation_axes = tuple([dim - 1 for dim in self._aggregation_axes if dim != 0]) if online_aggregation_axes: reduced = fn(x, axis=online_aggregation_axes, keepdims=self._keepdims) else: @@ -636,7 +636,7 @@ def __init__( def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]: stacked_val = self._tensor_processor.stack(self._container) result = self._tensor_processor.no_outliers_map( - stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile + stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile, keepdims=self._keepdims ) return result.tensor diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 7ce2855c636..94c3dcecc52 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -83,7 +83,7 @@ def masked_mean( if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) masked_x = np.ma.array(x.tensor, mask=mask.tensor) - return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=False).data) + return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=keepdims).data) @classmethod def masked_median( @@ -111,8 +111,11 @@ def no_outliers_map( alpha: float = 0.01, keepdims: bool = False, ) -> NNCFTensor: - if len(x.shape) == 1: - return fn(x, axis=None, mask=None, keepdims=keepdims) + if isinstance(axis, int): + axis = (axis,) + + if len(axis) == len(x.shape): + return fn(x, axis=axis, mask=None, keepdims=keepdims) x = x.tensor low_values, high_values = np.quantile(x, [alpha, 1 - alpha], axis=axis) @@ -249,7 +252,6 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace kwargs = { "tensor_processor": OVNNCFCollectorTensorProcessor, - "use_per_sample_stats": False, "num_samples": num_samples, "window_size": window_size, } diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index d1be5753024..eb91936cd10 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -181,7 +181,7 @@ def get_statistic_collector( f"Aggregator type: {params.aggregator_type} is not supported for OpenVino PTQ backend yet." ) - kwargs = {"reduction_shape": reduction_shape, "inplace": inplace} + kwargs = {"reduction_axes": reduction_shape, "inplace": inplace} if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: if container_key == OVMinMaxTensorStatistic.MIN_STAT: quantile = params.quantile_outlier_prob diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 2b795552e07..472b55cf241 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -100,7 +100,7 @@ def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTen if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor) - result = np.ma.mean(masked_x, axis=axis, keepdims=False).astype(masked_x.dtype) + result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) if result.size <= 1: return PTNNCFTensor(torch.tensor(result)) return PTNNCFTensor(torch.tensor(result.data)) @@ -112,7 +112,7 @@ def masked_median( # Implemented in numy as torch.masked.median is not implemented yet if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) - masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=~mask.tensor.detach().cpu().numpy()) + masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy()) result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype) if len(result) == 1: return PTNNCFTensor(torch.tensor(result)) @@ -185,8 +185,11 @@ def no_outliers_map( alpha: float = 0.01, keepdims: bool = False, ): - if len(x.shape) == 1: - return fn(x, axis=None, mask=None, keepdims=keepdims) + if isinstance(axis, int): + axis = (axis,) + + if len(x.shape) == len(axis): + return fn(x, axis=axis, mask=None, keepdims=keepdims) low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], axis=axis) outliers_mask = torch.logical_or(x.tensor < low_values.tensor, high_values.tensor < x.tensor) diff --git a/tests/common/test_statistics_aggregator.py b/tests/common/test_statistics_aggregator.py index 5988ff9e114..9b7f299ffd1 100644 --- a/tests/common/test_statistics_aggregator.py +++ b/tests/common/test_statistics_aggregator.py @@ -365,7 +365,7 @@ def dataset_samples_to_conv_w(self, dataset_sample): np.array((-10, -1, -128)), ) ), - ), + )[12:], ) def test_statistics_aggregator_min_max( self, @@ -375,6 +375,7 @@ def test_statistics_aggregator_min_max( inplace_statistics, is_backend_support_custom_estimators, ): + inplace_statistics = False model = self.get_backend_model(dataset_samples) quantizer_config = QuantizerConfig( mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel @@ -814,10 +815,10 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_ model = params["model"](dataset_samples) params = {} if statistics_type in [StatisticsType.MIN, StatisticsType.MAX, StatisticsType.ABS_MAX, StatisticsType.MEAN]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif statistics_type == "batch_mean": pytest.skip("Inplace statistic woun't work until openvino==2023.0.0 release") diff --git a/tests/experimental/common/test_reducers_and_aggregators.py b/tests/experimental/common/test_reducers_and_aggregators.py index b1d49977d98..8567cfb7a03 100644 --- a/tests/experimental/common/test_reducers_and_aggregators.py +++ b/tests/experimental/common/test_reducers_and_aggregators.py @@ -141,15 +141,21 @@ ] -def default_test_mean_no_outlier(tensor_processor, aggregation_axes): +def default_test_mean_no_outlier(tensor_processor, aggregation_axes, keepdims): return MeanNoOutliersAggregator( - tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile + tensor_processor=tensor_processor, + aggregation_axes=aggregation_axes, + quantile=default_test_quantile, + keepdims=keepdims, ) -def default_test_median_no_outlier(tensor_processor, aggregation_axes): +def default_test_median_no_outlier(tensor_processor, aggregation_axes, keepdims): return MedianNoOutliersAggregator( - tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile + tensor_processor=tensor_processor, + aggregation_axes=aggregation_axes, + quantile=default_test_quantile, + keepdims=keepdims, ) @@ -180,6 +186,10 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non def cast_tensor(self, tensor, dtype: Dtype): pass + @abstractmethod + def expand_dims(self, tensor, dims: Tuple[int, ...]): + pass + def test_noop_reducer(self, reducers): reducer = reducers["noop"]() input_ = np.arange(24).reshape((1, 2, 3, 4)) @@ -321,7 +331,10 @@ def test_min_max_aggregators(self, aggregation_axes, keepdims, min_ref, max_ref, ] @pytest.mark.parametrize("aggregator_cls,use_per_sample_stats,dims,refs", NO_OUTLIERS_TEST_PARAMS) - def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats): + @pytest.mark.parametrize("keepdims", [True, False]) + def test_mean_median_agggregators( + self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats, keepdims + ): input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) input_with_outliers = np.array( [100_000, -100_000, 200_000, -200_000, 300_000, -300_000, 400_000, -400_000, 500_000] @@ -334,7 +347,9 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, input_with_outliers = input_with_outliers.reshape((1, 3, 3)) aggregation_axes = (0, 1) if use_per_sample_stats else (0,) - aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes) + aggregator = aggregator_cls( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) for i in range(1, 6): aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT)) # this registration is to make diff between mean and median bigger @@ -346,6 +361,10 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, mult = 2.2 * i - 1 if not is_median else 1 aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult, Dtype.FLOAT)) ret_val = aggregator.aggregate() + + if keepdims: + refs = self.expand_dims(refs, (0, 1) if use_per_sample_stats else (0,)) + assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT)) @pytest.mark.parametrize( @@ -363,10 +382,10 @@ def test_reducers_name_hash_equal(self, reducer_name, reducers): params = {} if reducer_name in ["min", "max", "abs_max", "mean"]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["inplace"] = [False, True] elif reducer_name in ["quantile", "abs_quantile"]: - params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)] + params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)] params["quantile"] = [[0.01, 0.99], [0.001, 0.999]] elif reducer_name == "batch_mean": params["inplace"] = [False, True] diff --git a/tests/openvino/native/quantization/test_reducers_and_aggregators.py b/tests/openvino/native/quantization/test_reducers_and_aggregators.py index f16be85df7b..213a64b2e84 100644 --- a/tests/openvino/native/quantization/test_reducers_and_aggregators.py +++ b/tests/openvino/native/quantization/test_reducers_and_aggregators.py @@ -15,7 +15,6 @@ import pytest from nncf.common.graph.layer_attributes import Dtype -from nncf.common.tensor import NNCFTensor from nncf.openvino.statistics.collectors import OVAbsMaxReducer from nncf.openvino.statistics.collectors import OVAbsQuantileReducer from nncf.openvino.statistics.collectors import OVBatchMeanReducer @@ -62,3 +61,6 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non def cast_tensor(self, tensor, dtype: Dtype): return tensor + + def expand_dims(self, tensor, dims: Tuple[int, ...]): + return np.expand_dims(np.array(tensor), dims) diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py index e88904e2974..c6a97696b00 100644 --- a/tests/torch/ptq/test_reducers_and_aggregators.py +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -74,3 +74,10 @@ def cast_tensor(self, tensor, dtype: Dtype): if dtype == Dtype.INTEGER: return tensor.int() raise RuntimeError() + + def expand_dims(self, tensor, dims: Tuple[int, ...]): + tensor_ = torch.tensor(tensor) + shape = list(tensor_.shape) + for dim in dims: + shape.insert(dim, 1) + return tensor_.view(shape)