diff --git a/nncf/common/tensor_statistics/collectors.py b/nncf/common/tensor_statistics/collectors.py index 9e9f961bbbe..84dc8746b8b 100644 --- a/nncf/common/tensor_statistics/collectors.py +++ b/nncf/common/tensor_statistics/collectors.py @@ -307,7 +307,9 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor: @classmethod @abstractmethod - def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor: + def no_outliers_map( + cls, x: NNCFTensor, fn: MaskedReduceFN, axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01 + ) -> NNCFTensor: """ Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that are smaller that alpha and bigger than 1 - alpha quantile and applies diff --git a/nncf/experimental/common/tensor_statistics/collectors.py b/nncf/experimental/common/tensor_statistics/collectors.py index 5a51691669d..99166320922 100644 --- a/nncf/experimental/common/tensor_statistics/collectors.py +++ b/nncf/experimental/common/tensor_statistics/collectors.py @@ -120,7 +120,7 @@ class TensorAggregatorBase: def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, - aggregation_axes: Union[int, Tuple[int, ...]] = (0,), + aggregation_axes: Optional[Tuple[int, ...]] = None, keepdims: bool = False, num_samples: Optional[int] = None, ): @@ -132,7 +132,7 @@ def __init__( """ self._tensor_processor = tensor_processor - self._aggregation_axes = aggregation_axes + self._aggregation_axes = (0,) if aggregation_axes is None else aggregation_axes self._keepdims = keepdims self._num_samples = num_samples self._collected_samples = 0 @@ -495,15 +495,17 @@ class AbsQuantileReducer(QuantileReducerBase): def __init__( self, reduction_axes: Optional[ReductionShape] = None, - quantile: Union[float, List[float]] = 0.99, + quantile: Optional[Union[float, List[float]]] = None, inplace: bool = False, + keepdims: bool = True, ): - super().__init__(reduction_axes, quantile, False) + quantile = (0.99,) if quantile is None else quantile + super().__init__(reduction_axes, quantile, False, keepdims) def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]: x = self._tensor_processor.abs(x[0]) reduction_shape = self._get_reduction_shape(x) - return self._tensor_processor.quantile(x, [self._quantile], reduction_shape, keepdims=self._keepdims) + return self._tensor_processor.quantile(x, self._quantile, reduction_shape, keepdims=self._keepdims) class BatchMeanReducer(TensorReducerBase): @@ -551,7 +553,7 @@ class OnlineOfflineAggregatorBase(TensorAggregatorBase): def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, - aggregation_axes: Union[int, Tuple[int, ...]] = 0, + aggregation_axes: Optional[Tuple[int, ...]] = None, keepdims: bool = False, num_samples: Optional[int] = None, window_size=None, @@ -618,15 +620,14 @@ class NoOutliersAggregatorBase(OfflineAggregatorBase, ABC): def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, - aggregation_axes: Union[int, Tuple[int, ...]] = 0, + aggregation_axes: Optional[Tuple[int, ...]] = None, keepdims: bool = False, num_samples: Optional[int] = None, window_size=None, quantile: float = 0.01, ): - assert len(aggregation_axes) == 1 super().__init__( - tensor_processor, aggregation_axes=aggregation_axes[0], keepdims=keepdims, num_samples=num_samples + tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims, num_samples=num_samples ) self._window_size = window_size self._container = deque(maxlen=window_size) @@ -681,7 +682,7 @@ def __init__( self, tensor_processor: NNCFCollectorTensorProcessor, percentiles_to_collect: List[float], - aggregation_axes: Union[int, Tuple[int, ...]] = 0, + aggregation_axes: Optional[Tuple[int, ...]] = None, keepdims: bool = False, num_samples: Optional[int] = None, window_size=None, diff --git a/nncf/openvino/statistics/collectors.py b/nncf/openvino/statistics/collectors.py index 4704a12a0fb..7ce2855c636 100644 --- a/nncf/openvino/statistics/collectors.py +++ b/nncf/openvino/statistics/collectors.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Deque, List, Optional, Union +from typing import Any, Callable, Deque, List, Optional, Tuple, Union import numpy as np @@ -49,11 +49,11 @@ class OVNNCFCollectorTensorProcessor(NNCFCollectorTensorProcessor): """ @staticmethod - def reduce_min(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_min(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amin(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def reduce_max(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = True) -> NNCFTensor: + def reduce_max(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = True) -> NNCFTensor: return OVNNCFTensor(np.amax(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod @@ -69,16 +69,16 @@ def max(x1: NNCFTensor, x2: NNCFTensor) -> NNCFTensor: return OVNNCFTensor(np.maximum(x1.tensor, x2.tensor)) @staticmethod - def mean(x: NNCFTensor, axis: Union[int, tuple], keepdims: bool = False) -> NNCFTensor: + def mean(x: NNCFTensor, axis: Union[int, Tuple], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.mean(x.tensor, axis=axis, keepdims=keepdims)) @staticmethod - def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims: bool = False) -> NNCFTensor: + def median(x: NNCFTensor, axis: Union[int, Tuple, list], keepdims: bool = False) -> NNCFTensor: return OVNNCFTensor(np.median(x.tensor, axis=axis, keepdims=keepdims)) @classmethod def masked_mean( - cls, x: NNCFTensor, axis: Optional[Union[int, tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False ) -> NNCFTensor: if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) @@ -87,7 +87,7 @@ def masked_mean( @classmethod def masked_median( - cls, x: NNCFTensor, axis: Optional[Union[int, tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False + cls, x: NNCFTensor, axis: Optional[Union[int, Tuple, list]], mask: Optional[NNCFTensor], keepdims: bool = False ) -> NNCFTensor: if mask is None: return cls.median(x, axis=axis, keepdims=keepdims) @@ -107,7 +107,7 @@ def no_outliers_map( cls, x: NNCFTensor, fn: Callable[[NNCFTensor, int, NNCFTensor], Any], - axis: int = 0, + axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01, keepdims: bool = False, ) -> NNCFTensor: @@ -115,12 +115,9 @@ def no_outliers_map( return fn(x, axis=None, mask=None, keepdims=keepdims) x = x.tensor - if axis: - x = np.moveaxis(x, axis, 0) - - low_values, high_values = np.quantile(x, [alpha, 1 - alpha], 0) + low_values, high_values = np.quantile(x, [alpha, 1 - alpha], axis=axis) outliers_mask = np.logical_or(x < low_values, high_values < x) - return fn(OVNNCFTensor(x), axis=0, mask=OVNNCFTensor(outliers_mask), keepdims=keepdims) + return fn(OVNNCFTensor(x), axis=axis, mask=OVNNCFTensor(outliers_mask), keepdims=keepdims) @staticmethod def batch_mean(x: NNCFTensor) -> NNCFTensor: @@ -141,7 +138,7 @@ def sum(tensor: NNCFTensor) -> TensorElementsType: @staticmethod def quantile( - tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False + tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, Tuple, list], keepdims: bool = False ) -> List[NNCFTensor]: result = np.quantile(tensor.tensor, quantile, axis, keepdims=keepdims) return [OVNNCFTensor(x) for x in result] diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index e55b6ee8f10..2b795552e07 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -100,8 +100,8 @@ def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTen if mask is None: return cls.mean(x, axis=axis, keepdims=keepdims) masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor) - result = np.ma.mean(masked_x, axis=axis, keepdims=False) - if len(result) == 1: + result = np.ma.mean(masked_x, axis=axis, keepdims=False).astype(masked_x.dtype) + if result.size <= 1: return PTNNCFTensor(torch.tensor(result)) return PTNNCFTensor(torch.tensor(result.data)) @@ -181,20 +181,16 @@ def no_outliers_map( cls, x: NNCFTensor, fn: Callable[[NNCFTensor, int, NNCFTensor], Any], - axis: int = 0, + axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01, keepdims: bool = False, ): if len(x.shape) == 1: return fn(x, axis=None, mask=None, keepdims=keepdims) - x = x.tensor - if axis: - x = torch.moveaxis(x, [axis] if isinstance(axis, int) else axis, 0) - - low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], 0) - outliers_mask = np.logical_or(x < low_values, high_values < x) - return fn(x, axis=0, mask=PTNNCFTensor(outliers_mask), keepdims=keepdims) + low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], axis=axis) + outliers_mask = torch.logical_or(x.tensor < low_values.tensor, high_values.tensor < x.tensor) + return fn(x, axis=axis, mask=PTNNCFTensor(outliers_mask), keepdims=keepdims) @classmethod def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor: diff --git a/tests/experimental/common/test_reducers_and_aggregators.py b/tests/experimental/common/test_reducers_and_aggregators.py index cc54fe987ac..b1d49977d98 100644 --- a/tests/experimental/common/test_reducers_and_aggregators.py +++ b/tests/experimental/common/test_reducers_and_aggregators.py @@ -11,10 +11,12 @@ from abc import abstractmethod from itertools import product +from typing import Any, List, Optional, Tuple import numpy as np import pytest +from nncf.common.graph.layer_attributes import Dtype from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.experimental.common.tensor_statistics.collectors import MeanAggregator from nncf.experimental.common.tensor_statistics.collectors import MeanNoOutliersAggregator @@ -43,17 +45,117 @@ default_test_quantile = 0.1 -def default_test_mean_no_outlier(tp, ps): - return MeanNoOutliersAggregator(tp, ps, quantile=default_test_quantile) +OFFLINE_AGGREGATORS_TEST_CASES = [ + ( + None, + False, + [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]], + [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]], + ), + ( + (0,), + False, + [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]], + [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]], + ), + ( + (0, 2), + False, + [[-50000, -28, -32]], + [[50000, 28, 32]], + ), + ( + (2,), + False, + [ + [[-50000, 5, 10]], + [[-40000, 4, 8]], + [[-30000, 3, 6]], + [[-20000, 2, 4]], + [[-10000, 1, 2]], + [[0, 0, 0]], + [[-6, -7, -8]], + [[-12, -14, -16]], + [[-18, -21, -24]], + [[-24, -28, -32]], + ], + [ + [[50000, -5, -10]], + [[40000, -4, -8]], + [[30000, -3, -6]], + [[20000, -2, -4]], + [[10000, -1, -2]], + [[0, 0, 0]], + [[6, 7, 8]], + [[12, 14, 16]], + [[18, 21, 24]], + [[24, 28, 32]], + ], + ), + ( + None, + True, + [[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]], + [[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]], + ), + ( + (0,), + True, + [[[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]]], + [[[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]]], + ), + ( + (0, 2), + True, + [[[[-50000, -28, -32]]]], + [[[[50000, 28, 32]]]], + ), + ( + (2,), + True, + [ + [[[-50000, 5, 10]]], + [[[-40000, 4, 8]]], + [[[-30000, 3, 6]]], + [[[-20000, 2, 4]]], + [[[-10000, 1, 2]]], + [[[0, 0, 0]]], + [[[-6, -7, -8]]], + [[[-12, -14, -16]]], + [[[-18, -21, -24]]], + [[[-24, -28, -32]]], + ], + [ + [[[50000, -5, -10]]], + [[[40000, -4, -8]]], + [[[30000, -3, -6]]], + [[[20000, -2, -4]]], + [[[10000, -1, -2]]], + [[[0, 0, 0]]], + [[[6, 7, 8]]], + [[[12, 14, 16]]], + [[[18, 21, 24]]], + [[[24, 28, 32]]], + ], + ), +] + + +def default_test_mean_no_outlier(tensor_processor, aggregation_axes): + return MeanNoOutliersAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile + ) -def default_test_median_no_outlier(tp, ps): - return MedianNoOutliersAggregator(tp, ps, quantile=default_test_quantile) +def default_test_median_no_outlier(tensor_processor, aggregation_axes): + return MedianNoOutliersAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile + ) class TemplateTestReducersAggreagtors: @abstractmethod - def get_nncf_tensor(self, x: np.array): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): pass @pytest.fixture @@ -70,6 +172,14 @@ def reducers(self): def all_close(self, val, ref) -> bool: pass + @abstractmethod + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + pass + + @abstractmethod + def cast_tensor(self, tensor, dtype: Dtype): + pass + def test_noop_reducer(self, reducers): reducer = reducers["noop"]() input_ = np.arange(24).reshape((1, 2, 3, 4)) @@ -87,27 +197,31 @@ def test_noop_reducer(self, reducers): ], ) def test_min_max_mean_reducers(self, reducer_name, ref, reducers): - reduction_shape = (1, 2) + reduction_axes = (1, 2) input_ = np.arange(-26, 10).reshape((4, 3, 3)) - for i, red_shape in enumerate([reduction_shape, None]): - reducer = reducers[reducer_name](red_shape, False) - val = reducer([self.get_nncf_tensor(input_)]) - assert len(val) == 1 - assert self.all_close(val[0].tensor, ref[i]) + for i, red_axes in enumerate([reduction_axes, None]): + for keepdims in [True, False]: + reducer = reducers[reducer_name](reduction_axes=red_axes, inplace=False, keepdims=keepdims) + val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)]) + assert len(val) == 1 + ref_ = ref[i] if keepdims else self.squeeze_tensor(ref[i]) + assert self.all_close(val[0].tensor, self.cast_tensor(ref_, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name,ref", [("quantile", ([[[[-20000]]]], [[[[10000]]]])), ("abs_quantile", ([[[[20000]]]],))] ) def test_quantile_reducers(self, reducer_name, ref, reducers): - reduction_shape = (1, 2, 3) + reduction_axes = (1, 2, 3) input_ = np.arange(-26, 10).reshape((1, 4, 3, 3)) input_[0][0][0] = -20000 input_[0][0][1] = 10000 - reducer = reducers[reducer_name](reduction_shape, inplace=False) - val = reducer([self.get_nncf_tensor(input_)]) - assert len(val) == len(ref) - for i, ref_ in enumerate(ref): - assert self.all_close(val[i].tensor, ref_) + for keepdims in [True, False]: + reducer = reducers[reducer_name](reduction_axes=reduction_axes, inplace=False, keepdims=keepdims) + val = reducer([self.get_nncf_tensor(input_, dtype=Dtype.FLOAT)]) + assert len(val) == len(ref) + for i, ref_ in enumerate(ref): + ref_ = ref[i] if keepdims else self.squeeze_tensor(ref[i], (1, 2, 3)) + assert self.all_close(val[i].tensor, self.cast_tensor(ref_, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name,ref", @@ -116,9 +230,9 @@ def test_quantile_reducers(self, reducer_name, ref, reducers): def test_batch_mean_mean_per_ch_reducers(self, reducer_name, ref, reducers): input_ = np.arange(-26, 10).reshape((4, 1, 3, 3)) reducer = reducers[reducer_name](inplace=False) - val = reducer([self.get_nncf_tensor(input_)]) + val = reducer([self.get_nncf_tensor(input_, Dtype.FLOAT)]) assert len(val) == 1 - assert self.all_close(val[0].tensor, ref) + assert self.all_close(val[0].tensor, self.cast_tensor(ref, Dtype.FLOAT)) def test_noop_aggregator(self): aggregator = NoopAggregator(None) @@ -146,20 +260,28 @@ def test_shape_aggregator(self): assert aggregator._collected_samples == 1 assert ref_shape == aggregator.aggregate() - def test_min_max_aggregators(self, tensor_processor): - min_aggregator = MinAggregator(tensor_processor) - max_aggregator = MaxAggregator(tensor_processor) + @pytest.mark.parametrize( + "aggregation_axes,keepdims,min_ref,max_ref", + OFFLINE_AGGREGATORS_TEST_CASES, + ) + def test_min_max_aggregators(self, aggregation_axes, keepdims, min_ref, max_ref, tensor_processor): + min_aggregator = MinAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) + max_aggregator = MaxAggregator( + tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims + ) input_ = np.arange(3 * 3).reshape((1, 3, 3)) input_[0, 0, 0] = -10000 for i in range(-5, 5): min_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * (-i))) max_aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i)) - min_ref = [[[-50000, -4, -8], [-12, -16, -20], [-24, -28, -32]]] - assert self.all_close(min_ref, min_aggregator.aggregate()) - - max_ref = [[[50000, 4, 8], [12, 16, 20], [24, 28, 32]]] - assert self.all_close(max_ref, max_aggregator.aggregate()) + assert self.all_close( + min_aggregator.aggregate(), + min_ref, + ) + assert self.all_close(max_aggregator.aggregate(), max_ref) NO_OUTLIERS_TEST_PARAMS = [ (MeanAggregator, True, 1, 1404.5138888888905), @@ -211,19 +333,20 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, input_ = input_.reshape((1, 3, 3)) input_with_outliers = input_with_outliers.reshape((1, 3, 3)) - aggregator = aggregator_cls(tensor_processor, use_per_sample_stats) + aggregation_axes = (0, 1) if use_per_sample_stats else (0,) + aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes) for i in range(1, 6): - aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT)) # this registration is to make diff between mean and median bigger - aggregator.register_reduced_input(self.get_nncf_tensor(input_ * 10)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_ * 10, Dtype.FLOAT)) is_median = isinstance(aggregator, (MedianAggregator, MedianNoOutliersAggregator)) # Outliers registration for i in range(2): # mult is needed to make outlier and no outlier aggreagators differs mult = 2.2 * i - 1 if not is_median else 1 - aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult)) + aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult, Dtype.FLOAT)) ret_val = aggregator.aggregate() - assert self.all_close(ret_val, refs) + assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT)) @pytest.mark.parametrize( "reducer_name", diff --git a/tests/openvino/native/quantization/test_reducers_and_aggregators.py b/tests/openvino/native/quantization/test_reducers_and_aggregators.py index 4726ac61194..f16be85df7b 100644 --- a/tests/openvino/native/quantization/test_reducers_and_aggregators.py +++ b/tests/openvino/native/quantization/test_reducers_and_aggregators.py @@ -9,9 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List, Optional, Tuple + import numpy as np import pytest +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.tensor import NNCFTensor from nncf.openvino.statistics.collectors import OVAbsMaxReducer from nncf.openvino.statistics.collectors import OVAbsQuantileReducer from nncf.openvino.statistics.collectors import OVBatchMeanReducer @@ -31,7 +35,7 @@ class TestReducersAggregators(TemplateTestReducersAggreagtors): def tensor_processor(self): return OVNNCFCollectorTensorProcessor - def get_nncf_tensor(self, x: np.array): + def get_nncf_tensor(self, x: np.array, dtype: Optional[Dtype] = None): return OVNNCFTensor(x) @pytest.fixture(scope="module") @@ -52,3 +56,9 @@ def all_close(self, val, ref) -> bool: val_ = np.array(val) ref_ = np.array(ref) return np.allclose(val_, ref_) and val_.shape == ref_.shape + + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + return np.squeeze(np.array(ref_tensor), axes) + + def cast_tensor(self, tensor, dtype: Dtype): + return tensor diff --git a/tests/torch/ptq/test_reducers_and_aggregators.py b/tests/torch/ptq/test_reducers_and_aggregators.py index fefb8347f16..e88904e2974 100644 --- a/tests/torch/ptq/test_reducers_and_aggregators.py +++ b/tests/torch/ptq/test_reducers_and_aggregators.py @@ -9,9 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List, Optional, Tuple + +import numpy as np import pytest import torch +from nncf.common.graph.layer_attributes import Dtype from nncf.torch.tensor import PTNNCFTensor from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer from nncf.torch.tensor_statistics.collectors import PTAbsQuantileReducer @@ -31,8 +35,13 @@ class TestReducersAggregators(TemplateTestReducersAggreagtors): def tensor_processor(self): return PTNNCFCollectorTensorProcessor - def get_nncf_tensor(self, x: torch.Tensor): - return PTNNCFTensor(x) + def get_nncf_tensor(self, x: np.ndarray, dtype: Optional[Dtype] = None): + torch_tensor = torch.tensor(x) + if dtype == Dtype.FLOAT: + torch_tensor = torch_tensor.float() + elif dtype == Dtype.INTEGER: + torch_tensor = torch_tensor.int() + return PTNNCFTensor(torch_tensor) @pytest.fixture(scope="module") def reducers(self): @@ -52,3 +61,16 @@ def all_close(self, val, ref) -> bool: val_ = torch.tensor(val) ref_ = torch.tensor(ref) return torch.allclose(val_, ref_) and val_.shape == ref_.shape + + def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = None): + if axes is None: + return torch.tensor(ref_tensor).squeeze() + return torch.tensor(ref_tensor).squeeze(axes) + + def cast_tensor(self, tensor, dtype: Dtype): + tensor = torch.tensor(tensor) + if dtype == Dtype.FLOAT: + return tensor.float() + if dtype == Dtype.INTEGER: + return tensor.int() + raise RuntimeError()