Skip to content

Commit

Permalink
reducers_axes -> reduction_axes, aggregator_axes=aggregation_axes
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 20, 2023
1 parent cdfffa0 commit 75a18fb
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 45 deletions.
4 changes: 2 additions & 2 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def generate_stat_collector_for_range_init_config(

use_per_sample_stats = collector_params.use_per_sample_stats(init_config.init_type == "mixed_min_max")
collector_kwargs = {
"reducers_axes": collector_params.get_reduction_axes(use_per_sample_stats),
"aggregators_axes": collector_params.get_aggregation_axes(use_per_sample_stats),
"reduction_axes": collector_params.get_reduction_axes(use_per_sample_stats),
"aggregation_axes": collector_params.get_aggregation_axes(use_per_sample_stats),
"scale_shape": scale_shape,
}

Expand Down
2 changes: 0 additions & 2 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,6 @@ def get_quantizer_config(self) -> QuantizerConfig:


def get_per_channel_scale_shape(input_shape, is_weights, channel_idx: int = None):
# TODO: case channel_ids=0, is_weights=True and per_sample_stats=True
# leads to dimension error in statistic calculation
scale_shape = [1 for _ in input_shape]
if channel_idx is None:
if is_weights:
Expand Down
74 changes: 37 additions & 37 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,17 @@ class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer):

def get_min_max_statistic_collector(
use_abs_max: bool,
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
) -> TensorCollector:
"""
Min max statistic collector builder.
:param use_abs_max: Whether to use abs max reducer or max reducer.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:return: Min max statistic collector.
Expand All @@ -257,23 +257,23 @@ def get_min_max_statistic_collector(
aggregator_kwargs = {
"tensor_processor": PTNNCFCollectorTensorProcessor,
"num_samples": num_samples,
"aggregation_axes": aggregators_axes,
"aggregation_axes": aggregation_axes,
}
min_reducer = PTMinReducer(reducers_axes)
min_reducer = PTMinReducer(reduction_axes)
min_aggregator = MinAggregator(**aggregator_kwargs)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer = max_reducer_cls(reducers_axes)
max_reducer = max_reducer_cls(reduction_axes)
max_aggregator = MaxAggregator(**aggregator_kwargs)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator)
return tensor_collector


def get_mixed_min_max_statistic_collector(
use_abs_max: bool,
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
use_means_of_mins: bool,
use_means_of_maxs: bool,
Expand All @@ -284,8 +284,8 @@ def get_mixed_min_max_statistic_collector(
Mixed min max statistic collector builder.
:param use_abs_max: Whether to use abs max reducer or max reducer.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param use_means_of_mins: Whether to use mean or min aggregator for minimum statistic branch.
:param use_means_of_maxs: Whether to use mean or max aggregator for maximum statistic branch.
Expand All @@ -295,20 +295,20 @@ def get_mixed_min_max_statistic_collector(
:return: Mixed min max statistic collector.
"""
tensor_collector = TensorCollector(partial(PTMinMaxTensorStatistic, target_shape=scale_shape))
min_reducer = PTMinReducer(reducers_axes)
min_reducer = PTMinReducer(reduction_axes)

kwargs = {
"tensor_processor": PTNNCFCollectorTensorProcessor,
"num_samples": num_samples,
"aggregation_axes": aggregators_axes,
"aggregation_axes": aggregation_axes,
"window_size": window_size,
}
min_aggregator_cls = MeanAggregator if use_means_of_mins else MinAggregator
min_aggregator = min_aggregator_cls(**kwargs)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MIN_STAT, min_reducer, min_aggregator)

max_reducer_cls = PTAbsMaxReducer if use_abs_max else PTMaxReducer
max_reducer = max_reducer_cls(reducers_axes)
max_reducer = max_reducer_cls(reduction_axes)
max_aggregator_cls = MeanAggregator if use_means_of_maxs else MaxAggregator
max_aggregator = max_aggregator_cls(**kwargs)
tensor_collector.register_statistic_branch(PTMinMaxTensorStatistic.MAX_STAT, max_reducer, max_aggregator)
Expand All @@ -317,17 +317,17 @@ def get_mixed_min_max_statistic_collector(


def get_median_mad_statistic_collector(
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
window_size: Optional[int] = None,
) -> TensorCollector:
"""
Median Absolute Deviation statistic collector builder.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Expand All @@ -338,8 +338,8 @@ def get_median_mad_statistic_collector(
return _get_collection_without_reduction(
MedianAbsoluteDeviationAggregator,
PTMedianMADTensorStatistic,
reducers_axes=reducers_axes,
aggregators_axes=aggregators_axes,
reduction_axes=reduction_axes,
aggregation_axes=aggregation_axes,
scale_shape=scale_shape,
num_samples=num_samples,
window_size=window_size,
Expand All @@ -348,8 +348,8 @@ def get_median_mad_statistic_collector(

def get_percentile_tensor_collector(
percentiles_to_collect: Tuple[int, ...],
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
window_size: Optional[int] = None,
Expand All @@ -358,8 +358,8 @@ def get_percentile_tensor_collector(
Percentile statistic collector builder.
:param percentiles_to_collect: Percetiles to use on aggregation phase.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Expand All @@ -369,8 +369,8 @@ def get_percentile_tensor_collector(
return _get_collection_without_reduction(
partial(percentileAggregator, percentiles_to_collect=percentiles_to_collect),
PTPercentileTensorStatistic,
reducers_axes=reducers_axes,
aggregators_axes=aggregators_axes,
reduction_axes=reduction_axes,
aggregation_axes=aggregation_axes,
scale_shape=scale_shape,
num_samples=num_samples,
window_size=window_size,
Expand All @@ -380,8 +380,8 @@ def get_percentile_tensor_collector(
def _get_collection_without_reduction(
aggregator_cls: TensorAggregatorBase,
statistic_cls: TensorAggregatorBase,
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
window_size: Optional[int] = None,
Expand All @@ -391,8 +391,8 @@ def _get_collection_without_reduction(
:param aggregator_cls: Aggregator class to build the tensor collector.
:param aggregator_cls: Statistic class to build the tensor collector.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Expand All @@ -401,7 +401,7 @@ def _get_collection_without_reduction(
"""
tensor_collector = TensorCollector(partial(statistic_cls, target_shape=scale_shape))
reducer = PTNoopReducer()
aggregation_axes = list(set(list(aggregators_axes) + [dim + 1 for dim in reducers_axes]))
aggregation_axes = list(set(list(aggregation_axes) + [dim + 1 for dim in reduction_axes]))
aggregator = aggregator_cls(
PTNNCFCollectorTensorProcessor,
aggregation_axes=aggregation_axes,
Expand All @@ -417,8 +417,8 @@ def _get_collection_without_reduction(

def get_mean_percentile_statistic_collector(
percentiles_to_collect: Tuple[int, ...],
reducers_axes: Tuple[int, ...],
aggregators_axes: Tuple[int, ...],
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
window_size: Optional[int] = None,
Expand All @@ -427,8 +427,8 @@ def get_mean_percentile_statistic_collector(
Mean percentile statistic collector builder.
:param percentiles_to_collect: Percetiles to use on reduction phase.
:param reducers_axes: Axes to use in reduction functions.
:param aggregators_axes: Axes to use in aggregation functions.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Expand All @@ -437,11 +437,11 @@ def get_mean_percentile_statistic_collector(
"""
tensor_collector = TensorCollector(partial(PTPercentileTensorStatistic, target_shape=scale_shape))
quantiles_to_collect = np.true_divide(percentiles_to_collect, 100)
reducer = PTQuantileReducer(reduction_axes=reducers_axes, quantile=quantiles_to_collect)
reducer = PTQuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect)
for output_port_id, p in enumerate(percentiles_to_collect):
aggregator = MeanAggregator(
PTNNCFCollectorTensorProcessor,
aggregation_axes=aggregators_axes,
aggregation_axes=aggregation_axes,
num_samples=num_samples,
window_size=window_size,
)
Expand Down
8 changes: 4 additions & 4 deletions tests/torch/tensor_statistics/test_tensor_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def test_collected_statistics_with_shape_convert(
collector_obj = collector(
scale_shape=scale_shape,
use_abs_max=True,
reducers_axes=reducer_axes,
aggregators_axes=(0,),
reduction_axes=reducer_axes,
aggregation_axes=(0,),
num_samples=None,
)
for input_ in TestCollectedStatistics.REF_INPUTS:
Expand Down Expand Up @@ -246,8 +246,8 @@ def test_collected_statistics(

collector_obj = collector(
scale_shape=reduction_shape,
reducers_axes=reducer_axes,
aggregators_axes=(0,),
reduction_axes=reducer_axes,
aggregation_axes=(0,),
num_samples=None,
)
for input_ in TestCollectedStatistics.REF_INPUTS:
Expand Down

0 comments on commit 75a18fb

Please sign in to comment.