From 7982bf48bad3cb3a851b1a07579bb0649ef4f4b4 Mon Sep 17 00:00:00 2001 From: Aleksei Kashapov Date: Fri, 18 Oct 2024 13:24:02 +0200 Subject: [PATCH] comments --- .../common/tensor_statistics/statistics.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/nncf/experimental/common/tensor_statistics/statistics.py b/nncf/experimental/common/tensor_statistics/statistics.py index e5aefe9254c..1cc236298b5 100644 --- a/nncf/experimental/common/tensor_statistics/statistics.py +++ b/nncf/experimental/common/tensor_statistics/statistics.py @@ -115,12 +115,8 @@ def __eq__(self, other: TensorStatistic): @classmethod def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic: return cls( - median_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ - MedianMADTensorStatistic.MEDIAN_VALUES_STAT - ], - mad_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][ - MedianMADTensorStatistic.MAD_VALUES_STAT - ], + median_values=kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MEDIAN_VALUES_STAT], + mad_values=kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MAD_VALUES_STAT], ) @@ -146,8 +142,8 @@ def __eq__(self, other: TensorStatistic): @classmethod def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic: - if PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY in kwargs: - percentile_vs_values_dict = kwargs[PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY] + if cls.TENSOR_STATISTIC_OUTPUT_KEY in kwargs: + percentile_vs_values_dict = kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY] else: percentile_vs_values_dict = {} for (_, percentile), value in kwargs.items(): @@ -259,8 +255,8 @@ def __eq__(self, other: Any) -> bool: @classmethod def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic: mean_values, shape_values = None, None - if WCTensorStatistic.MEAN_STAT in kwargs and kwargs[WCTensorStatistic.MEAN_STAT] is not None: - mean_values = [fns.squeeze(it) for it in kwargs[WCTensorStatistic.MEAN_STAT]] - if WCTensorStatistic.SHAPE_STAT in kwargs and kwargs[WCTensorStatistic.SHAPE_STAT] is not None: - shape_values = [tuple(it.data) for it in kwargs[WCTensorStatistic.SHAPE_STAT]] + if cls.MEAN_STAT in kwargs and kwargs[cls.MEAN_STAT] is not None: + mean_values = [fns.squeeze(it) for it in kwargs[cls.MEAN_STAT]] + if cls.SHAPE_STAT in kwargs and kwargs[cls.SHAPE_STAT] is not None: + shape_values = [tuple(it.data) for it in kwargs[cls.SHAPE_STAT]] return cls(mean_values=mean_values, shape_values=shape_values)