diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index e49b6126aa4..466046fe0ca 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -340,10 +340,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor: :return: Stacked Tensor. """ if isinstance(x, List): - unwrapped_x = [i.data for i in x] - # singledispatch cannot dispatch function by element in a list - res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis) - return Tensor(res) + return Tensor(_dispatch_list(stack, x, axis=axis)) raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index cdaadd075d6..a4f79855179 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -376,9 +376,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc retval[nncf_module_scope + relative_scope] = target_module return retval - def update_model_ref(self, model: torch.nn.Module) -> None: - object.__setattr__(self, "__model_ref", model) - def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): hook_addresses = self.insert_at_point(point, fn_list) self._temprorary_hooks_adresses.append(hook_addresses) diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index 4f5e69a4b81..140df6de094 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -35,5 +35,5 @@ def __init__( def __call__(self, *args, **kwargs): if self.debug_interface is not None: - self.debug_interface.register_activation_quantize_call(str(self.quantizer_storage_key)) + self.debug_interface.register_activation_quantize_call(str(self._storage_key)) return super().__call__(*args, **kwargs) diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index fa49cf3cca3..65beac840be 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from typing import Dict import numpy as np @@ -27,41 +26,9 @@ from nncf.torch.tensor_statistics.algo import create_register_input_hook -class ModelView: - def __init__(self, model: NNCFNetwork): - self.model = model - self.nncf_module_additions = self.model.nncf.save_nncf_module_additions() - - def __enter__(self): - # Model ref removed to prevent copying - self.model.nncf.update_model_ref(None) - - # nncf_replaced_models removed to prevent copying - replaced_modules = self.model.nncf._nncf_replaced_modules - self.model.nncf._nncf_replaced_modules = None - - self.nncf_interface = deepcopy(self.model.nncf) - - # Model ref is recovering - self.model.nncf.update_model_ref(self.model) - self.nncf_interface.update_model_ref(self.model) - - # nncf_replaced_models is recovering - self.model.nncf._nncf_replaced_modules = replaced_modules - self.nncf_interface._nncf_replaced_modules = replaced_modules - return self.model - - def __exit__(self, exc_type, exc_val, exc_tb): - self.model._nncf = self.nncf_interface - self.model.nncf.reset_nncf_modules() - self.model.nncf.load_nncf_module_additions(self.nncf_module_additions) - - class PTStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: with torch.no_grad(): - # with ModelView(model) as intermediate_model: - # super().collect_statistics(intermediate_model, graph) super().collect_statistics(model, graph) model.nncf.remove_temporary_ops() diff --git a/tests/common/experimental/test_statistic_collector.py b/tests/common/experimental/test_statistic_collector.py index 3caad7fcd9a..d0a1344a93f 100644 --- a/tests/common/experimental/test_statistic_collector.py +++ b/tests/common/experimental/test_statistic_collector.py @@ -329,7 +329,7 @@ class BadStatContainer: class TemplateTestStatisticCollector: @abstractmethod - def get_nncf_tensor_cls(self): + def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor: pass @abstractmethod @@ -366,10 +366,10 @@ def test_empty_tensors_register(self, inplace, any_not_empty): collector.register_statistic_branch("A", reducer, aggregator) input_name = "input_name" full_inputs = TensorCollector.get_tensor_collector_inputs( - {input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])] + {input_name: self.get_nncf_tensor(np.array([100]))}, [(hash(reducer), [input_name])] ) empty_inputs = TensorCollector.get_tensor_collector_inputs( - {input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])] + {input_name: self.get_nncf_tensor(np.array([]))}, [(hash(reducer), [input_name])] ) stats = collector.get_statistics() @@ -385,7 +385,7 @@ def test_empty_tensors_register(self, inplace, any_not_empty): assert aggregator._collected_samples == 2 stats = collector.get_statistics() assert len(stats) == 1 - assert stats["A"] == self.get_nncf_tensor_cls()([100]) + assert stats["A"] == self.get_nncf_tensor([100]) return assert len(aggregator._container) == 0 diff --git a/tests/openvino/native/test_statistic_collector.py b/tests/openvino/native/test_statistic_collector.py index 32123644942..73d3534bd50 100644 --- a/tests/openvino/native/test_statistic_collector.py +++ b/tests/openvino/native/test_statistic_collector.py @@ -11,8 +11,10 @@ from typing import Type +import numpy as np import pytest +from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistics import MeanTensorStatistic from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic @@ -26,8 +28,8 @@ class TestOVStatisticCollector(TemplateTestStatisticCollector): - def get_nncf_tensor_cls(self): - return OVNNCFTensor + def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor: + return OVNNCFTensor(value) @pytest.fixture def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: diff --git a/tests/torch/ptq/test_statistic_collector.py b/tests/torch/ptq/test_statistic_collector.py index 0ab1ef2bb55..96b0963194b 100644 --- a/tests/torch/ptq/test_statistic_collector.py +++ b/tests/torch/ptq/test_statistic_collector.py @@ -11,8 +11,11 @@ from typing import Type +import numpy as np import pytest +import torch +from nncf.common.tensor import NNCFTensor from nncf.common.tensor_statistics.statistics import MeanTensorStatistic from nncf.common.tensor_statistics.statistics import MedianMADTensorStatistic from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic @@ -26,9 +29,9 @@ from tests.common.experimental.test_statistic_collector import TemplateTestStatisticCollector -class TestOVStatisticCollector(TemplateTestStatisticCollector): - def get_nncf_tensor_cls(self): - return PTNNCFTensor +class TestPTStatisticCollector(TemplateTestStatisticCollector): + def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor: + return PTNNCFTensor(torch.tensor(value)) @pytest.fixture def min_max_statistic_cls(self) -> Type[MinMaxTensorStatistic]: