Skip to content

Commit

Permalink
Refactor Torch/Torch PTQ to use experimental TensorCollector
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 7, 2023
1 parent c07810b commit 681c357
Show file tree
Hide file tree
Showing 30 changed files with 1,255 additions and 545 deletions.
38 changes: 36 additions & 2 deletions nncf/common/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, reduction_shape: Optional[ReductionShape] = None, num_samples
def num_samples(self) -> int:
return self._num_samples

def register_input(self, x: TensorType) -> TensorType:
def register_inputs(self, x: TensorType) -> TensorType:
"""Registers input tensor"""
if not self._enabled:
return x
Expand Down Expand Up @@ -251,6 +251,11 @@ def unstack(x: NNCFTensor, axis: int = 0) -> List[NNCFTensor]:
:return: List of NNCFTensor.
"""

@staticmethod
@abstractmethod
def squeeze(x: NNCFTensor, dim: Optional[int] = None) -> NNCFTensor:
""""""

@staticmethod
@abstractmethod
def sum(tensor: NNCFTensor) -> TensorElementsType:
Expand Down Expand Up @@ -278,6 +283,17 @@ def quantile(
:returns: List of the quantile-th percentile(s) of the tensor elements.
"""

@classmethod
@abstractmethod
def precentile(
cls,
tensor: NNCFTensor,
precentile: Union[float, List[float]],
axis: Union[int, tuple, list],
keepdims: bool = False,
) -> List[TensorElementsType]:
""""""

@staticmethod
@abstractmethod
def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:
Expand All @@ -291,7 +307,9 @@ def mean_per_channel(x: NNCFTensor, axis: int) -> NNCFTensor:

@classmethod
@abstractmethod
def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha: float = 0.01) -> NNCFTensor:
def no_outliers_map(
cls, x: NNCFTensor, fn: MaskedReduceFN, axis: Union[int, Tuple[int, ...]] = 0, alpha: float = 0.01
) -> NNCFTensor:
"""
Computes quantiles [alpha, 1 - alpha] on given tensor, masks all elements that
are smaller that alpha and bigger than 1 - alpha quantile and applies
Expand All @@ -305,6 +323,22 @@ def no_outliers_map(cls, x: NNCFTensor, fn: MaskedReduceFN, axis: int = 0, alpha
:returns: Result of given masked reduction function on filtered from outliers NNCFTensor.
"""

@classmethod
def masked_map(cls, x: NNCFTensor, fn: MaskedReduceFN, filter_fn) -> NNCFTensor:
""" """

@classmethod
def sub(cls, a: NNCFTensor, b: NNCFTensor) -> NNCFTensor:
""""""

@classmethod
def filter_by_fn(cls, x: NNCFTensor, filter_fn) -> NNCFTensor:
""" """

@classmethod
def non_zero_elements(cls, x: NNCFTensor) -> NNCFTensor:
""" """


class MinMaxStatisticCollector(OnlineTensorStatisticCollector):
"""Collector estimates min of minimum values and max of maximum values."""
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/tensor_statistics/statistic_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __eq__(self, other):
def register_tensor(self, x: TensorType):
for tensor_collectors in self.algorithm_to_tensor_collectors.values():
for tensor_collector in tensor_collectors:
tensor_collector.register_input(x)
tensor_collector.register_unnamed_inputs(x)


class StatisticPointsContainer(UserDict):
Expand Down
7 changes: 7 additions & 0 deletions nncf/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class TensorStatistic(ABC):
"""Base class that stores statistic data"""

TENSOR_STATISTIC_OUTPUT_KEY = "tensor_statistic_output"

@staticmethod
@abstractmethod
def tensor_eq(tensor1: TensorType, tensor2: TensorType, rtol=1e-6) -> bool:
Expand Down Expand Up @@ -63,6 +65,9 @@ def __eq__(self, other: "MeanTensorStatistic") -> bool:


class MedianMADTensorStatistic(TensorStatistic):
MEDIAN_VALUES_STAT = "median_values"
MAD_VALUES_STAT = "mad_values"

def __init__(self, median_values, mad_values):
self.median_values = median_values
self.mad_values = mad_values
Expand All @@ -74,6 +79,8 @@ def __eq__(self, other: "MedianMADTensorStatistic") -> bool:


class PercentileTensorStatistic(TensorStatistic):
PRECENTILE_VS_VALUE_DICT = "percentile_vs_values_dict"

def __init__(self, percentile_vs_values_dict):
self.percentile_vs_values_dict = percentile_vs_values_dict

Expand Down
Loading

0 comments on commit 681c357

Please sign in to comment.