Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 6, 2023
1 parent fa54ae4 commit 01c7371
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 50 deletions.
5 changes: 1 addition & 4 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
return Tensor(_dispatch_list(stack, x, axis=axis))
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down
3 changes: 0 additions & 3 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
retval[nncf_module_scope + relative_scope] = target_module
return retval

def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def __init__(

def __call__(self, *args, **kwargs):
if self.debug_interface is not None:
self.debug_interface.register_activation_quantize_call(str(self.quantizer_storage_key))
self.debug_interface.register_activation_quantize_call(str(self._storage_key))
return super().__call__(*args, **kwargs)
33 changes: 0 additions & 33 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -27,41 +26,9 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model
self.nncf_module_additions = self.model.nncf.save_nncf_module_additions()

def __enter__(self):
# Model ref removed to prevent copying
self.model.nncf.update_model_ref(None)

# nncf_replaced_models removed to prevent copying
replaced_modules = self.model.nncf._nncf_replaced_modules
self.model.nncf._nncf_replaced_modules = None

self.nncf_interface = deepcopy(self.model.nncf)

# Model ref is recovering
self.model.nncf.update_model_ref(self.model)
self.nncf_interface.update_model_ref(self.model)

# nncf_replaced_models is recovering
self.model.nncf._nncf_replaced_modules = replaced_modules
self.nncf_interface._nncf_replaced_modules = replaced_modules
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface
self.model.nncf.reset_nncf_modules()
self.model.nncf.load_nncf_module_additions(self.nncf_module_additions)


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

Expand Down
8 changes: 4 additions & 4 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class BadStatContainer:

class TemplateTestStatisticCollector:
@abstractmethod
def get_nncf_tensor_cls(self):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
pass

@abstractmethod
Expand Down Expand Up @@ -366,10 +366,10 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([]))}, [(hash(reducer), [input_name])]
)

stats = collector.get_statistics()
Expand All @@ -385,7 +385,7 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
assert aggregator._collected_samples == 2
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] == self.get_nncf_tensor_cls()([100])
assert stats["A"] == self.get_nncf_tensor([100])
return

assert len(aggregator._container) == 0
Expand Down
6 changes: 4 additions & 2 deletions tests/openvino/native/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from typing import Type

import numpy as np
import pytest

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
Expand All @@ -26,8 +28,8 @@


class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return OVNNCFTensor
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
return OVNNCFTensor(value)

@pytest.fixture
def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]:
Expand Down
9 changes: 6 additions & 3 deletions tests/torch/ptq/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from typing import Type

import numpy as np
import pytest
import torch

from nncf.common.tensor import NNCFTensor
from nncf.common.tensor_statistics.statistics import MeanTensorStatistic
from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic
from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic
Expand All @@ -26,9 +29,9 @@
from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector


class TestOVStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor_cls(self):
return PTNNCFTensor
class TestPTStatisticCollector(TemplateTestStatisticCollector):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
return PTNNCFTensor(torch.tensor(value))

@pytest.fixture
def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]:
Expand Down

0 comments on commit 01c7371

Please sign in to comment.