Skip to content

Commit

Permalink
Test finishing
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
daniil-lyakhov committed Dec 8, 2023
1 parent a1ea42a commit 129b412
Show file tree
Hide file tree
Showing 20 changed files with 412 additions and 172 deletions.
5 changes: 1 addition & 4 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
return Tensor(_dispatch_list(stack, x, axis=axis))
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self):

self._post_hooks = defaultdict(OrderedDict)
self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict)
self._hooks_counter = 0
self._hooks_counter = -1

self._threading = CopySafeThreadingVars()

Expand Down
42 changes: 42 additions & 0 deletions nncf/torch/external_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext

EXTERNAL_OP_STORAGE_NAME = "external_op"


class ExternalOpCallHook:
"""
Hook which is calling operation registered in the NNCFInterface
by given storage name and storage key. Target operation should be
registered before the ExternalOpCallHook call.
Hook module could not registered as a callable hook
since a thread-local version of the module should be used during
the base module execution.
"""

def __init__(self, storage_name: str, context: TracingContext, storage_key: str):
"""
:param storage_name: Attribute name of a model NNCFInterface.
:param context: Current tracing context.
:param storage_key: Key to retrieve callable hook
"""
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)
10 changes: 5 additions & 5 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
Expand Down Expand Up @@ -60,9 +59,10 @@ def forward(self, x):
def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
commands = []
target_points = []
for target_node in target_nodes:
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
6 changes: 4 additions & 2 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,16 @@ def requires_graph_rebuild(self):
class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
target_commands: List[PTInsertionCommand],
target_points: List[PTTargetPoint],
fn: Callable,
op_unique_name: str,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT, None)
self.target_commands = target_commands
self.target_points = target_points
self.fn = fn
self.op_name = op_unique_name
self.priority = priority

def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand":
# TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead
Expand Down
16 changes: 8 additions & 8 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTInsertionTemporaryCommand
Expand All @@ -35,8 +37,6 @@
from nncf.torch.nncf_network import ExtraCompressionModuleType
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import ExternalOpCallHook
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice
Expand Down Expand Up @@ -149,15 +149,15 @@ def _apply_shared_nodes_insertion(

insertion_commands: List[PTInsertionCommand] = []

for command in transformations:
for shared_command in transformations:
op_id = (
command.op_name + f"[{';'.join([tp.target_point.target_node_name for tp in command.target_commands])}]"
shared_command.op_name + f"[{';'.join([tp.target_node_name for tp in shared_command.target_points])}]"
)
model.nncf.add_compression_module(op_id, command.fn, compression_model_type)
model.nncf.add_compression_module(op_id, shared_command.fn, compression_model_type)

for command in command.target_commands:
command.fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id)
insertion_commands.append(command)
for target_point in shared_command.target_points:
fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id)
insertion_commands.append(PTInsertionCommand(target_point, fn, priority=shared_command.priority))

return PTModelTransformer._apply_insertion_transformations(model, insertion_commands)

Expand Down
34 changes: 2 additions & 32 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from copy import deepcopy
from enum import Enum
from enum import IntEnum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

import torch
from torch import nn
Expand Down Expand Up @@ -53,6 +53,7 @@
from nncf.torch.dynamic_graph.scope_access import get_module_by_scope
from nncf.torch.dynamic_graph.trace_tensor import strip_traced_tensors
from nncf.torch.dynamic_graph.wrappers import wrap_module_call
from nncf.torch.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.graph.graph import PTNNCFGraph
from nncf.torch.graph.graph_builder import GraphBuilder
from nncf.torch.graph.graph_builder import GraphConverter
Expand All @@ -62,7 +63,6 @@
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
from nncf.torch.utils import get_all_modules_by_type
Expand Down Expand Up @@ -345,25 +345,6 @@ def reset_nncf_modules(self):
module = self.get_module_by_scope(some_scope)
module.reset()

def get_shallow_copy(self) -> "NNCFNetwork":
from nncf.torch.utils import load_module_state
from nncf.torch.utils import save_module_state

saved_state = save_module_state(self._model_ref)
new_interface = NNCFNetworkInterface(
self._model_ref,
self._input_infos,
self._user_dummy_forward_fn,
self._wrap_inputs_fn,
self._scopes_without_shape_matching,
self._ignored_scopes,
self._target_scopes,
wrap_outputs_fn=self._wrap_outputs_fn,
)
self._model_ref._nncf = new_interface
load_module_state(self._model_ref, saved_state)
return self._model_ref

def get_clean_shallow_copy(self) -> "NNCFNetwork":
# WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
# and load_nncf_module_additions to preserve these, or temporary_clean_view().
Expand Down Expand Up @@ -395,9 +376,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
retval[nncf_module_scope + relative_scope] = target_module
return retval

def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
Expand Down Expand Up @@ -802,14 +780,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork":
return self.compression_controller.strip(do_copy)


class TemporaryOp:
def __init__(self, op: Callable) -> None:
self._op = op

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._op(*args, **kwargs)


class NNCFNetworkMeta(type):
"""
Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be
Expand Down
21 changes: 4 additions & 17 deletions nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.external_hook import ExternalOpCallHook
from nncf.torch.quantization.debug_interface import QuantizationDebugInterface

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME


class ExternalOpCallHook:
def __init__(self, storage_name, context, storage_key):
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)


class ExternalQuantizerCallHook(ExternalOpCallHook):
"""
Cannot simply register the quantizer module as a callable hook, since we need to call
a thread-local version of the quantizer module during base module execution.
External hook which is using quantization storage name and
could be constructed with a debug interface.
"""

def __init__(
Expand All @@ -48,5 +35,5 @@ def __init__(

def __call__(self, *args, **kwargs):
if self.debug_interface is not None:
self.debug_interface.register_activation_quantize_call(str(self.quantizer_storage_key))
self.debug_interface.register_activation_quantize_call(str(self._storage_key))
return super().__call__(*args, **kwargs)
33 changes: 0 additions & 33 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -27,41 +26,9 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model
self.nncf_module_additions = self.model.nncf.save_nncf_module_additions()

def __enter__(self):
# Model ref removed to prevent copying
self.model.nncf.update_model_ref(None)

# nncf_replaced_models removed to prevent copying
replaced_modules = self.model.nncf._nncf_replaced_modules
self.model.nncf._nncf_replaced_modules = None

self.nncf_interface = deepcopy(self.model.nncf)

# Model ref is recovering
self.model.nncf.update_model_ref(self.model)
self.nncf_interface.update_model_ref(self.model)

# nncf_replaced_models is recovering
self.model.nncf._nncf_replaced_modules = replaced_modules
self.nncf_interface._nncf_replaced_modules = replaced_modules
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface
self.model.nncf.reset_nncf_modules()
self.model.nncf.load_nncf_module_additions(self.nncf_module_additions)


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

Expand Down
8 changes: 4 additions & 4 deletions tests/common/experimental/test_statistic_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class BadStatContainer:

class TemplateTestStatisticCollector:
@abstractmethod
def get_nncf_tensor_cls(self):
def get_nncf_tensor(self, value: np.ndarray) -> NNCFTensor:
pass

@abstractmethod
Expand Down Expand Up @@ -366,10 +366,10 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
collector.register_statistic_branch("A", reducer, aggregator)
input_name = "input_name"
full_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([100]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([100]))}, [(hash(reducer), [input_name])]
)
empty_inputs = TensorCollector.get_tensor_collector_inputs(
{input_name: self.get_nncf_tensor_cls()(np.array([]))}, [(hash(reducer), [input_name])]
{input_name: self.get_nncf_tensor(np.array([]))}, [(hash(reducer), [input_name])]
)

stats = collector.get_statistics()
Expand All @@ -385,7 +385,7 @@ def test_empty_tensors_register(self, inplace, any_not_empty):
assert aggregator._collected_samples == 2
stats = collector.get_statistics()
assert len(stats) == 1
assert stats["A"] == self.get_nncf_tensor_cls()([100])
assert stats["A"] == self.get_nncf_tensor([100])
return

assert len(aggregator._container) == 0
Expand Down
9 changes: 6 additions & 3 deletions tests/common/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ class BCStatsCollectors(Enum):


class TemplateTestStatisticsAggregator:
@classmethod
@abstractmethod
def get_min_max_algo_backend_cls(self) -> Type[MinMaxAlgoBackend]:
def get_min_max_algo_backend_cls(cls) -> Type[MinMaxAlgoBackend]:
pass

@abstractmethod
Expand All @@ -73,6 +74,7 @@ def get_statistics_aggregator(self, dataset):
def get_dataset(self, samples):
pass

@staticmethod
@abstractmethod
def get_target_point(self, target_type: TargetType) -> TargetPoint:
pass
Expand Down Expand Up @@ -631,10 +633,11 @@ def filter_func(point):
assert ref.shape == val.shape
assert np.allclose(val, ref)

@classmethod
def create_statistics_point(
self, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
cls, model, q_config, target_point, subset_size, algorithm_name, inplace_statistics, range_estimator
):
algo_backend = self.get_min_max_algo_backend_cls()
algo_backend = cls.get_min_max_algo_backend_cls()
nncf_graph = NNCFGraphFactory.create(model)
tensor_collector = algo_backend.get_statistic_collector(
range_estimator,
Expand Down
6 changes: 4 additions & 2 deletions tests/onnx/test_statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@


class TestStatisticsAggregator(TemplateTestStatisticsAggregator):
def get_min_max_algo_backend_cls(self) -> Type[ONNXMinMaxAlgoBackend]:
@classmethod
def get_min_max_algo_backend_cls(cls) -> Type[ONNXMinMaxAlgoBackend]:
return ONNXMinMaxAlgoBackend

def get_bias_correction_algo_backend_cls(self) -> Type[ONNXBiasCorrectionAlgoBackend]:
Expand Down Expand Up @@ -65,7 +66,8 @@ def transform_fn(data_item):

return Dataset(samples, transform_fn)

def get_target_point(self, target_type: TargetType):
@staticmethod
def get_target_point(target_type: TargetType):
target_node_name = IDENTITY_NODE_NAME
port_id = 0
if target_type == TargetType.OPERATION_WITH_WEIGHTS:
Expand Down
Loading

0 comments on commit 129b412

Please sign in to comment.