diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index a5d41d4c272..2db2d776dd1 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -95,7 +95,7 @@ def __init__(self): self._post_hooks = defaultdict(OrderedDict) self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict) - self._hooks_counter = 0 + self._hooks_counter = -1 self._threading = CopySafeThreadingVars() diff --git a/nncf/torch/external_hook.py b/nncf/torch/external_hook.py new file mode 100644 index 00000000000..2dad7bd19d8 --- /dev/null +++ b/nncf/torch/external_hook.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from nncf.torch.dynamic_graph.context import TracingContext + +EXTERNAL_OP_STORAGE_NAME = "external_op" + + +class ExternalOpCallHook: + """ + Hook which is calling operation registered in the NNCFInterface + by given storage name and storage key. Target operation should be + registered before the ExternalOpCallHook call. + Hook module could not registered as a callable hook + since a thread-local version of the module should be used during + the base module execution. + """ + + def __init__(self, storage_name: str, context: TracingContext, storage_key: str): + """ + :param storage_name: Attribute name of a model NNCFInterface. + :param context: Current tracing context. + :param storage_key: Key to retrieve callable hook + """ + self._storage_name = storage_name + self._compressed_context = context + self._storage_key = storage_key + + def __call__(self, *args: Any, **kwargs) -> Any: + replica = self._compressed_context.base_module_thread_local_replica + storage = getattr(replica.nncf, self._storage_name) + return storage[self._storage_key](*args, **kwargs) diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 3273a35c4a5..929e3498a6f 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -16,7 +16,6 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType -from nncf.common.graph.transformations.commands import TransformationPriority from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand @@ -60,9 +59,10 @@ def forward(self, x): def multiply_insertion_command( target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int ) -> PTInsertionCommand: - commands = [] + target_points = [] for target_node in target_nodes: - target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id) - commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY)) + target_points.append( + PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id) + ) - return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name) + return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name) diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index 1359fe2941d..fc64c1a3937 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -193,14 +193,16 @@ def requires_graph_rebuild(self): class PTSharedFnInsertionCommand(PTTransformationCommand): def __init__( self, - target_commands: List[PTInsertionCommand], + target_points: List[PTTargetPoint], fn: Callable, op_unique_name: str, + priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY, ): super().__init__(TransformationType.INSERT, None) - self.target_commands = target_commands + self.target_points = target_points self.fn = fn self.op_name = op_unique_name + self.priority = priority def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand": # TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index b57ca6cac58..afd63e2cf73 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -21,6 +21,8 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.quantization.structs import NonWeightQuantizerId +from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand @@ -35,8 +37,6 @@ from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint -from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME -from nncf.torch.quantization.external_quantizer import ExternalOpCallHook from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook from nncf.torch.utils import get_model_device from nncf.torch.utils import is_multidevice @@ -149,15 +149,15 @@ def _apply_shared_nodes_insertion( insertion_commands: List[PTInsertionCommand] = [] - for command in transformations: + for shared_command in transformations: op_id = ( - command.op_name + f"[{';'.join([tp.target_point.target_node_name for tp in command.target_commands])}]" + shared_command.op_name + f"[{';'.join([tp.target_node_name for tp in shared_command.target_points])}]" ) - model.nncf.add_compression_module(op_id, command.fn, compression_model_type) + model.nncf.add_compression_module(op_id, shared_command.fn, compression_model_type) - for command in command.target_commands: - command.fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id) - insertion_commands.append(command) + for target_point in shared_command.target_points: + fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id) + insertion_commands.append(PTInsertionCommand(target_point, fn, priority=shared_command.priority)) return PTModelTransformer._apply_insertion_transformations(model, insertion_commands) diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index c2d2bbb7e31..cdaadd075d6 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -17,7 +17,7 @@ from copy import deepcopy from enum import Enum from enum import IntEnum -from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar +from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar import torch from torch import nn @@ -53,6 +53,7 @@ from nncf.torch.dynamic_graph.scope_access import get_module_by_scope from nncf.torch.dynamic_graph.trace_tensor import strip_traced_tensors from nncf.torch.dynamic_graph.wrappers import wrap_module_call +from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.graph_builder import GraphBuilder from nncf.torch.graph.graph_builder import GraphConverter @@ -62,7 +63,6 @@ from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules -from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook from nncf.torch.utils import get_all_modules_by_type @@ -345,25 +345,6 @@ def reset_nncf_modules(self): module = self.get_module_by_scope(some_scope) module.reset() - def get_shallow_copy(self) -> "NNCFNetwork": - from nncf.torch.utils import load_module_state - from nncf.torch.utils import save_module_state - - saved_state = save_module_state(self._model_ref) - new_interface = NNCFNetworkInterface( - self._model_ref, - self._input_infos, - self._user_dummy_forward_fn, - self._wrap_inputs_fn, - self._scopes_without_shape_matching, - self._ignored_scopes, - self._target_scopes, - wrap_outputs_fn=self._wrap_outputs_fn, - ) - self._model_ref._nncf = new_interface - load_module_state(self._model_ref, saved_state) - return self._model_ref - def get_clean_shallow_copy(self) -> "NNCFNetwork": # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions # and load_nncf_module_additions to preserve these, or temporary_clean_view(). @@ -802,14 +783,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork": return self.compression_controller.strip(do_copy) -class TemporaryOp: - def __init__(self, op: Callable) -> None: - self._op = op - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self._op(*args, **kwargs) - - class NNCFNetworkMeta(type): """ Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index 08848f1bfaf..4f5e69a4b81 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -9,32 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any from nncf.torch.dynamic_graph.context import TracingContext +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.quantization.debug_interface import QuantizationDebugInterface EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers" -EXTERNAL_OP_STORAGE_NAME = "external_op" EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME -class ExternalOpCallHook: - def __init__(self, storage_name, context, storage_key): - self._storage_name = storage_name - self._compressed_context = context - self._storage_key = storage_key - - def __call__(self, *args: Any, **kwargs) -> Any: - replica = self._compressed_context.base_module_thread_local_replica - storage = getattr(replica.nncf, self._storage_name) - return storage[self._storage_key](*args, **kwargs) - - class ExternalQuantizerCallHook(ExternalOpCallHook): """ - Cannot simply register the quantizer module as a callable hook, since we need to call - a thread-local version of the quantizer module during base module execution. + External hook which is using quantization storage name and + could be constructed with a debug interface. """ def __init__( diff --git a/tests/torch/helpers.py b/tests/torch/helpers.py index 1c2008360c5..e15f7671627 100644 --- a/tests/torch/helpers.py +++ b/tests/torch/helpers.py @@ -549,6 +549,12 @@ def check_with_reference(self): hooks = self._target_model.nncf._compressed_context._post_hooks self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_POST_HOOK]) + def clear(self): + """ + Removes all recorded references. + """ + self._ref_hooks.clear() + @staticmethod def _check_weight_update_hooks(ref_hooks): for target_module, ref_hooks_per_module in ref_hooks.items(): diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py index 7f723167e3b..2c9da2c031e 100644 --- a/tests/torch/ptq/test_smooth_quant.py +++ b/tests/torch/ptq/test_smooth_quant.py @@ -58,7 +58,7 @@ def get_node_name_map(self) -> Dict[str, str]: @staticmethod def get_target_node_name(command: TransformationCommand): if isinstance(command, PTSharedFnInsertionCommand): - return command.target_commands[0].target_point.target_node_name + return command.target_points[0].target_node_name return command.target_point.target_node_name @staticmethod diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index a9ff2333b74..a642017cea0 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -41,6 +41,8 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress from nncf.torch.dynamic_graph.patch_pytorch import register_operator +from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.external_hook import ExternalOpCallHook from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype @@ -48,9 +50,12 @@ from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layers import NNCFConv2d from nncf.torch.layers import register_module @@ -145,8 +150,11 @@ def setup(self): point_for_relu_inputs, ] + @pytest.mark.parametrize( + "insert_method_name,check_tmp_ops", [("insert_at_point", False), ("temporary_insert_at_point", True)] + ) @pytest.mark.parametrize("target_point", available_points) - def test_single_insertions(self, setup, target_point: PTTargetPoint): + def test_single_insertions(self, setup, target_point: PTTargetPoint, insert_method_name: str, check_tmp_ops: bool): insertion_point = PTInsertionPoint( target_point.target_type, OperationAddress.from_str(target_point.target_node_name), @@ -157,7 +165,8 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): else: hook = BaseOp(lambda x: x) - self.compressed_model.nncf.insert_at_point(insertion_point, [hook]) + insert_at_point_method = getattr(self.compressed_model.nncf, insert_method_name) + insert_at_point_method(insertion_point, [hook]) if insertion_point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() @@ -174,6 +183,9 @@ def test_single_insertions(self, setup, target_point: PTTargetPoint): module = self.compressed_model.nncf.get_module_by_scope(insertion_point.module_scope) assert module.post_ops["0"] is hook + if check_tmp_ops: + assert len(self.compressed_model.nncf._temprorary_hooks_adresses) == 1 + priority_types = ["same", "different"] insertion_types = TargetType priority_test_cases = list(itertools.product(priority_types, insertion_types)) @@ -183,8 +195,9 @@ def check_order(iterable1: List, iterable2: List, ordering: List): for idx, order in enumerate(ordering): assert iterable1[idx] is iterable2[order] + @pytest.mark.parametrize("command_cls", [PTInsertionCommand, PTInsertionTemporaryCommand]) @pytest.mark.parametrize("case", priority_test_cases, ids=[x[1].name + "-" + x[0] for x in priority_test_cases]) - def test_priority(self, case, setup): + def test_priority(self, case, command_cls, setup): priority_type = case[0] insertion_type = case[1] @@ -214,14 +227,14 @@ def test_priority(self, case, setup): if priority_type == "same": # Same-priority commands will be executed in registration order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.DEFAULT_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.DEFAULT_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.DEFAULT_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.DEFAULT_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) else: # Prioritized commands will be executed in ascending priority order - command1 = PTInsertionCommand(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) - command2 = PTInsertionCommand(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) - command3 = PTInsertionCommand(point, hook3, TransformationPriority.DEFAULT_PRIORITY) + command1 = command_cls(point, hook1, TransformationPriority.SPARSIFICATION_PRIORITY) + command2 = command_cls(point, hook2, TransformationPriority.QUANTIZATION_PRIORITY) + command3 = command_cls(point, hook3, TransformationPriority.DEFAULT_PRIORITY) layout = PTTransformationLayout() layout.register(command1) @@ -241,10 +254,12 @@ def test_priority(self, case, setup): pre_hook_id = PreHookId( OperationAddress.from_str(point.target_node_name), input_port_id=point.input_port_id ) - self.check_order(ctx._pre_hooks[pre_hook_id], hook_list, order) + actual_pre_hooks = list(ctx._pre_hooks[pre_hook_id].values()) + self.check_order(actual_pre_hooks, hook_list, order) if insertion_type == TargetType.OPERATOR_POST_HOOK: ctx = self.compressed_model.nncf.get_tracing_context() - self.check_order(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)], hook_list, order) + actual_post_hooks = list(ctx._post_hooks[OperationAddress.from_str(point.target_node_name)].values()) + self.check_order(actual_post_hooks, hook_list, order) if insertion_type == TargetType.OPERATION_WITH_WEIGHTS: module = self.compressed_model.nncf.get_containing_module(point.target_node_name) @@ -475,19 +490,22 @@ def test_extraction_with_fused_bias_transformations(): assert isinstance(extracted_model[0], NNCFConv2d) -def test_bias_correction_transformations(): +@pytest.mark.parametrize( + "command_cls,attr_name,new_value", + [(PTBiasCorrectionCommand, "bias", torch.tensor([42.0])), (PTWeightUpdateCommand, "weight", torch.tensor([42.0]))], +) +def test_correction_transformations(command_cls, attr_name, new_value): model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) model_transformer = PTModelTransformer(model) - new_bias = torch.Tensor([42]) - target_point = PTTargetPoint(TargetType.LAYER, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0") - command = PTBiasCorrectionCommand(target_point, new_bias) + command = command_cls(target_point, new_value) transformation_layout = PTTransformationLayout() transformation_layout.register(command) updated_model = model_transformer.transform(transformation_layout) - assert updated_model.conv1.bias.data == new_bias + param = getattr(updated_model.conv1, attr_name) + assert param.data == new_value def test_rebuild_graph_after_insert_transformation(): @@ -550,25 +568,90 @@ def test_quantizer_insertion_transformations(target_type, node_name, input_port_ @pytest.mark.parametrize( - "target_type, node_name, input_port_id", - ( - ( + "priority", [TransformationPriority.OP_INSERTION_PRIORITY, TransformationPriority.DEFAULT_PRIORITY] +) +@pytest.mark.parametrize("compression_module_registered", [False, True]) +def test_shared_fn_insertion_point(priority, compression_module_registered, mocker): + model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) + + class Hook(torch.nn.Module): + def forward(self, x): + return x + + tps = [ + PTTargetPoint( TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", - None, ), - ( + PTTargetPoint( TargetType.OPERATOR_PRE_HOOK, "InsertionPointTestModel/linear_0", - 0, + input_port_id=0, ), - ( + PTTargetPoint( TargetType.OPERATION_WITH_WEIGHTS, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", - None, ), + ] + OP_UNIQUE_NAME = "UNIQUE_NAME" + if compression_module_registered: + model.nncf.register_compression_module_type(ExtraCompressionModuleType.EXTERNAL_OP) + hook_instance = Hook() + command = PTSharedFnInsertionCommand(tps, hook_instance, OP_UNIQUE_NAME, priority) + transformation_layout = PTTransformationLayout() + transformation_layout.register(command) + + mocker.MagicMock() + mocker.patch( + "nncf.torch.model_transformer.PTModelTransformer._apply_insertion_transformations", + return_value=mocker.MagicMock(), + ) + model_transformer = PTModelTransformer(model) + _ = model_transformer.transform(transformation_layout=transformation_layout) + + assert model.nncf.is_compression_module_registered(ExtraCompressionModuleType.EXTERNAL_OP) + + REF_STORAGE_KEY = ( + "UNIQUE_NAME[/nncf_model_input_0;InsertionPointTestModel/linear_0;" + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0]" + ) + + storage = getattr(model.nncf, EXTERNAL_OP_STORAGE_NAME) + assert storage[REF_STORAGE_KEY] is hook_instance + + mock = PTModelTransformer._apply_insertion_transformations + mock.assert_called_once() + + _, commands = mock.call_args.args + assert len(commands) == len(tps) + for command in commands: + assert command.target_point in tps + fn = command.fn + assert isinstance(fn, ExternalOpCallHook) + assert fn._storage_name == EXTERNAL_OP_STORAGE_NAME + assert fn._storage_key == REF_STORAGE_KEY + + +INSERTION_POINT_TEST_MODEL_TARGET_POINTS = ( + ( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + None, + ), + ( + TargetType.OPERATOR_PRE_HOOK, + "InsertionPointTestModel/linear_0", + 0, + ), + ( + TargetType.OPERATION_WITH_WEIGHTS, + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", + None, ), ) + + +@pytest.mark.parametrize("target_type, node_name, input_port_id", INSERTION_POINT_TEST_MODEL_TARGET_POINTS) def test_successive_insertion_transformation(target_type, node_name, input_port_id): model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) diff --git a/tests/torch/test_nncf_network.py b/tests/torch/test_nncf_network.py index 5b8f5b63a01..db991322103 100644 --- a/tests/torch/test_nncf_network.py +++ b/tests/torch/test_nncf_network.py @@ -46,6 +46,7 @@ from nncf.torch.nncf_network import PTInsertionType from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config from tests.torch.helpers import BasicConvTestModel +from tests.torch.helpers import HookChecker from tests.torch.helpers import TwoConvTestModel from tests.torch.helpers import check_correct_nncf_modules_replacement from tests.torch.helpers import create_compressed_model_and_algo_for_test @@ -847,3 +848,55 @@ def test_access_to_input_info(): input_info = ExampleInputInfo.from_example_input(example_input) nncf_model = NNCFNetwork(model, input_info) nncf_model.nncf.input_infos + + +@pytest.mark.parametrize( + "target_type, target_node_name, input_port_id", + [ + (TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0), + (TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0), + (TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + (TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0), + ], +) +def test_temporary_insert_at_point(target_type, target_node_name, input_port_id): + class Hook(torch.nn.Module): + def forward(self, x): + return x + + model = SimplestModel() + example_input = torch.ones(SimplestModel.INPUT_SIZE) + input_info = ExampleInputInfo.from_example_input(example_input) + nncf_model = NNCFNetwork(model, input_info) + + node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping() + ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id) + + checker = HookChecker(nncf_model, "conv") + + def _check(ref_hooks_): + checker.clear() + checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id) + checker.check_with_reference() + + permanent_hook = Hook() + # Make temporary hook a ref to the permanent hook + # to check tmp hooks are not removed by their id() + temporary_hook = permanent_hook + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks = [permanent_hook] + _check(ref_hooks) + + for _ in range(2): + temporary_hook = Hook() + nncf_model.nncf.temporary_insert_at_point(ip, [temporary_hook]) + ref_hooks.append(temporary_hook) + _check(ref_hooks) + + nncf_model.nncf.insert_at_point(ip, [permanent_hook]) + ref_hooks.append(permanent_hook) + _check(ref_hooks) + + nncf_model.nncf.remove_temporary_ops() + del ref_hooks[-2] + _check(ref_hooks)