Skip to content

Commit

Permalink
WIP temporary insertion command
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Jan 22, 2024
1 parent 3a4ef0e commit 3458217
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 22 deletions.
32 changes: 23 additions & 9 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

import threading
import weakref
from collections import OrderedDict
from collections import defaultdict
from collections import deque
from contextlib import contextmanager
from typing import Callable, DefaultDict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Union

import torch

Expand Down Expand Up @@ -94,10 +95,11 @@ class TracingContext:
def __init__(self):
self.graph = DynamicGraph()

self._post_hooks: DefaultDict[OperationAddress, List[Callable]] = defaultdict(list)
self._pre_hooks: DefaultDict[PreHookId, List[Callable]] = defaultdict(list)
self._post_hooks = defaultdict(OrderedDict)
self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict)
self._num_nested_hooks = 0
self.reused_parameters = []
self._hooks_counter = 0

self._threading = CopySafeThreadingVars()

Expand Down Expand Up @@ -282,9 +284,16 @@ def pop_scope(self):
self.relative_scopes_stack.pop()
self.module_call_stack.pop()

def register_pre_hooks(self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int):
def register_pre_hooks(
self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int
) -> List[int]:
pre_hook_id = PreHookId(op_address, input_port_id)
self._pre_hooks[pre_hook_id].extend(fn_list)
hooks_ids = []
for fn in fn_list:
self._hooks_counter += 1
self._pre_hooks[pre_hook_id][self._hooks_counter] = fn
hooks_ids.append(self._hooks_counter)
return pre_hook_id, hooks_ids

def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInput) -> OperatorInput:
in_op = getattr(self, "in_operator", False)
Expand All @@ -296,21 +305,26 @@ def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInp
for pre_hook_id in pre_hook_ids_for_curr_op:
hook_list_for_current_input_port = self._pre_hooks[pre_hook_id]
input_arg_to_process = pre_hook_id.input_port_id
for hook in hook_list_for_current_input_port:
for hook in hook_list_for_current_input_port.values():
op_inputs[input_arg_to_process] = hook(op_inputs[input_arg_to_process])
self._threading.thread_local.num_nested_hooks -= 1
self.in_operator = in_op
return op_inputs

def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress):
self._post_hooks[op_address].extend(fn_list)
def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress) -> List[int]:
hooks_ids = []
for fn in fn_list:
self._hooks_counter += 1
self._post_hooks[op_address][self._hooks_counter] = fn
hooks_ids.append(self._hooks_counter)
return op_address, hooks_ids

def execute_post_hooks(self, op_address: OperationAddress, outputs):
in_op = getattr(self, "in_operator", False)
self.in_operator = False
self._threading.thread_local.num_nested_hooks += 1
if op_address in self._post_hooks:
for hook in self._post_hooks[op_address]:
for hook in self._post_hooks[op_address].values():
outputs = hook(outputs)
self._threading.thread_local.num_nested_hooks -= 1
self.in_operator = in_op
Expand Down
29 changes: 29 additions & 0 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,35 @@ def requires_graph_rebuild(self):
return self.priority == TransformationPriority.QUANTIZATION_PRIORITY


class PTInsertionTemporaryCommand(PTTransformationCommand):
"""
Insertion operation to the models.
"""

def __init__(
self,
point: PTTargetPoint,
fn: Callable,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT, point)
self.fn: Callable = fn
self.priority: TransformationPriority = priority

def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand":
# TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead
raise NotImplementedError()

def requires_graph_rebuild(self):
"""
Return boolean flag to rebuild graph of model.
:return: Boolean flag.
"""
# Rebuild graph when adding quantization nodes or an op.
return False


class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
Expand Down
34 changes: 31 additions & 3 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import copy
from collections import defaultdict
from typing import Callable, Dict, List, Tuple
from typing import Callable, Dict, Iterator, List, Tuple

from torch import Tensor
from torch import nn
Expand All @@ -24,6 +24,7 @@
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(self, model: NNCFNetwork):
self._command_transformation_ordered_pairs = [
(PTModelExtractionWithFusedBiasCommand, self._apply_extraction_with_fused_bias_transformations),
(PTInsertionCommand, self._apply_insertion_transformations),
(PTInsertionTemporaryCommand, self._apply_temporary_insertion_transformation),
(PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations),
(PTBiasCorrectionCommand, self._apply_bias_correction_transformations),
(PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion),
Expand Down Expand Up @@ -82,6 +84,33 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P
"""
Applies insertion transformations to the model.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
"""
for insert_args in PTModelTransformer._get_nncf_network_insert_arguments(model, transformations):
model.nncf.insert_at_point(*insert_args)
return model

@staticmethod
def _apply_temporary_insertion_transformation(
model: NNCFNetwork, transformations: List[PTInsertionCommand]
) -> NNCFNetwork:
"""
Applies temporary insertion transformations to the model.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
"""
for insert_args in PTModelTransformer._get_nncf_network_insert_arguments(model, transformations):
model.nncf.temporary_insert_at_point(*insert_args)
return model

def _get_nncf_network_insert_arguments(
model: NNCFNetwork, transformations: List[PTInsertionCommand]
) -> Iterator[Tuple[PTInsertionPoint, List[Callable]]]:
"""
Applies insertion transformations to the model.
:param model: Model to apply transformations.
:param transformations: List of the bias correction transformations.
"""
Expand All @@ -107,8 +136,7 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P

for pt_ip, fn_list_with_priority in fns_grouped_by_points.items():
fn_list_with_priority = sorted(fn_list_with_priority, key=lambda x: x[1])
model.nncf.insert_at_point(pt_ip, [x[0] for x in fn_list_with_priority])
return model
yield (pt_ip, [x[0] for x in fn_list_with_priority])

@staticmethod
def _apply_shared_nodes_insertion(
Expand Down
48 changes: 43 additions & 5 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from enum import Enum
from enum import IntEnum
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

import torch
from torch import nn
Expand Down Expand Up @@ -244,6 +244,7 @@ def __init__(
self._target_scopes = target_scopes
self._user_dummy_forward_fn = dummy_forward_fn
self._kd_loss_handler = None
self._temprorary_hooks_adresses = []

if wrap_inputs_fn is not None:
self._wrap_inputs_fn = wrap_inputs_fn
Expand Down Expand Up @@ -420,11 +421,36 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
return hook_addresses

def remove_temporary_ops(self):
for point, hook_address, hook_ids in self._temprorary_hooks_adresses:
for hook_idx in hook_ids:
if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK:
hooks = self._compressed_context._pre_hooks[hook_address]
elif point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK:
hooks = self._compressed_context._post_hooks[hook_address]
else:
nncf_module = self.get_module_by_scope(point.module_scope)
if point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP:
hooks = nncf_module.pre_ops
else:
hooks = nncf_module.post_ops

del hooks[hook_idx]
self._temprorary_hooks_adresses.clear()

def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hooks_ids = None
if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK:
self._compressed_context.register_pre_hooks(fn_list, point.op_address, point.input_port_id)
hook_address, hooks_ids = self._compressed_context.register_pre_hooks(
fn_list, point.op_address, point.input_port_id
)
elif point.insertion_type == PTInsertionType.OPERATOR_POST_HOOK:
self._compressed_context.register_post_hooks(fn_list, point.op_address)
hook_address, hooks_ids = self._compressed_context.register_post_hooks(fn_list, point.op_address)
elif point.insertion_type in [PTInsertionType.NNCF_MODULE_PRE_OP, PTInsertionType.NNCF_MODULE_POST_OP]:
nncf_module = self.get_module_by_scope(point.module_scope)
if not isinstance(nncf_module, _NNCFModuleMixin):
Expand All @@ -442,14 +468,18 @@ def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
for scope_list_for_module in self.get_nncf_module_scopes():
norm_nncf_scopes.extend([self._normalize_variable_recurrent_scope(x) for x in scope_list_for_module])
assert norm_target_scope in norm_nncf_scopes # Required for proper Recurrent/VariableRecurrent addressing
hooks_ids = []
if point.insertion_type == PTInsertionType.NNCF_MODULE_PRE_OP:
for fn in fn_list:
nncf_module.register_pre_forward_operation(fn)
hooks_ids.append(nncf_module.register_pre_forward_operation(fn))
elif point.insertion_type == PTInsertionType.NNCF_MODULE_POST_OP:
for fn in fn_list:
nncf_module.register_post_forward_operation(fn)
hooks_ids.append(nncf_module.register_post_forward_operation(fn))
hook_address = None
else:
raise RuntimeError("Unsupported insertion type: {}".format(point.insertion_type))
hook_addresses = (point, hook_address, hooks_ids)
return hook_addresses

def get_graph(self) -> PTNNCFGraph:
if self._compressed_context.graph.get_nodes_count() == 0 or self._compressed_graphs_pair.nncf_graph is None:
Expand Down Expand Up @@ -819,6 +849,14 @@ def get_reused_parameters(self):
return ret


class TemporaryOp:
def __init__(self, op: Callable) -> None:
self._op = op

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._op(*args, **kwargs)


class NNCFNetworkMeta(type):
"""
Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be
Expand Down
10 changes: 6 additions & 4 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer
from nncf.common.tensor_statistics.aggregator import StatisticsAggregator
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.tensor import PTNNCFTensor
from nncf.torch.tensor_statistics.algo import create_register_input_hook
Expand Down Expand Up @@ -60,8 +60,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
with ModelView(model) as intermediate_model:
super().collect_statistics(intermediate_model, graph)
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

def _register_statistics(
self, outputs: Dict[str, PTNNCFTensor], statistic_points: StatisticPointsContainer
Expand All @@ -79,7 +81,7 @@ def _get_transformation_layout_extra_outputs(
for collectors in _statistic_point.algorithm_to_tensor_collectors.values():
for collector in collectors:
transformation_commands.append(
PTInsertionCommand(
PTInsertionTemporaryCommand(
_statistic_point.target_point,
create_register_input_hook(collector=collector),
TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION,
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _check_weight_update_hooks(ref_hooks):
def _check_pre_post_hooks(hooks, ref_hooks):
assert len(hooks) == len(ref_hooks)
for op_address, ref_hooks in ref_hooks.items():
actual_hooks = hooks[op_address]
actual_hooks = hooks[op_address].values()
assert len(actual_hooks) == len(ref_hooks)
for actual_hook, ref_hook in zip(actual_hooks, ref_hooks):
assert actual_hook is ref_hook

0 comments on commit 3458217

Please sign in to comment.