Skip to content

Commit

Permalink
[Torch] Experimental tensor collector is using for statistic collection
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 27, 2023
1 parent b16d4a1 commit bf9c611
Show file tree
Hide file tree
Showing 59 changed files with 2,241 additions and 842 deletions.
122 changes: 86 additions & 36 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC
from abc import abstractmethod
from collections import deque
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np

Expand All @@ -21,14 +21,13 @@
from nncf.common.tensor import TensorType
from nncf.common.tensor_statistics.reduction import get_per_channel_history

ReductionShape = Tuple[int]
MaskedReduceFN = Callable[[NNCFTensor, Union[int, tuple, list], NNCFTensor, bool], NNCFTensor]
ReductionAxes = Tuple[int]


class TensorStatisticCollectorBase(ABC):
"""Collector estimate statistics at the quantization point based on the provided reduction shape."""

def __init__(self, reduction_shape: Optional[ReductionShape] = None, num_samples: Optional[int] = None):
def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples: Optional[int] = None):
"""
Initializes Tensor Statistic Collector
Expand Down Expand Up @@ -101,7 +100,7 @@ class OfflineTensorStatisticCollector(TensorStatisticCollectorBase):
"""Collects statistics in offline regime by storing the data and aggregating it afterwards."""

def __init__(
self, reduction_shape: Optional[ReductionShape] = None, num_samples: int = None, window_size: int = None
self, reduction_shape: Optional[ReductionAxes] = None, num_samples: int = None, window_size: int = None
):
super().__init__(reduction_shape, num_samples)
self._samples = deque(maxlen=window_size)
Expand Down Expand Up @@ -199,9 +198,9 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
"""
Computes the masked mean of elements across given dimensions of NNCFTensor.
Expand All @@ -214,9 +213,11 @@ def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor,
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_median(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_median(
cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
"""
Computes the masked median of elements across given dimensions of NNCFTensor.
Expand Down Expand Up @@ -251,6 +252,16 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]:
:return: List of NNCFTensor.
"""

@staticmethod
@abstractmethod
def squeeze(x: NNCFTensor, dim: Optional[Union[int, Tuple[int, ...]]] = None) -> NNCFTensor:
"""
Remove axes of length one from x.
:param x: NNCFTensor to squeeze.
:param axis: Selects a subset of the entries of length one in the shape.
"""

@staticmethod
@abstractmethod
def sum(tensor: NNCFTensor) -> TensorElementsType:
Expand All @@ -267,15 +278,36 @@ def quantile(
tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False
) -> List[TensorElementsType]:
"""
Compute the quantile-th percentile(s) of the data along the specified axis.
Compute the quantile(s) of the data along the specified axis.
:param tensor: Given NNCFTensor.
:params quantile: Percentile or sequence of percentiles to compute, which must be between
:params quantile: Quantile or sequence of quantiles to compute, which must be between
0 and 1 inclusive.
:param axis: Axis or axes along which the quantiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile(s) of the tensor elements.
"""

@classmethod
@abstractmethod
def percentile(
cls,
tensor: NNCFTensor,
percentile: Union[float, List[float]],
axis: Union[int, tuple, list],
keepdims: bool = False,
) -> List[TensorElementsType]:
"""
Compute the percentile(s) of the data along the specified axis.
:param tensor: Given NNCFTensor.
:params percentile: percentile or sequence of percentiles to compute, which must be between
0 and 100 inclusive.
:param axis: Axis or axes along which the percentiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile-th percentile(s) of the tensor elements.
:returns: List of the percentile(s) of the tensor elements.
"""

@staticmethod
Expand All @@ -289,27 +321,47 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
:return: Reduced NNCFTensor.
"""

@classmethod
@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Computes the element-wise logical OR of the given input tensors.
Zeros are treated as False and nonzeros are treated as True.
:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise or operation between input_ and other tensor.
"""

@staticmethod
def less(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Return the truth value of (x1 < x2) element-wise.
:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise less operation between input_ and other tensor.
"""

@staticmethod
@abstractmethod
def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor:
def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
"""
Returns result of a substract b operation.
"""
Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that
are smaller that alpha and bigger than 1 - alpha quantile and applies
given masked reduction function fn.

:param tensor: Given NNCFTensor.
:param fn: Masked reduce operation from the same NNCFCollectorTensorProcessor class.
:param axis: Axis along which the reduction function is computed.
:params alpha: Minimal percentile to filter outliers outside the range
[quantile(alpha), quantile(1 - alpha)]. Must be between 0 and 1. inclusive.
:returns: Result of given masked reduction function on filtered from outliers NNCFTensor.
@classmethod
@abstractmethod
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
"""
Returns binary mask from the input x which equal true for all elemets that are smaller than
corresponding machine epsilon.
"""


class MinMaxStatisticCollector(OnlineTensorStatisticCollector):
"""Collector estimates min of minimum values and max of maximum values."""

def __init__(self, use_abs_max: bool, reduction_shape: ReductionShape, num_samples: int = None):
def __init__(self, use_abs_max: bool, reduction_shape: ReductionAxes, num_samples: int = None):
super().__init__(reduction_shape, num_samples)
self._use_abs_max = use_abs_max
self._tensor_processor = self._get_processor()
Expand Down Expand Up @@ -353,7 +405,7 @@ def __init__(
self,
use_per_sample_stats: bool,
use_abs_max: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -407,7 +459,7 @@ def __init__(
use_abs_max: bool,
use_means_of_mins: bool,
use_means_of_maxs: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -447,17 +499,15 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector):
Collector that aggregates statistics as mean along a pre-assigned axis.
"""

def __init__(
self, reduction_shape: ReductionShape, num_samples: Optional[int] = None, window_size: Optional[int] = None
) -> None:
def __init__(self, channel_axis: int, num_samples: Optional[int] = None, window_size: Optional[int] = None) -> None:
"""
:param reduction_shape: The shape for the reduction while statistics collection.
For the MeanStatisticCollector this parameter contains the main axis.
:param channel_axis: The main axis for the reduction while statistics collection.
:param num_samples: Optional parameter for statistic collection that regulates
the number of samples that will be processed.
:param window_size: Optional maximum length for the statistic collection
"""
super().__init__(reduction_shape, num_samples)
super().__init__(num_samples=num_samples)
self._channel_axis = channel_axis
self._tensor_processor = self._get_processor()
self._all_values = deque(maxlen=window_size)
self._all_shapes = deque(maxlen=window_size)
Expand All @@ -468,10 +518,10 @@ def _get_processor():
pass

def _register_input_common(self, x: NNCFTensor):
if self._reduction_shape == 0:
if self._channel_axis == 0:
self._all_values.append(self._tensor_processor.batch_mean(x))
else:
self._all_values.append(self._tensor_processor.mean_per_channel(x, self._reduction_shape))
self._all_values.append(self._tensor_processor.mean_per_channel(x, self._channel_axis))
self._all_shapes.append(x.shape)

def _reset(self):
Expand Down Expand Up @@ -536,7 +586,7 @@ class PercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand All @@ -561,7 +611,7 @@ class MeanPercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class TensorStatistic(ABC):
"""Base class that stores statistic data"""

TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output"

@staticmethod
@abstractmethod
def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool:
Expand Down Expand Up @@ -63,6 +65,9 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool:


class MedianMADTensorStatistic(TensorStatistic):
MEDIAN_VALUES_STAT = "median_values"
MAD_VALUES_STAT = "mad_values"

def __init__(self, median_values, mad_values):
self.median_values = median_values
self.mad_values = mad_values
Expand All @@ -74,6 +79,8 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool:


class PercentileTensorStatistic(TensorStatistic):
PERCENTILE_VS_VALUE_DICT = "percentile_vs_values_dict"

def __init__(self, percentile_vs_values_dict):
self.percentile_vs_values_dict = percentile_vs_values_dict

Expand Down
Loading

0 comments on commit bf9c611

Please sign in to comment.