Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 7, 2023
1 parent 9b119fe commit b4aced0
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 30 deletions.
18 changes: 9 additions & 9 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, reduction_axes: Optional[ReductionShape] = None, inplace: boo
:param keepdims: Should the axes which are reduced are left in the result
as dimensions with size one or not.
"""
self._reduction_shape = reduction_axes
self._reduction_axes = reduction_axes
self._tensor_processor: NNCFCollectorTensorProcessor = self._get_processor()
self._inplace = inplace
self._keepdims = keepdims
Expand Down Expand Up @@ -98,16 +98,16 @@ def __call__(self, x: List[NNCFTensor]):
def __eq__(self, __o: object) -> bool:
return (
isinstance(__o, self.__class__)
and self._reduction_shape == __o._reduction_shape
and self._reduction_axes == __o._reduction_axes
and self._inplace == __o.inplace
)

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._reduction_shape))
return hash((self.__class__.__name__, self.inplace, self._reduction_axes))

def _get_reduction_shape(self, tensor: NNCFTensor) -> Union[int, Tuple[int, ...]]:
if self._reduction_shape is not None:
return self._reduction_shape
if self._reduction_axes is not None:
return self._reduction_axes
return tuple(range(len(tensor.shape)))


Expand Down Expand Up @@ -481,7 +481,7 @@ def __eq__(self, __o: object) -> bool:
return super().__eq__(__o) and self._quantile == __o._quantile

def __hash__(self) -> int:
return hash((self.__class__.__name__, self.inplace, self._reduction_shape, tuple(self._quantile)))
return hash((self.__class__.__name__, self.inplace, self._reduction_axes, tuple(self._quantile)))


class QuantileReducer(QuantileReducerBase):
Expand Down Expand Up @@ -521,7 +521,7 @@ def __init__(self, channel_dim: int = 1, inplace: bool = False):
super().__init__(channel_dim, inplace)

def _reduce_out_of_place(self, x: List[NNCFTensor]) -> List[NNCFTensor]:
return [self._tensor_processor.mean_per_channel(x[0], self._reduction_shape)]
return [self._tensor_processor.mean_per_channel(x[0], self._reduction_axes)]


##################################################Aggregators##################################################
Expand Down Expand Up @@ -567,7 +567,7 @@ def __init__(

class OnlineAggregatorBase(OnlineOfflineAggregatorBase, ABC):
def _online_register_reduced_input_impl(self, x: TensorType, fn) -> None:
online_aggregation_axes = [dim - 1 for dim in self._aggregation_axes if dim != 0]
online_aggregation_axes = tuple([dim - 1 for dim in self._aggregation_axes if dim != 0])
if online_aggregation_axes:
reduced = fn(x, axis=online_aggregation_axes, keepdims=self._keepdims)
else:
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(
def _offline_aggregation_impl(self, fn) -> List[NNCFTensor]:
stacked_val = self._tensor_processor.stack(self._container)
result = self._tensor_processor.no_outliers_map(
stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile
stacked_val, fn, axis=self._aggregation_axes, alpha=self._quantile, keepdims=self._keepdims
)
return result.tensor

Expand Down
10 changes: 6 additions & 4 deletions nncf/openvino/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def masked_mean(
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor, mask=mask.tensor)
return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=False).data)
return OVNNCFTensor(np.ma.mean(masked_x, axis=axis, keepdims=keepdims).data)

@classmethod
def masked_median(
Expand Down Expand Up @@ -111,8 +111,11 @@ def no_outliers_map(
alpha: float = 0.01,
keepdims: bool = False,
) -> NNCFTensor:
if len(x.shape) == 1:
return fn(x, axis=None, mask=None, keepdims=keepdims)
if isinstance(axis, int):
axis = (axis,)

if len(axis) == len(x.shape):
return fn(x, axis=axis, mask=None, keepdims=keepdims)

x = x.tensor
low_values, high_values = np.quantile(x, [alpha, 1 - alpha], axis=axis)
Expand Down Expand Up @@ -249,7 +252,6 @@ def get_mean_stat_collector(num_samples, channel_axis, window_size=None, inplace

kwargs = {
"tensor_processor": OVNNCFCollectorTensorProcessor,
"use_per_sample_stats": False,
"num_samples": num_samples,
"window_size": window_size,
}
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def get_statistic_collector(
f"Aggregator type: {params.aggregator_type} is not supported for OpenVino PTQ backend yet."
)

kwargs = {"reduction_shape": reduction_shape, "inplace": inplace}
kwargs = {"reduction_axes": reduction_shape, "inplace": inplace}
if params.statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]:
if container_key == OVMinMaxTensorStatistic.MIN_STAT:
quantile = params.quantile_outlier_prob
Expand Down
11 changes: 7 additions & 4 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTen
if mask is None:
return cls.mean(x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor)
result = np.ma.mean(masked_x, axis=axis, keepdims=False).astype(masked_x.dtype)
result = np.ma.mean(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
if result.size <= 1:
return PTNNCFTensor(torch.tensor(result))
return PTNNCFTensor(torch.tensor(result.data))
Expand All @@ -112,7 +112,7 @@ def masked_median(
# Implemented in numy as torch.masked.median is not implemented yet
if mask is None:
return cls.median(x, axis=axis, keepdims=keepdims)
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=~mask.tensor.detach().cpu().numpy())
masked_x = np.ma.array(x.tensor.detach().cpu().numpy(), mask=mask.tensor.detach().cpu().numpy())
result = np.ma.median(masked_x, axis=axis, keepdims=keepdims).astype(masked_x.dtype)
if len(result) == 1:
return PTNNCFTensor(torch.tensor(result))
Expand Down Expand Up @@ -185,8 +185,11 @@ def no_outliers_map(
alpha: float = 0.01,
keepdims: bool = False,
):
if len(x.shape) == 1:
return fn(x, axis=None, mask=None, keepdims=keepdims)
if isinstance(axis, int):
axis = (axis,)

if len(x.shape) == len(axis):
return fn(x, axis=axis, mask=None, keepdims=keepdims)

low_values, high_values = cls.quantile(x, [alpha, 1 - alpha], axis=axis)
outliers_mask = torch.logical_or(x.tensor < low_values.tensor, high_values.tensor < x.tensor)
Expand Down
7 changes: 4 additions & 3 deletions tests/common/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def dataset_samples_to_conv_w(self, dataset_sample):
np.array((-10, -1, -128)),
)
),
),
)[12:],
)
def test_statistics_aggregator_min_max(
self,
Expand All @@ -375,6 +375,7 @@ def test_statistics_aggregator_min_max(
inplace_statistics,
is_backend_support_custom_estimators,
):
inplace_statistics = False
model = self.get_backend_model(dataset_samples)
quantizer_config = QuantizerConfig(
mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel
Expand Down Expand Up @@ -814,10 +815,10 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_
model = params["model"](dataset_samples)
params = {}
if statistics_type in [StatisticsType.MIN, StatisticsType.MAX, StatisticsType.ABS_MAX, StatisticsType.MEAN]:
params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)]
params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)]
params["inplace"] = [False, True]
elif statistics_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]:
params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)]
params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)]
params["quantile"] = [[0.01, 0.99], [0.001, 0.999]]
elif statistics_type == "batch_mean":
pytest.skip("Inplace statistic woun't work until openvino==2023.0.0 release")
Expand Down
35 changes: 27 additions & 8 deletions tests/experimental/common/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,21 @@
]


def default_test_mean_no_outlier(tensor_processor, aggregation_axes):
def default_test_mean_no_outlier(tensor_processor, aggregation_axes, keepdims):
return MeanNoOutliersAggregator(
tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile
tensor_processor=tensor_processor,
aggregation_axes=aggregation_axes,
quantile=default_test_quantile,
keepdims=keepdims,
)


def default_test_median_no_outlier(tensor_processor, aggregation_axes):
def default_test_median_no_outlier(tensor_processor, aggregation_axes, keepdims):
return MedianNoOutliersAggregator(
tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, quantile=default_test_quantile
tensor_processor=tensor_processor,
aggregation_axes=aggregation_axes,
quantile=default_test_quantile,
keepdims=keepdims,
)


Expand Down Expand Up @@ -180,6 +186,10 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non
def cast_tensor(self, tensor, dtype: Dtype):
pass

@abstractmethod
def expand_dims(self, tensor, dims: Tuple[int, ...]):
pass

def test_noop_reducer(self, reducers):
reducer = reducers["noop"]()
input_ = np.arange(24).reshape((1, 2, 3, 4))
Expand Down Expand Up @@ -321,7 +331,10 @@ def test_min_max_aggregators(self, aggregation_axes, keepdims, min_ref, max_ref,
]

@pytest.mark.parametrize("aggregator_cls,use_per_sample_stats,dims,refs", NO_OUTLIERS_TEST_PARAMS)
def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats):
@pytest.mark.parametrize("keepdims", [True, False])
def test_mean_median_agggregators(
self, aggregator_cls, refs, tensor_processor, dims, use_per_sample_stats, keepdims
):
input_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
input_with_outliers = np.array(
[100_000, -100_000, 200_000, -200_000, 300_000, -300_000, 400_000, -400_000, 500_000]
Expand All @@ -334,7 +347,9 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,
input_with_outliers = input_with_outliers.reshape((1, 3, 3))

aggregation_axes = (0, 1) if use_per_sample_stats else (0,)
aggregator = aggregator_cls(tensor_processor=tensor_processor, aggregation_axes=aggregation_axes)
aggregator = aggregator_cls(
tensor_processor=tensor_processor, aggregation_axes=aggregation_axes, keepdims=keepdims
)
for i in range(1, 6):
aggregator.register_reduced_input(self.get_nncf_tensor(input_ * i, Dtype.FLOAT))
# this registration is to make diff between mean and median bigger
Expand All @@ -346,6 +361,10 @@ def test_mean_median_agggregators(self, aggregator_cls, refs, tensor_processor,
mult = 2.2 * i - 1 if not is_median else 1
aggregator.register_reduced_input(self.get_nncf_tensor(input_with_outliers * mult, Dtype.FLOAT))
ret_val = aggregator.aggregate()

if keepdims:
refs = self.expand_dims(refs, (0, 1) if use_per_sample_stats else (0,))

assert self.all_close(ret_val, self.cast_tensor(refs, Dtype.FLOAT))

@pytest.mark.parametrize(
Expand All @@ -363,10 +382,10 @@ def test_reducers_name_hash_equal(self, reducer_name, reducers):

params = {}
if reducer_name in ["min", "max", "abs_max", "mean"]:
params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)]
params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)]
params["inplace"] = [False, True]
elif reducer_name in ["quantile", "abs_quantile"]:
params["reduction_shape"] = [None, (0, 1, 3), (1, 2, 3)]
params["reduction_axes"] = [None, (0, 1, 3), (1, 2, 3)]
params["quantile"] = [[0.01, 0.99], [0.001, 0.999]]
elif reducer_name == "batch_mean":
params["inplace"] = [False, True]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pytest

from nncf.common.graph.layer_attributes import Dtype
from nncf.common.tensor import NNCFTensor
from nncf.openvino.statistics.collectors import OVAbsMaxReducer
from nncf.openvino.statistics.collectors import OVAbsQuantileReducer
from nncf.openvino.statistics.collectors import OVBatchMeanReducer
Expand Down Expand Up @@ -62,3 +61,6 @@ def squeeze_tensor(self, ref_tensor: List[Any], axes: Optional[Tuple[int]] = Non

def cast_tensor(self, tensor, dtype: Dtype):
return tensor

def expand_dims(self, tensor, dims: Tuple[int, ...]):
return np.expand_dims(np.array(tensor), dims)
7 changes: 7 additions & 0 deletions tests/torch/ptq/test_reducers_and_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,10 @@ def cast_tensor(self, tensor, dtype: Dtype):
if dtype == Dtype.INTEGER:
return tensor.int()
raise RuntimeError()

def expand_dims(self, tensor, dims: Tuple[int, ...]):
tensor_ = torch.tensor(tensor)
shape = list(tensor_.shape)
for dim in dims:
shape.insert(dim, 1)
return tensor_.view(shape)

0 comments on commit b4aced0

Please sign in to comment.