Skip to content

Commit

Permalink
Refactor TensorCollector to check StatisticContainer before building
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 25, 2023
1 parent 3f44e98 commit 786e6d6
Show file tree
Hide file tree
Showing 17 changed files with 368 additions and 167 deletions.
45 changes: 43 additions & 2 deletions nncf/experimental/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from nncf.common.tensor_statistics.collectors import NNCFCollectorTensorProcessor
from nncf.common.tensor_statistics.collectors import NNCFTensor
from nncf.common.tensor_statistics.collectors import ReductionAxes
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.common.tensor_statistics.statistics import PercentileTensorStatistic
from nncf.common.tensor_statistics.statistics import RawTensorStatistic
from nncf.common.tensor_statistics.statistics import TensorStatistic
from nncf.quantization.advanced_parameters import AggregatorType

Expand Down Expand Up @@ -334,7 +339,7 @@ def get_statistics(self) -> Union[TensorStatistic, Dict[str, Any]]:

if not self._stat_container:
return kwargs
return self._stat_container(kwargs)
return self._build_statistic_container(self._stat_container, kwargs)

def get_inplace_fn_info(self) -> List[Tuple[Any, int]]:
"""
Expand Down Expand Up @@ -392,6 +397,39 @@ def get_tensor_collector_inputs(
target_inputs[reducer] = [outputs[name] for name in names]
return target_inputs

@staticmethod
def _build_statistic_container(statistic_container_cls: TensorStatistic, kwargs: Dict[Any, Any]):
if issubclass(statistic_container_cls, MinMaxTensorStatistic):
return statistic_container_cls(
min_values=kwargs[MinMaxTensorStatistic.MIN_STAT], max_values=kwargs[MinMaxTensorStatistic.MAX_STAT]
)
if issubclass(statistic_container_cls, MeanTensorStatistic):
return statistic_container_cls(
mean_values=kwargs[MeanTensorStatistic.MEAN_STAT], shape=kwargs[MeanTensorStatistic.SHAPE_STAT]
)
if issubclass(statistic_container_cls, RawTensorStatistic):
return statistic_container_cls(values=kwargs[RawTensorStatistic.VALUES_STATS])
if issubclass(statistic_container_cls, MedianMADTensorStatistic):
return statistic_container_cls(
median_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][
MedianMADTensorStatistic.MEDIAN_VALUES_STAT
],
mad_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][
MedianMADTensorStatistic.MAD_VALUES_STAT
],
)
if issubclass(statistic_container_cls, PercentileTensorStatistic):
if PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY in kwargs:
percentile_vs_values_dict = kwargs[PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY]
else:
percentile_vs_values_dict = {}
for (_, percentile), value in kwargs.items():
percentile_vs_values_dict[percentile] = value
return statistic_container_cls(percentile_vs_values_dict=percentile_vs_values_dict)
raise RuntimeError(
f"Statistic collector class {statistic_container_cls} is not supported by the TensorCollector class."
)


class MergedTensorCollector(TensorCollector):
"""
Expand Down Expand Up @@ -708,7 +746,10 @@ def _aggregate_impl(self) -> Dict[str, NNCFTensor]:
)
if not self._keepdims:
median_per_ch = self._tensor_processor.squeeze(median_per_ch, self._aggregation_axes)
return {"median_values": median_per_ch.tensor, "mad_values": mad_values.tensor}
return {
MedianMADTensorStatistic.MEDIAN_VALUES_STAT: median_per_ch.tensor,
MedianMADTensorStatistic.MAD_VALUES_STAT: mad_values.tensor,
}


class PercentileAggregator(TensorAggregatorBase):
Expand Down
20 changes: 7 additions & 13 deletions nncf/onnx/statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,8 @@ def _register_input(self, x: ONNXNNCFTensor):

def _get_statistics(self) -> ONNXMinMaxTensorStatistic:
return ONNXMinMaxTensorStatistic(
{
ONNXMinMaxTensorStatistic.MIN_STAT: self._min_values.tensor,
ONNXMinMaxTensorStatistic.MAX_STAT: self._max_values.tensor,
}
min_values=self._min_values.tensor,
max_values=self._max_values.tensor,
)


Expand All @@ -168,10 +166,8 @@ def _register_input(self, x: ONNXNNCFTensor):

def _get_statistics(self) -> ONNXMinMaxTensorStatistic:
return ONNXMinMaxTensorStatistic(
{
ONNXMinMaxTensorStatistic.MIN_STAT: self._min_aggregate().tensor,
ONNXMinMaxTensorStatistic.MAX_STAT: self._max_aggregate().tensor,
}
min_values=self._min_aggregate().tensor,
max_values=self._max_aggregate().tensor,
)


Expand All @@ -185,10 +181,8 @@ def _register_input(self, x: ONNXNNCFTensor):

def _get_statistics(self) -> ONNXMeanTensorStatistic:
return ONNXMeanTensorStatistic(
{
ONNXMeanTensorStatistic.MEAN_STAT: self._mean_aggregate().tensor,
ONNXMeanTensorStatistic.SHAPE_STAT: self._shape(),
}
mean_values=self._mean_aggregate().tensor,
shape=self._shape(),
)


Expand All @@ -201,4 +195,4 @@ def _register_input(self, x: ONNXNNCFTensor):
self._register_input_common(x)

def _get_statistics(self) -> ONNXRawTensorStatistic:
return ONNXRawTensorStatistic({ONNXRawTensorStatistic.VALUES_STATS: self._all_values})
return ONNXRawTensorStatistic(self._all_values)
9 changes: 0 additions & 9 deletions nncf/onnx/statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,18 @@


class ONNXMinMaxTensorStatistic(MinMaxTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT])

@staticmethod
def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool:
return bool(np.allclose(tensor1, tensor2, rtol=rtol))


class ONNXMeanTensorStatistic(MeanTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))


class ONNXRawTensorStatistic(RawTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.VALUES_STATS])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))
9 changes: 0 additions & 9 deletions nncf/openvino/statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,18 @@


class OVMinMaxTensorStatistic(MinMaxTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MIN_STAT], tensor_collector_output[self.MAX_STAT])

@staticmethod
def tensor_eq(tensor1: np.ndarray, tensor2: np.ndarray, rtol=1e-6) -> bool:
return bool(np.allclose(tensor1, tensor2, rtol=rtol))


class OVMeanTensorStatistic(MeanTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.MEAN_STAT], tensor_collector_output[self.SHAPE_STAT])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))


class OVRawTensorStatistic(RawTensorStatistic):
def __init__(self, tensor_collector_output):
super().__init__(tensor_collector_output[self.VALUES_STATS])

@staticmethod
def tensor_eq(tensor: np.ndarray, rtol=1e-6) -> bool:
return bool(np.all(tensor, rtol=rtol))
4 changes: 1 addition & 3 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def unify_statistics(statistics: List[ONNXMinMaxTensorStatistic]) -> ONNXMinMaxT
min_values.append(np.array(statistic.min_values).flatten())
max_values = np.max(max_values, axis=0)
min_values = np.min(min_values, axis=0)
return ONNXMinMaxTensorStatistic(
{ONNXMinMaxTensorStatistic.MIN_STAT: min_values, ONNXMinMaxTensorStatistic.MAX_STAT: max_values}
)
return ONNXMinMaxTensorStatistic(min_values=min_values, max_values=max_values)

@staticmethod
def _get_input_edges_mapping(nncf_graph: NNCFGraph):
Expand Down
4 changes: 1 addition & 3 deletions nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def unify_statistics(statistics: List[OVMinMaxTensorStatistic]) -> OVMinMaxTenso
min_values.append(np.array(statistic.min_values).flatten())
max_values = np.max(max_values, axis=0)
min_values = np.min(min_values, axis=0)
return OVMinMaxTensorStatistic(
{OVMinMaxTensorStatistic.MIN_STAT: min_values, OVMinMaxTensorStatistic.MAX_STAT: max_values}
)
return OVMinMaxTensorStatistic(min_values=min_values, max_values=max_values)

@staticmethod
def _get_reduction_shape_and_use_abs_max(
Expand Down
4 changes: 1 addition & 3 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,7 @@ def unify_statistics(statistics: List[PTMinMaxTensorStatistic]) -> PTMinMaxTenso
min_values.append(torch.tensor(statistic.min_values).flatten())
max_values = torch.max(torch.tensor(max_values))
min_values = torch.min(torch.tensor(min_values))
return PTMinMaxTensorStatistic(
{PTMinMaxTensorStatistic.MIN_STAT: min_values, PTMinMaxTensorStatistic.MAX_STAT: max_values}
)
return PTMinMaxTensorStatistic(min_values=min_values, max_values=max_values)

@staticmethod
def get_statistic_collector(
Expand Down
3 changes: 3 additions & 0 deletions nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,6 @@ def __init__(self, tensor: torch.tensor):
@property
def device(self) -> torch.device:
return self._tensor.device

def is_empty(self) -> bool:
return self.tensor.size == 0
63 changes: 52 additions & 11 deletions nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from functools import partial
from typing import Deque, List, Optional, Tuple, Union
from typing import Deque, List, Optional, Tuple, Type, Union

import numpy as np
import torch
Expand Down Expand Up @@ -229,6 +229,44 @@ class PTMeanPerChanelReducer(PTReducerMixIn, MeanPerChReducer):
pass


def _reshape_all(targets: Tuple[torch.Tensor, ...], target_shape: Tuple[int, ...]):
return map(lambda stat: torch.reshape(stat, target_shape), targets)


def _get_wrapped_min_max_tensor_statistic(target_shape: Tuple[int, ...]) -> Type[PTMinMaxTensorStatistic]:
"""
Returns PTMinMaxTensorStatistic type but all statistics are reshaped to target_shape.
:param target_shape: Target shape of the tensor statistic
:return: PTMinMaxTensorStatistic type but all statistics are reshaped to target_shape.
"""

class WrappedPTMinMaxTensorStatistic(PTMinMaxTensorStatistic):
def __init__(self, min_values, max_values):
min_values, max_values = _reshape_all((min_values, max_values), target_shape)
super().__init__(min_values, max_values)

return WrappedPTMinMaxTensorStatistic


def _get_wrapped_percentile_tensor_statistic(target_shape: Tuple[int, ...]) -> Type[PTPercentileTensorStatistic]:
"""
Returns PTPercentileTensorStatistic type but all statistics are reshaped to target_shape.
:param target_shape: Target shape of the tensor statistic
:return: PTPercentileTensorStatistic type but all statistics are reshaped to target_shape.
"""

class WrappedPTPercentileTensorStatistic(PTPercentileTensorStatistic):
def __init__(self, percentile_vs_values_dict):
reshaped_percentiles = {}
for k, v in percentile_vs_values_dict.items():
reshaped_percentiles[k] = torch.reshape(v, target_shape)
super().__init__(reshaped_percentiles)

return WrappedPTPercentileTensorStatistic


def get_min_max_statistic_collector(
use_abs_max: bool,
reduction_axes: Tuple[int, ...],
Expand All @@ -246,7 +284,8 @@ def get_min_max_statistic_collector(
:param num_samples: Maximum number of samples to collect.
:return: Min max statistic collector.
"""
tensor_collector = TensorCollector(partial(PTMinMaxTensorStatistic, target_shape=scale_shape))

tensor_collector = TensorCollector(_get_wrapped_min_max_tensor_statistic(target_shape=scale_shape))

aggregator_kwargs = {
"tensor_processor": PTNNCFCollectorTensorProcessor,
Expand Down Expand Up @@ -288,7 +327,7 @@ def get_mixed_min_max_statistic_collector(
Aggregates all available collected statistics in case parameter is None.
:return: Mixed min max statistic collector.
"""
tensor_collector = TensorCollector(partial(PTMinMaxTensorStatistic, target_shape=scale_shape))
tensor_collector = TensorCollector(_get_wrapped_min_max_tensor_statistic(target_shape=scale_shape))
min_reducer = PTMinReducer(reduction_axes)

kwargs = {
Expand Down Expand Up @@ -329,12 +368,17 @@ def get_median_mad_statistic_collector(
:return: Median Absolute Deviation statistic collector.
"""

class WrappedPTMedianMADTensorStatistic(PTMedianMADTensorStatistic):
def __init__(self, median_values, mad_values):
median_values, mad_values = _reshape_all((median_values, mad_values), scale_shape)
super().__init__(median_values, mad_values)

return _get_collection_without_reduction(
MedianAbsoluteDeviationAggregator,
PTMedianMADTensorStatistic,
WrappedPTMedianMADTensorStatistic,
reduction_axes=reduction_axes,
aggregation_axes=aggregation_axes,
scale_shape=scale_shape,
num_samples=num_samples,
window_size=window_size,
)
Expand Down Expand Up @@ -362,10 +406,9 @@ def get_percentile_tensor_collector(
"""
return _get_collection_without_reduction(
partial(PercentileAggregator, percentiles_to_collect=percentiles_to_collect),
PTPercentileTensorStatistic,
_get_wrapped_percentile_tensor_statistic(target_shape=scale_shape),
reduction_axes=reduction_axes,
aggregation_axes=aggregation_axes,
scale_shape=scale_shape,
num_samples=num_samples,
window_size=window_size,
)
Expand All @@ -376,7 +419,6 @@ def _get_collection_without_reduction(
statistic_cls: TensorAggregatorBase,
reduction_axes: Tuple[int, ...],
aggregation_axes: Tuple[int, ...],
scale_shape: Tuple[int, ...],
num_samples: int,
window_size: Optional[int] = None,
) -> TensorCollector:
Expand All @@ -387,13 +429,12 @@ def _get_collection_without_reduction(
:param aggregator_cls: Statistic class to build the tensor collector.
:param reduction_axes: Axes to use in reduction functions.
:param aggregation_axes: Axes to use in aggregation functions.
:param scale_shape: Target shape for collected statistics.
:param num_samples: Maximum number of samples to collect.
:param window_size: Number of samples from the end of the list of collected samples to aggregate.
Aggregates all available collected statistics in case parameter is None.
:return: Target statistic collector.
"""
tensor_collector = TensorCollector(partial(statistic_cls, target_shape=scale_shape))
tensor_collector = TensorCollector(statistic_cls)
reducer = PTNoopReducer()
aggregation_axes = list(set(list(aggregation_axes) + [dim + 1 for dim in reduction_axes]))
aggregator = aggregator_cls(
Expand Down Expand Up @@ -429,7 +470,7 @@ def get_mean_percentile_statistic_collector(
Aggregates all available collected statistics in case parameter is None.
:return: Mean percentile statistic collector.
"""
tensor_collector = TensorCollector(partial(PTPercentileTensorStatistic, target_shape=scale_shape))
tensor_collector = TensorCollector(_get_wrapped_percentile_tensor_statistic(target_shape=scale_shape))
quantiles_to_collect = np.true_divide(percentiles_to_collect, 100)
reducer = PTQuantileReducer(reduction_axes=reduction_axes, quantile=quantiles_to_collect)
for output_port_id, p in enumerate(percentiles_to_collect):
Expand Down
Loading

0 comments on commit 786e6d6

Please sign in to comment.