diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index 144f5d68238..3faff888f99 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -55,7 +55,7 @@ def wrap_operator(operator, operator_info: PatchedOperatorInfo): Wraps the input callable object (`operator`) with the functionality that allows the calls to this object to be tracked by the currently set global TracingContext. The wrapped functions can be then intercepted, their arguments and return values modified arbitrarily and, for functions that correspond to operations on - tensors in a DNN, their general position and address in the DNN's model control flow graph can be established. + tensors in a DNN, their general position and address in the DNN's model control flow graph can be established. :param: operator: A callable object to be wrapped. :param: operator_info (PatchedOperatorInfo): An informational struct containing the specifics of wrapping diff --git a/nncf/torch/return_types.py b/nncf/torch/return_types.py index b92bf4446b0..cc7b89d03ac 100644 --- a/nncf/torch/return_types.py +++ b/nncf/torch/return_types.py @@ -9,21 +9,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from typing import Any, Optional, Tuple, Type, Union import torch -def __get_supported_torch_return_types() -> Tuple[Type[object]]: +def __get_supported_torch_return_types() -> Tuple[Type[tuple], ...]: """ - Collects types from torch.return_type which can be wrapped/unwrapped by nncf. + Collects types from torch.return_type which can be wrapped/unwrapped by NNCF. + NNCF can wrap/unwrap only return types that have two attributes, one of them + should be the `values` attribute. - :return: List of types from torch.return_type which can be wrapped/unwrapped by nncf. + :return: List of types from torch.return_type which can be wrapped/unwrapped by NNCF. """ - return_type_names = [t for t in dir(torch.return_types) if not t.startswith("_") and not t.startswith("linalg")] - return_types = [getattr(torch.return_types, t_name) for t_name in return_type_names] - return_types = [t for t in return_types if hasattr(t, "values")] - return tuple(return_types) + return tuple(t for _, t in inspect.getmembers(torch.return_types) if inspect.isclass(t) and hasattr(t, "values")) _TORCH_RETURN_TYPES = __get_supported_torch_return_types() @@ -31,7 +31,7 @@ def __get_supported_torch_return_types() -> Tuple[Type[object]]: def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor: """ - Attempts to unwrap the tensor value from one of torch.return_types instantces + Attempts to unwrap the tensor value from one of torch.return_types instances in case torch operation output is wrapped by a torch return_type. :param tensor: Torch tensor or torch return type instance to unwrap values from. @@ -52,5 +52,7 @@ def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optiona """ if isinstance(wrapped_input, _TORCH_RETURN_TYPES): - return wrapped_input.__class__([tensor] + [arg for arg in wrapped_input[1:]]) + # We assume that return_type has only two attributes, the first one is `value`. + # This assumption is checked by `test_unwrap_wrap_torch_return_type`. + return wrapped_input.__class__((tensor, wrapped_input[1])) return tensor diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index 51fdba11f50..7b2299f7a3b 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -39,6 +39,7 @@ from nncf.torch.dynamic_graph.io_handling import FillerInputElement from nncf.torch.dynamic_graph.io_handling import FillerInputInfo from nncf.torch.dynamic_graph.operation_address import OperationAddress +from nncf.torch.dynamic_graph.patch_pytorch import register_operator from nncf.torch.graph.operator_metatypes import PTConv2dMetatype from nncf.torch.graph.operator_metatypes import PTInputNoopMetatype from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype @@ -51,6 +52,7 @@ from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.layers import NNCFConv2d +from nncf.torch.layers import register_module from nncf.torch.model_transformer import PTModelTransformer from nncf.torch.module_operations import BaseOp from nncf.torch.module_operations import UpdateWeight @@ -596,3 +598,155 @@ def test_successive_insertion_transformation(target_type, node_name, input_port_ assert len(hook_ops) == 2 for hook_op, op in zip(hook_ops, ops): assert hook_op is op + + +GLOBAL_LIST = [] + + +def get_dummy_op(op_id): + @register_operator() + def dummy_op(x): + GLOBAL_LIST.append(op_id) + return x + + return dummy_op + + +def get_model_to_test_nested_modules(): + @register_module() + class DummyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros((1,))) + + def forward(self, x): + GLOBAL_LIST.append("DummyModule") + return x + self.weight + + class TestModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.op1 = get_dummy_op("op1") + self.m = DummyModule() + self.op2 = get_dummy_op("op2") + + def forward(self, x): + x = self.op1(x) + x = self.m(x) + x = self.op2(x) + + return TestModel() + + +@pytest.mark.parametrize( + "target_type, node_name, input_port_id, ref_hooks", + ( + ( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + None, + ( + [ + "pre_hook_1", + "op1", + "DummyModule", + "op2", + ], + [ + "pre_hook_0", + "pre_hook_1", + "pre_hook_2", + "op1", + "DummyModule", + "op2", + ], + ), + ), + ( + TargetType.OPERATOR_PRE_HOOK, + "TestModel/dummy_op_1", + 0, + ( + [ + "op1", + "DummyModule", + "pre_hook_1", + "op2", + ], + [ + "op1", + "DummyModule", + "pre_hook_0", + "pre_hook_1", + "pre_hook_2", + "op2", + ], + ), + ), + ( + TargetType.OPERATION_WITH_WEIGHTS, + "TestModel/NNCFUserDummyModule[m]/__add___0", + None, + ( + [ + "op1", + "pre_hook_1", + "DummyModule", + "op2", + ], + [ + "op1", + "pre_hook_0", + "pre_hook_1", + "pre_hook_2", + "DummyModule", + "op2", + ], + ), + ), + )[-1:], +) +def test_nested_hooks(target_type, node_name, input_port_id, ref_hooks): + model = NNCFNetwork(get_model_to_test_nested_modules(), FillerInputInfo([FillerInputElement([10])])) + + # Check test model is working as expected + GLOBAL_LIST.clear() + model.nncf.rebuild_graph() + assert GLOBAL_LIST == [ + "op1", + "DummyModule", + "op2", + ] + + target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id) + transformed_model = model + + command = PTInsertionCommand(target_point, get_dummy_op("pre_hook_1")) + + model_transformer = PTModelTransformer(transformed_model) + transformation_layout = PTTransformationLayout() + transformation_layout.register(command) + transformed_model = model_transformer.transform(transformation_layout) + + GLOBAL_LIST.clear() + transformed_model.nncf.rebuild_graph() + assert GLOBAL_LIST == ref_hooks[0] + + graph = transformed_model.nncf.get_graph() + target_node = graph.get_node_by_name(node_name) + if target_type == TargetType.OPERATOR_POST_HOOK: + target_node = graph.get_next_nodes(target_node)[0] + elif target_type == TargetType.OPERATOR_PRE_HOOK: + target_node = graph.get_previous_nodes(target_node)[0] + else: + target_node = graph.get_previous_nodes(target_node)[1] + transformation_layout = PTTransformationLayout() + for i, target_type_ in enumerate([TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATOR_POST_HOOK]): + target_point_on_hook = PTTargetPoint(target_type_, target_node.node_name, input_port_id=0) + transformation_layout.register(PTInsertionCommand(target_point_on_hook, get_dummy_op(f"pre_hook_{i * 2}"))) + model_transformer = PTModelTransformer(transformed_model) + model_with_nested_hooks = model_transformer.transform(transformation_layout) + + GLOBAL_LIST.clear() + model_with_nested_hooks.nncf.rebuild_graph() + assert GLOBAL_LIST == ref_hooks[1]