From 17e86f57b94d82253c8017f27895e7b045f11515 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 19 Sep 2023 11:20:12 +0200 Subject: [PATCH] torch.median -> torch.quantile + minor fix --- nncf/torch/quantization/init_range.py | 15 ++++++++++++--- nncf/torch/tensor_statistics/collectors.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/nncf/torch/quantization/init_range.py b/nncf/torch/quantization/init_range.py index 3a593fe3ec3..8a9d7c971c8 100644 --- a/nncf/torch/quantization/init_range.py +++ b/nncf/torch/quantization/init_range.py @@ -14,6 +14,7 @@ from typing import Callable, Dict, List, Tuple import numpy as np +import torch from nncf.common.graph.layer_attributes import WeightedLayerAttributes from nncf.common.quantization.initialization.range import RangeInitCollectorParams @@ -277,6 +278,16 @@ def __init__( self.hook_handles = [] self.batch_size = batch_size + def _get_fwd_hook( + self, collector: TensorStatisticCollectorBase + ) -> Callable[[torch.Module, torch.Tensor, torch.Tensor], torch.Tensor]: + hook = register_inputs_hook_factory(collector=collector) + + def fwd_hook(module, input_, output): + hook(input_[0]) + + return fwd_hook + def _prepare_initialization(self): for name, data in self.modules_to_init.items(): quantizer_module, init_config, is_weights, input_shape = data @@ -311,9 +322,7 @@ def _prepare_initialization(self): self.collectors_and_modules_to_init[name] = collector, quantizer_module - self.hook_handles.append( - quantizer_module.register_forward_hook(register_inputs_hook_factory(collector=collector)) - ) + self.hook_handles.append(quantizer_module.register_forward_hook(self._get_fwd_hook(collector))) def _apply_initializers(self): for handle in self.hook_handles: diff --git a/nncf/torch/tensor_statistics/collectors.py b/nncf/torch/tensor_statistics/collectors.py index 0eb6afcba06..c6c760afe55 100644 --- a/nncf/torch/tensor_statistics/collectors.py +++ b/nncf/torch/tensor_statistics/collectors.py @@ -81,7 +81,7 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF # See https://github.com/pytorch/pytorch/issues/61582 if not isinstance(axis, int): return PTNNCFTensor(torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims))) - return PTNNCFTensor(x.tensor.median(dim=axis, keepdim=keepdims).values) + return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values) @classmethod def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple], mask: NNCFTensor, keepdims=False) -> NNCFTensor: