Skip to content

Commit

Permalink
torch.median -> torch.quantile + minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 19, 2023
1 parent f6b2d71 commit 17e86f5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
15 changes: 12 additions & 3 deletions nncf/torch/quantization/init_range.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Callable, Dict, List, Tuple

import numpy as np
import torch

from nncf.common.graph.layer_attributes import WeightedLayerAttributes
from nncf.common.quantization.initialization.range import RangeInitCollectorParams
Expand Down Expand Up @@ -277,6 +278,16 @@ def __init__(
self.hook_handles = []
self.batch_size = batch_size

def _get_fwd_hook(
self, collector: TensorStatisticCollectorBase
) -> Callable[[torch.Module, torch.Tensor, torch.Tensor], torch.Tensor]:
hook = register_inputs_hook_factory(collector=collector)

def fwd_hook(module, input_, output):
hook(input_[0])

return fwd_hook

def _prepare_initialization(self):
for name, data in self.modules_to_init.items():
quantizer_module, init_config, is_weights, input_shape = data
Expand Down Expand Up @@ -311,9 +322,7 @@ def _prepare_initialization(self):

self.collectors_and_modules_to_init[name] = collector, quantizer_module

self.hook_handles.append(
quantizer_module.register_forward_hook(register_inputs_hook_factory(collector=collector))
)
self.hook_handles.append(quantizer_module.register_forward_hook(self._get_fwd_hook(collector)))

def _apply_initializers(self):
for handle in self.hook_handles:
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/tensor_statistics/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def median(x: NNCFTensor, axis: Union[int, tuple, list], keepdims=False) -> NNCF
# See https://github.com/pytorch/pytorch/issues/61582
if not isinstance(axis, int):
return PTNNCFTensor(torch.tensor(np.median(x.tensor.detach().cpu().numpy(), axis=axis, keepdims=keepdims)))
return PTNNCFTensor(x.tensor.median(dim=axis, keepdim=keepdims).values)
return PTNNCFTensor(torch.quantile(x.tensor, q=0.5, dim=axis, keepdim=keepdims).values)

@classmethod
def masked_mean(cls, x: NNCFTensor, axis: Union[int, tuple], mask: NNCFTensor, keepdims=False) -> NNCFTensor:
Expand Down

0 comments on commit 17e86f5

Please sign in to comment.