Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
kshpv committed Oct 18, 2024
1 parent 17752eb commit 7982bf4
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions nncf/experimental/common/tensor_statistics/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,8 @@ def __eq__(self, other: TensorStatistic):
@classmethod
def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic:
return cls(
median_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][
MedianMADTensorStatistic.MEDIAN_VALUES_STAT
],
mad_values=kwargs[MedianMADTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY][
MedianMADTensorStatistic.MAD_VALUES_STAT
],
median_values=kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MEDIAN_VALUES_STAT],
mad_values=kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY][cls.MAD_VALUES_STAT],
)


Expand All @@ -146,8 +142,8 @@ def __eq__(self, other: TensorStatistic):

@classmethod
def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic:
if PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY in kwargs:
percentile_vs_values_dict = kwargs[PercentileTensorStatistic.TENSOR_STATISTIC_OUTPUT_KEY]
if cls.TENSOR_STATISTIC_OUTPUT_KEY in kwargs:
percentile_vs_values_dict = kwargs[cls.TENSOR_STATISTIC_OUTPUT_KEY]
else:
percentile_vs_values_dict = {}
for (_, percentile), value in kwargs.items():
Expand Down Expand Up @@ -259,8 +255,8 @@ def __eq__(self, other: Any) -> bool:
@classmethod
def from_kwargs(cls, kwargs: Dict[str, Any]) -> TensorStatistic:
mean_values, shape_values = None, None
if WCTensorStatistic.MEAN_STAT in kwargs and kwargs[WCTensorStatistic.MEAN_STAT] is not None:
mean_values = [fns.squeeze(it) for it in kwargs[WCTensorStatistic.MEAN_STAT]]
if WCTensorStatistic.SHAPE_STAT in kwargs and kwargs[WCTensorStatistic.SHAPE_STAT] is not None:
shape_values = [tuple(it.data) for it in kwargs[WCTensorStatistic.SHAPE_STAT]]
if cls.MEAN_STAT in kwargs and kwargs[cls.MEAN_STAT] is not None:
mean_values = [fns.squeeze(it) for it in kwargs[cls.MEAN_STAT]]
if cls.SHAPE_STAT in kwargs and kwargs[cls.SHAPE_STAT] is not None:
shape_values = [tuple(it.data) for it in kwargs[cls.SHAPE_STAT]]
return cls(mean_values=mean_values, shape_values=shape_values)

0 comments on commit 7982bf4

Please sign in to comment.