From 08661a5adb86ca5df42370373fbb3d8832b160ad Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 5 Dec 2023 11:28:41 +0100 Subject: [PATCH] Test finishing Cleanup Comments --- nncf/common/graph/transformations/commands.py | 1 - nncf/experimental/tensor/functions.py | 5 +- .../algorithms/smooth_quant/algorithm.py | 12 +- .../algorithms/smooth_quant/backend.py | 34 +++- .../smooth_quant/openvino_backend.py | 22 +-- .../algorithms/smooth_quant/torch_backend.py | 16 +- nncf/torch/dynamic_graph/context.py | 2 +- nncf/torch/nncf_network.py | 32 +--- nncf/torch/quantization/external_quantizer.py | 1 - nncf/torch/statistics/aggregator.py | 33 ---- nncf/torch/tensor.py | 7 +- tests/torch/helpers.py | 33 +++- tests/torch/test_model_transformer.py | 33 ++-- tests/torch/test_nncf_network.py | 53 ++++++ tests/torch/test_statistics_aggregator.py | 175 +++++++++++++++--- 15 files changed, 314 insertions(+), 145 deletions(-) diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index 11734b1f152..fa26e587ec7 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -35,7 +35,6 @@ class TransformationPriority(IntEnum): FP32_TENSOR_STATISTICS_OBSERVATION = 1 PRUNING_PRIORITY = 2 SPARSIFICATION_PRIORITY = 3 - OP_INSERTION_PRIORITY = 4 QUANTIZATION_PRIORITY = 11 diff --git a/nncf/experimental/tensor/functions.py b/nncf/experimental/tensor/functions.py index 6bf73135de2..cd0647b742d 100644 --- a/nncf/experimental/tensor/functions.py +++ b/nncf/experimental/tensor/functions.py @@ -342,10 +342,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor: :return: Stacked Tensor. """ if isinstance(x, List): - unwrapped_x = [i.data for i in x] - # singledispatch cannot dispatch function by element in a list - res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis) - return Tensor(res) + return Tensor(_dispatch_list(stack, x, axis=axis)) raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}") diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index e9195405cf9..f31c2e57e20 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -129,7 +129,14 @@ def apply( if any(val.data is None for val in activations_value): empty_statistic = True break - assert len(activations_value) == 1 + if len(activations_value) != 1: + raise RuntimeError( + ( + "More than one statistic is collected for one node during" + f"Smooth Quanti algorithm: {node_to_smooth.node_name}" + ) + ) + activations_value = self._clip_statistics(activations_value) weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) @@ -254,6 +261,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin for node_data in nodes_to_smooth_data: node_to_smooth = node_data["node_to_smooth"] target_point = self._backend_entity.target_point( + target_type=self._backend_entity.pre_layer_target_type(), target_node_name=node_to_smooth.node_name, port_id=node_data["input_act_port"], ) @@ -306,7 +314,7 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[ nodes_to_smooth_data.append( { "node_to_smooth": node_with_weight, - "input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph), + "input_act_port": activation_port_id, } ) return nodes_to_smooth_data diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index 05aae4b0ac3..57440e1f371 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -11,13 +11,15 @@ from abc import ABC from abc import abstractmethod -from typing import List, Tuple, TypeVar +from typing import Callable, List, Tuple, TypeVar from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetPoint +from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.tensor import Tensor @@ -55,10 +57,20 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: @staticmethod @abstractmethod - def target_point(target_node_name: str, port_id: int) -> TargetPoint: + def pre_layer_target_type() -> TargetType: + """ + Returns backend-specific pre layer target type. + + :returns: Backend-specific pre layer target type. + """ + + @staticmethod + @abstractmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: """ Returns backend-specific target point. + :param target_type: Type of the location that should be modified. :param target_node_name: Name of the located node. :param port_id: Port ID of the tensor for the statistics distribution. :return: Backend-specific TargetPoint. @@ -184,10 +196,20 @@ def get_weight_channel_axis(node: NNCFNode) -> int: @staticmethod @abstractmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): - pass + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + """ + Returns true if given node shares constant with a different node. + + :param node: NNCFNode instance. + :param nncf_graph: NNCFGraph instance. + :return: Whether the given node is shares weights with a different node or not. + """ @staticmethod @abstractmethod - def get_filter_fn_for_statistics(activation_port_id: int): - pass + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: + """ + Returns backend-specific callable to filter statistic containers according to its statistic point. + + :param activation_port_id: Activation port id for the statistic collection target node. + """ diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index 20a91a36756..6746c60314f 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple import numpy as np import openvino.runtime as ov @@ -52,8 +52,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: return QUANTIZE_AGNOSTIC_OPERATIONS @staticmethod - def target_point(target_node_name: str, port_id: int) -> OVTargetPoint: - return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id) + def pre_layer_target_type() -> TargetType: + return TargetType.PRE_LAYER_OPERATION + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: + return OVTargetPoint(target_type, target_node_name, port_id) @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: @@ -92,15 +96,11 @@ def get_abs_max_channel_collector( @staticmethod def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> Tensor: - port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) + port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight) return Tensor(get_weight_value(node_with_weight, model, port_id)) @staticmethod def get_weight_tensor_port_id(node: NNCFNode) -> int: - return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node) - - @staticmethod - def _get_weight_tensor_port_id(node: NNCFNode) -> int: const_ids = node.layer_attributes.get_const_port_ids() if len(const_ids) != 1: raise RuntimeError(f"Found more than 1 port for {node.node_name} node") @@ -108,7 +108,7 @@ def _get_weight_tensor_port_id(node: NNCFNode) -> int: @staticmethod def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand: - weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) + weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight) return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id) @staticmethod @@ -154,13 +154,13 @@ def _calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: return -2 + port_id if transpose else -1 - port_id @staticmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node) weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node return len(nncf_graph.get_next_nodes(weight_node)) > 1 @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int): + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: return point.target_point.port_id == activation_port_id diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 7927461378a..a8315890771 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Callable, List, Tuple import numpy as np @@ -63,8 +63,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT[QuantizationTrait.QUANTIZATION_AGNOSTIC] @staticmethod - def target_point(target_node_name: str, port_id: int) -> PTTargetPoint: - return PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node_name, input_port_id=port_id) + def pre_layer_target_type() -> TargetType: + return TargetType.OPERATOR_PRE_HOOK + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: @@ -92,7 +96,7 @@ def get_abs_max_channel_collector( def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor: node_module = model.nncf.get_containing_module(node_with_weight.node_name) if node_module.weight is None: - return None + raise RuntimeError(f"{node_module} module has no .weight attribute.") return Tensor(node_module.weight.data) @staticmethod @@ -130,11 +134,11 @@ def get_weight_channel_axis(node: NNCFNode) -> int: return 1 @staticmethod - def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: return node.is_shared() @staticmethod - def get_filter_fn_for_statistics(activation_port_id: int): + def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]: def filter_func(point: StatisticPoint) -> bool: return True diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index c070ea4f2d6..f2504b63277 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -97,7 +97,7 @@ def __init__(self): self._post_hooks = defaultdict(OrderedDict) self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict) - self._hooks_counter = 0 + self._hooks_counter = -1 self._threading = CopySafeThreadingVars() diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index c50ef70101d..5268aba3640 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from enum import Enum from enum import IntEnum -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar +from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar import torch from torch import nn @@ -364,25 +364,6 @@ def reset_nncf_modules(self): module = self.get_module_by_scope(some_scope) module.reset() - def get_shallow_copy(self) -> "NNCFNetwork": - from nncf.torch.utils import load_module_state - from nncf.torch.utils import save_module_state - - saved_state = save_module_state(self._model_ref) - new_interface = NNCFNetworkInterface( - self._model_ref, - self._input_infos, - self._user_dummy_forward_fn, - self._wrap_inputs_fn, - self._scopes_without_shape_matching, - self._ignored_scopes, - self._target_scopes, - wrap_outputs_fn=self._wrap_outputs_fn, - ) - self._model_ref._nncf = new_interface - load_module_state(self._model_ref, saved_state) - return self._model_ref - def get_clean_shallow_copy(self) -> "NNCFNetwork": # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions # and load_nncf_module_additions to preserve these, or temporary_clean_view(). @@ -414,9 +395,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc retval[nncf_module_scope + relative_scope] = target_module return retval - def update_model_ref(self, model: torch.nn.Module) -> None: - object.__setattr__(self, "__model_ref", model) - def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): hook_addresses = self.insert_at_point(point, fn_list) self._temprorary_hooks_adresses.append(hook_addresses) @@ -832,14 +810,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork": return self.compression_controller.strip(do_copy) -class TemporaryOp: - def __init__(self, op: Callable) -> None: - self._op = op - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self._op(*args, **kwargs) - - class NNCFNetworkMeta(type): """ Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index ad065392470..68097eada08 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -14,7 +14,6 @@ from nncf.torch.quantization.debug_interface import QuantizationDebugInterface EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers" -EXTERNAL_OP_STORAGE_NAME = "external_op" EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index fa49cf3cca3..65beac840be 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from typing import Dict import numpy as np @@ -27,41 +26,9 @@ from nncf.torch.tensor_statistics.algo import create_register_input_hook -class ModelView: - def __init__(self, model: NNCFNetwork): - self.model = model - self.nncf_module_additions = self.model.nncf.save_nncf_module_additions() - - def __enter__(self): - # Model ref removed to prevent copying - self.model.nncf.update_model_ref(None) - - # nncf_replaced_models removed to prevent copying - replaced_modules = self.model.nncf._nncf_replaced_modules - self.model.nncf._nncf_replaced_modules = None - - self.nncf_interface = deepcopy(self.model.nncf) - - # Model ref is recovering - self.model.nncf.update_model_ref(self.model) - self.nncf_interface.update_model_ref(self.model) - - # nncf_replaced_models is recovering - self.model.nncf._nncf_replaced_modules = replaced_modules - self.nncf_interface._nncf_replaced_modules = replaced_modules - return self.model - - def __exit__(self, exc_type, exc_val, exc_tb): - self.model._nncf = self.nncf_interface - self.model.nncf.reset_nncf_modules() - self.model.nncf.load_nncf_module_additions(self.nncf_module_additions) - - class PTStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: with torch.no_grad(): - # with ModelView(model) as intermediate_model: - # super().collect_statistics(intermediate_model, graph) super().collect_statistics(model, graph) model.nncf.remove_temporary_ops() diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py index 908e482f889..edf234734dc 100644 --- a/nncf/torch/tensor.py +++ b/nncf/torch/tensor.py @@ -9,12 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union - import torch from nncf.common.tensor import NNCFTensor -from nncf.torch.return_types import maybe_unwrap_from_torch_return_type class PTNNCFTensor(NNCFTensor): @@ -22,13 +19,11 @@ class PTNNCFTensor(NNCFTensor): A realisation of torch tensors wrapper for common NNCF algorithms. """ - def __init__(self, tensor: Union[torch.tensor, "PTNNCFTensor", tuple]): + def __init__(self, tensor: torch.tensor): # In case somebody attempts to wrap # tensor twice if isinstance(tensor, self.__class__): tensor = tensor.tensor - else: - tensor = maybe_unwrap_from_torch_return_type(tensor) super().__init__(tensor) diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index e15f7671627..9e45ac81e95 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -15,7 +15,7 @@ from collections import defaultdict from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union +from typing import Any, Callable, Dict, List, Tuple, TypeVar, Union import numpy as np import onnx @@ -35,6 +35,7 @@ from nncf.torch.compression_method_api import PTCompressionAlgorithmController from nncf.torch.dynamic_graph.context import PreHookId from nncf.torch.dynamic_graph.io_handling import FillerInputInfo +from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.scope import Scope from nncf.torch.initialization import PTInitializingDataLoader from nncf.torch.initialization import register_default_init_args @@ -503,6 +504,9 @@ def load_exported_onnx_version( return model_proto +HookType = TypeVar("HookType") + + class HookChecker: """ Class to check pre/post hooks and pre ops are placed correctly. @@ -535,7 +539,11 @@ def _convert_to_op_address(self, target_type: TargetType, target_node_name: str, address = address_map[target_node_name] if target_type == TargetType.OPERATOR_PRE_HOOK: address = PreHookId(address, input_port_id) - elif target_type == TargetType.OPERATION_WITH_WEIGHTS: + elif target_type in [ + TargetType.OPERATION_WITH_WEIGHTS, + TargetType.PRE_LAYER_OPERATION, + TargetType.POST_LAYER_OPERATION, + ]: address = getattr(self._target_model, self._nncf_module_attr_name) return address @@ -544,6 +552,15 @@ def check_with_reference(self): Check hooks in the target model and reference hooks are matching. """ self._check_weight_update_hooks(self._ref_hooks[TargetType.OPERATION_WITH_WEIGHTS]) + + target_module = getattr(self._target_model, self._nncf_module_attr_name) + if target_module in self._ref_hooks[TargetType.PRE_LAYER_OPERATION]: + hooks = target_module.pre_ops + self._check_pre_post_op_hooks(hooks, self._ref_hooks[TargetType.PRE_LAYER_OPERATION][target_module]) + if target_module in self._ref_hooks[TargetType.POST_LAYER_OPERATION]: + hooks = target_module.post_ops + self._check_pre_post_op_hooks(hooks, self._ref_hooks[TargetType.POST_LAYER_OPERATION][target_module]) + hooks = self._target_model.nncf._compressed_context._pre_hooks self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_PRE_HOOK]) hooks = self._target_model.nncf._compressed_context._post_hooks @@ -556,7 +573,7 @@ def clear(self): self._ref_hooks.clear() @staticmethod - def _check_weight_update_hooks(ref_hooks): + def _check_weight_update_hooks(ref_hooks: Dict[torch.nn.Module, List[HookType]]): for target_module, ref_hooks_per_module in ref_hooks.items(): assert len(target_module.pre_ops) == len(ref_hooks_per_module) for actual_op, ref_op in zip(target_module.pre_ops.values(), ref_hooks_per_module): @@ -564,7 +581,15 @@ def _check_weight_update_hooks(ref_hooks): assert actual_op.op is ref_op @staticmethod - def _check_pre_post_hooks(hooks, ref_hooks): + def _check_pre_post_op_hooks(hooks: List[torch.ModuleDict], ref_hooks: List[HookType]): + assert len(hooks) == len(ref_hooks) + for actual_hook, ref_hook in zip(hooks.values(), ref_hooks): + assert actual_hook is ref_hook + + @staticmethod + def _check_pre_post_hooks( + hooks: Dict[OperationAddress, Dict[Any, HookType]], ref_hooks: Dict[OperationAddress, List[HookType]] + ): assert len(hooks) == len(ref_hooks) for op_address, ref_hooks in ref_hooks.items(): actual_hooks = hooks[op_address].values() diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index 992a99ad914..9650cea8923 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -50,6 +50,7 @@ from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -149,8 +150,11 @@ def setup(self): point_for_relu_inputs, ] + @pytest.mark.parametrize( + "insert_method_name,check_tmp_ops", [("insert_at_point", False), ("temporary_insert_at_point", True)] + ) @pytest.mark.parametrize("target_point", available_points) - def test_single_insertions(self, setup, target_point: PTTargetPoint): + def test_single_insertions(self, setup, target_point: PTTargetPoint, insert_method_name: str, check_tmp_ops: bool): insertion_point = PTInsertionPoint( target_point.target_type, OperationAddress.from_str(target_point.target_node_name), @@ -161,7 +165,8 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): else: hook = BaseOp(lambda x: x) - self.compressed_model.nncf.insert_at_point(insertion_point, [hook]) + insert_at_point_method = getattr(self.compressed_model.nncf, insert_method_name) + insert_at_point_method(insertion_point, [hook]) if insertion_point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() @@ -178,6 +183,9 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): module = self.compressed_model.nncf.get_module_by_scope(insertion_point.module_scope) assert module.post_ops["0"] is hook + if check_tmp_ops: + assert len(self.compressed_model.nncf._temprorary_hooks_adresses) == 1 + priority_types = ["same", "different"] insertion_types = TargetType priority_test_cases = list(itertools.product(priority_types, insertion_types)) @@ -187,8 +195,9 @@ def check_order(iterable1: List, iterable2: List, ordering: List): for idx, order in enumerate(ordering): assert iterable1[idx] is iterable2[order] + @pytest.mark.parametrize("command_cls", [PTInsertionCommand, PTInsertionTemporaryCommand]) @pytest.mark.parametrize("case", priority_test_cases, ids=[x[1].name + "-" + x[0] for x in priority_test_cases]) - def test_priority(self, case, setup): + def test_priority(self, case, command_cls, setup): priority_type = case[0] insertion_type = case[1] @@ -218,14 +227,14 @@ def test_priority(self, case, setup): if priority_type == "same": # Same-priority commands will be executed in registration order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.DEFAULT_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.DEFAULT_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.DEFAULT_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.DEFAULT_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) else: # Prioritized commands will be executed in ascending priority order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) layout = PTTransformationLayout() layout.register(command1) @@ -245,10 +254,12 @@ def test_priority(self, case, setup): pre_hook_id = PreHookId( OperationAddress.from_str(point.target_node_name), input_port_id=point.input_port_id ) - self.check_order(ctx._pre_hooks[pre_hook_id], hook_list, order) + actual_pre_hooks = list(ctx._pre_hooks[pre_hook_id].values()) + self.check_order(actual_pre_hooks, hook_list, order) if insertion_type == TargetType.OPERATOR_POST_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() - self.check_order(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)], hook_list, order) + actual_post_hooks = list(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)].values()) + self.check_order(actual_post_hooks, hook_list, order) if insertion_type == TargetType.OPERATION_WITH_WEIGHTS: module = self.compressed_model.nncf.get_containing_module(point.target_node_name) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 80e96e5eefc..6ffa78e597c 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -47,6 +47,7 @@ from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel +from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test @@ -913,3 +914,55 @@ def test_insert_hook_after_parameter(): assert hook.forward_calls_counter == 1 assert torch.sum(result.nonzero()) > 0 assert torch.sum(result_with_hook.nonzero()) == 0 + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +def test_temporary_insert_at_point(target_type, target_node_name, input_port_id): + class Hook(torch.nn.Module): + def forward(self, x): + return x + + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + permanent_hook = Hook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = Hook() + nncf_model.nncf.temporary_insert_at_point(ip, [temporary_hook]) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_temporary_ops() + del ref_hooks[-2] + _check(ref_hooks) diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 40403fe29d6..5653a28ec5a 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -21,13 +21,14 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout -from nncf.common.quantization.structs import QuantizationMode +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend from nncf.quantization.range_estimator import RangeEstimatorParametersSet +from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.model_transformer import PTInsertionCommand from nncf.torch.statistics.aggregator import PTStatisticsAggregator @@ -137,7 +138,7 @@ def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_ pass @pytest.mark.parametrize( - "test_parameters, ", + "test_parameters", ( MinMaxTestParameters( RangeEstimatorParametersSet.MINMAX, @@ -189,21 +190,17 @@ def test_successive_statistics_aggregation( def fn(x): return x * 2 - layout = TransformationLayout() target_point = self.get_target_point(test_parameters.target_type) - command = PTInsertionCommand(target_point, fn) - layout.register(command) - model_transformer = factory.ModelTransformerFactory.create(model) - model = model_transformer.transform(layout) - model.nncf.rebuild_graph() + model = self.__add_fn_to_model(model, target_point, fn) ### Check hook inserted correctly - self.__check_hooks(test_parameters, model, target_point, fn) + self.__check_successive_hooks(test_parameters, model, target_point, fn) ### Register and collect statistics after inserted operations - tensor_collector = self.__collect_statistics_get_collector( + statistic_points = self.__get_statistic_points( test_parameters, model, quantizer_config, dataset_samples, inplace_statistics ) + tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) ### Check values are changed because of the inserted operation self.__check_collector( test_parameters, @@ -212,42 +209,145 @@ def fn(x): ) ### Check the inserted operation is inside the model - self.__check_hooks(test_parameters, model, target_point, fn) + self.__check_successive_hooks(test_parameters, model, target_point, fn) - def __collect_statistics_get_collector( - self, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics + @pytest.mark.parametrize( + "test_parameters, nested_target_node_name", + ( + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_PRE_HOOK, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/fn_0", + ), + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATION_WITH_WEIGHTS, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/NNCFConv2d[conv]/ModuleDict[pre_ops]/UpdateWeight[0]/fn_0", + ), + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_POST_HOOK, + QuantizationMode.SYMMETRIC, + False, + 512, + -512, + ), + "PTIdentityConvModel/fn_0", + ), + ), + ) + @pytest.mark.parametrize("nested_target_type", [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK]) + def test_nested_statistics_aggregation( + self, + test_parameters: MinMaxTestParameters, + nested_target_type: TargetType, + nested_target_node_name, + dataset_samples, + is_stat_in_shape_of_scale, + inplace_statistics, + is_backend_support_custom_estimators, ): + model = self.get_backend_model(dataset_samples) + quantizer_config = QuantizerConfig( + mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel + ) + + is_standard_estimator = test_parameters.range_estimator_params in [ + RangeEstimatorParametersSet.MINMAX, + RangeEstimatorParametersSet.MEAN_MINMAX, + ] + if not is_standard_estimator and not is_backend_support_custom_estimators: + pytest.skip("Custom estimators are not supported for this backend yet") + + ### Register operations before statistic collection + @register_operator() + def fn(x): + return x * 2 + + target_point = self.get_target_point(test_parameters.target_type) + model = self.__add_fn_to_model(model, target_point, fn) + nested_target_point = PTMinMaxAlgoBackend.target_point(nested_target_type, nested_target_node_name, 0) + model = self.__add_fn_to_model(model, nested_target_point, fn) + + ### Check hook inserted correctly + self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) + + ### Register and collect statistics after inserted operations + statistic_points = self.__get_statistic_points( + test_parameters, + model, + quantizer_config, + dataset_samples, + inplace_statistics, + ) + tensor_collector = self.__collect_statistics_get_collector(statistic_points, model, dataset_samples) + ### Check values are changed because of the inserted operation + self.__check_collector( + test_parameters, + tensor_collector, + is_stat_in_shape_of_scale, + ) + + ### Check the inserted operation is inside the model + self.__check_nested_hooks(test_parameters, model, target_point, nested_target_type, nested_target_node_name, fn) + + @staticmethod + def __add_fn_to_model(model, target_point, fn): + layout = TransformationLayout() + command = PTInsertionCommand(target_point, fn) + layout.register(command) + model_transformer = factory.ModelTransformerFactory.create(model) + model = model_transformer.transform(layout) + model.nncf.rebuild_graph() + return model + + @classmethod + def __get_statistic_points( + cls, test_parameters: MinMaxTestParameters, model, quantizer_config, dataset_samples, inplace_statistics + ) -> StatisticPointsContainer: statistics_points = StatisticPointsContainer() - for target_point in [test_parameters.target_type]: - target_point = self.get_target_point(target_point) - algorithm_name = "TestAlgo" - statistic_point = self.create_statistics_point( + for target_type in [test_parameters.target_type]: + target_point = cls.get_target_point(target_type) + statistic_point = cls.create_statistics_point( model, quantizer_config, target_point, len(dataset_samples), - algorithm_name, + "TEST_ALGO", inplace_statistics, test_parameters.range_estimator_params, ) statistics_points.add_statistic_point(statistic_point) + return statistics_points + def __collect_statistics_get_collector( + self, + statistics_points: StatisticPointsContainer, + model, + dataset_samples, + ): dataset = self.get_dataset(dataset_samples) statistics_aggregator = self.get_statistics_aggregator(dataset) statistics_aggregator.register_statistic_points(statistics_points) graph = NNCFGraphFactory.create(model) statistics_aggregator.collect_statistics(model, graph) - def filter_func(point): - return ( - algorithm_name in point.algorithm_to_tensor_collectors and point.target_point.type == target_point.type - ) - - tensor_collectors = list( - statistics_points.get_algo_statistics_for_node(target_point.target_node_name, filter_func, algorithm_name) - ) + tensor_collectors = list(statistics_points.get_tensor_collectors()) assert len(tensor_collectors) == 1 - return tensor_collectors[0] + return tensor_collectors[0][2] @staticmethod def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale): @@ -272,7 +372,20 @@ def __check_collector(test_parameters, tensor_collector, stat_in_shape_of_scale) assert stat.max_values.shape == ref_shape @staticmethod - def __check_hooks(test_parameters, model, target_point, fn): + def __check_successive_hooks(test_parameters, model, target_point, fn): + checker = HookChecker(model, "conv") + checker.add_ref( + ref_hooks=[fn], + target_type=test_parameters.target_type, + target_node_name=target_point.target_node_name, + input_port_id=0, + ) + checker.check_with_reference() + + @staticmethod + def __check_nested_hooks( + test_parameters, model, target_point, nested_target_type: TargetType, nested_target_node_name: str, fn + ): checker = HookChecker(model, "conv") checker.add_ref( ref_hooks=[fn], @@ -280,4 +393,10 @@ def __check_hooks(test_parameters, model, target_point, fn): target_node_name=target_point.target_node_name, input_port_id=0, ) + checker.add_ref( + ref_hooks=[fn], + target_type=nested_target_type, + target_node_name=nested_target_node_name, + input_port_id=0, + ) checker.check_with_reference()