Skip to content

Commit

Permalink
[Torch] Experimental tensor collector is using for statistic collection
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 26, 2023
1 parent def8009 commit bdfd115
Show file tree
Hide file tree
Showing 51 changed files with 2,151 additions and 762 deletions.
110 changes: 81 additions & 29 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from abc import ABC
from abc import abstractmethod
from collections import deque
from typing import Callable, List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np

Expand All @@ -21,14 +21,13 @@
from nncf.common.tensor import TensorType
from nncf.common.tensor_statistics.reduction import get_per_channel_history

ReductionShape = Tuple[int]
MaskedReduceFN = Callable[[NNCFTensor, Union[int, tuple, list], NNCFTensor, bool], NNCFTensor]
ReductionAxes = Tuple[int]


class TensorStatisticCollectorBase(ABC):
"""Collector estimate statistics at the quantization point based on the provided reduction shape."""

def __init__(self, reduction_shape: Optional[ReductionShape] = None, num_samples: Optional[int] = None):
def __init__(self, reduction_shape: Optional[ReductionAxes] = None, num_samples: Optional[int] = None):
"""
Initializes Tensor Statistic Collector
Expand Down Expand Up @@ -101,7 +100,7 @@ class OfflineTensorStatisticCollector(TensorStatisticCollectorBase):
"""Collects statistics in offline regime by storing the data and aggregating it afterwards."""

def __init__(
self, reduction_shape: Optional[ReductionShape] = None, num_samples: int = None, window_size: int = None
self, reduction_shape: Optional[ReductionAxes] = None, num_samples: int = None, window_size: int = None
):
super().__init__(reduction_shape, num_samples)
self._samples = deque(maxlen=window_size)
Expand Down Expand Up @@ -199,9 +198,9 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
"""
Computes the masked mean of elements across given dimensions of NNCFTensor.
Expand All @@ -214,9 +213,11 @@ def masked_mean(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor,
:return: Reduced NNCFTensor.
"""

@staticmethod
@classmethod
@abstractmethod
def masked_median(x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
def masked_median(
cls, x: NNCFTensor, axis: Union[int, tuple, list], mask: NNCFTensor, keepdims=False
) -> NNCFTensor:
"""
Computes the masked median of elements across given dimensions of NNCFTensor.
Expand Down Expand Up @@ -251,6 +252,16 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]:
:return: List of NNCFTensor.
"""

@staticmethod
@abstractmethod
def squeeze(x: NNCFTensor, dim: Optional[Union[int, Tuple[int, ...]]] = None) -> NNCFTensor:
"""
Remove axes of length one from x.
:param x: NNCFTensor to squeeze.
:param axis: Selects a subset of the entries of length one in the shape.
"""

@staticmethod
@abstractmethod
def sum(tensor: NNCFTensor) -> TensorElementsType:
Expand All @@ -267,15 +278,36 @@ def quantile(
tensor: NNCFTensor, quantile: Union[float, List[float]], axis: Union[int, tuple, list], keepdims: bool = False
) -> List[TensorElementsType]:
"""
Compute the quantile-th percentile(s) of the data along the specified axis.
Compute the quantile(s) of the data along the specified axis.
:param tensor: Given NNCFTensor.
:params quantile: Percentile or sequence of percentiles to compute, which must be between
:params quantile: Quantile or sequence of quantiles to compute, which must be between
0 and 1 inclusive.
:param axis: Axis or axes along which the quantiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile(s) of the tensor elements.
"""

@classmethod
@abstractmethod
def percentile(
cls,
tensor: NNCFTensor,
percentile: Union[float, List[float]],
axis: Union[int, tuple, list],
keepdims: bool = False,
) -> List[TensorElementsType]:
"""
Compute the percentile(s) of the data along the specified axis.
:param tensor: Given NNCFTensor.
:params percentile: percentile or sequence of percentiles to compute, which must be between
0 and 100 inclusive.
:param axis: Axis or axes along which the percentiles are computed.
:param keepdims: If True, the axes which are reduced are left in the result
as dimensions with size one.
:returns: List of the quantile-th percentile(s) of the tensor elements.
:returns: List of the percentile(s) of the tensor elements.
"""

@staticmethod
Expand All @@ -289,27 +321,47 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
:return: Reduced NNCFTensor.
"""

@classmethod
@staticmethod
def logical_or(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Computes the element-wise logical OR of the given input tensors.
Zeros are treated as False and nonzeros are treated as True.
:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise or operation between input_ and other tensor.
"""

@staticmethod
def less(input_: NNCFTensor, other: NNCFTensor) -> NNCFTensor:
"""
Return the truth value of (x1 < x2) element-wise.
:param input_: The input tensor.
:param other: The tensor to compute or with.
:return: Result of elementwise less operation between input_ and other tensor.
"""

@staticmethod
@abstractmethod
def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor:
def sub(a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
"""
Returns result of a substract b operation.
"""
Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that
are smaller that alpha and bigger than 1 - alpha quantile and applies
given masked reduction function fn.

:param tensor: Given NNCFTensor.
:param fn: Masked reduce operation from the same NNCFCollectorTensorProcessor class.
:param axis: Axis along which the reduction function is computed.
:params alpha: Minimal percentile to filter outliers outside the range
[quantile(alpha), quantile(1 - alpha)]. Must be between 0 and 1. inclusive.
:returns: Result of given masked reduction function on filtered from outliers NNCFTensor.
@classmethod
@abstractmethod
def zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
"""
Returns binary mask from the input x which equal true for all elemets that are smaller than
corresponding machine epsilon.
"""


class MinMaxStatisticCollector(OnlineTensorStatisticCollector):
"""Collector estimates min of minimum values and max of maximum values."""

def __init__(self, use_abs_max: bool, reduction_shape: ReductionShape, num_samples: int = None):
def __init__(self, use_abs_max: bool, reduction_shape: ReductionAxes, num_samples: int = None):
super().__init__(reduction_shape, num_samples)
self._use_abs_max = use_abs_max
self._tensor_processor = self._get_processor()
Expand Down Expand Up @@ -353,7 +405,7 @@ def __init__(
self,
use_per_sample_stats: bool,
use_abs_max: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -407,7 +459,7 @@ def __init__(
use_abs_max: bool,
use_means_of_mins: bool,
use_means_of_maxs: bool,
reduction_shape: ReductionShape,
reduction_shape: ReductionAxes,
num_samples: int = None,
window_size: int = None,
):
Expand Down Expand Up @@ -448,7 +500,7 @@ class MeanStatisticCollector(OfflineTensorStatisticCollector):
"""

def __init__(
self, reduction_shape: ReductionShape, num_samples: Optional[int] = None, window_size: Optional[int] = None
self, reduction_shape: ReductionAxes, num_samples: Optional[int] = None, window_size: Optional[int] = None
) -> None:
"""
:param reduction_shape: The shape for the reduction while statistics collection.
Expand Down Expand Up @@ -536,7 +588,7 @@ class PercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand All @@ -561,7 +613,7 @@ class MeanPercentileStatisticCollector(OfflineTensorStatisticCollector):
def __init__(
self,
percentiles_to_collect: List[float],
reduction_shape: Optional[ReductionShape] = None,
reduction_shape: Optional[ReductionAxes] = None,
num_samples: int = None,
window_size: int = None,
):
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class TensorStatistic(ABC):
"""Base class that stores statistic data"""

TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output"

@staticmethod
@abstractmethod
def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool:
Expand Down Expand Up @@ -63,6 +65,9 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool:


class MedianMADTensorStatistic(TensorStatistic):
MEDIAN_VALUES_STAT = "median_values"
MAD_VALUES_STAT = "mad_values"

def __init__(self, median_values, mad_values):
self.median_values = median_values
self.mad_values = mad_values
Expand All @@ -74,6 +79,8 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool:


class PercentileTensorStatistic(TensorStatistic):
PERCENTILE_VS_VALUE_DICT = "percentile_vs_values_dict"

def __init__(self, percentile_vs_values_dict):
self.percentile_vs_values_dict = percentile_vs_values_dict

Expand Down
Loading

0 comments on commit bdfd115

Please sign in to comment.