Skip to content

Commit

Permalink
WIP test finishing
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Dec 5, 2023
1 parent b1e45ac commit 9241766
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 66 deletions.
10 changes: 5 additions & 5 deletions nncf/torch/graph/transformations/command_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from nncf.common.graph.graph import NNCFNode
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
Expand Down Expand Up @@ -60,9 +59,10 @@ def forward(self, x):
def multiply_insertion_command(
target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int
) -> PTInsertionCommand:
commands = []
target_points = []
for target_node in target_nodes:
target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY))
target_points.append(
PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id)
)

return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name)
return PTSharedFnInsertionCommand(target_points, SQMultiply(scale_value), scale_node_name)
6 changes: 4 additions & 2 deletions nncf/torch/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,16 @@ def requires_graph_rebuild(self):
class PTSharedFnInsertionCommand(PTTransformationCommand):
def __init__(
self,
target_commands: List[PTInsertionCommand],
target_points: List[PTTargetPoint],
fn: Callable,
op_unique_name: str,
priority: TransformationPriority = TransformationPriority.DEFAULT_PRIORITY,
):
super().__init__(TransformationType.INSERT, None)
self.target_commands = target_commands
self.target_points = target_points
self.fn = fn
self.op_name = op_unique_name
self.priority = priority

def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand":
# TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead
Expand Down
16 changes: 8 additions & 8 deletions nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from nncf.torch.nncf_network import ExtraCompressionModuleType
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import ExternalOpCallHook
from nncf.torch.quantization.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_hook import ExternalOpCallHook
from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook
from nncf.torch.utils import get_model_device
from nncf.torch.utils import is_multidevice
Expand Down Expand Up @@ -149,15 +149,15 @@ def _apply_shared_nodes_insertion(

insertion_commands: List[PTInsertionCommand] = []

for command in transformations:
for shared_command in transformations:
op_id = (
command.op_name + f"[{';'.join([tp.target_point.target_node_name for tp in command.target_commands])}]"
shared_command.op_name + f"[{';'.join([tp.target_node_name for tp in shared_command.target_points])}]"
)
model.nncf.add_compression_module(op_id, command.fn, compression_model_type)
model.nncf.add_compression_module(op_id, shared_command.fn, compression_model_type)

for command in command.target_commands:
command.fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id)
insertion_commands.append(command)
for target_point in shared_command.target_points:
fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id)
insertion_commands.append(PTInsertionCommand(target_point, fn, priority=shared_command.priority))

return PTModelTransformer._apply_insertion_transformations(model, insertion_commands)

Expand Down
31 changes: 2 additions & 29 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from copy import deepcopy
from enum import Enum
from enum import IntEnum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

import torch
from torch import nn
Expand Down Expand Up @@ -62,7 +62,7 @@
from nncf.torch.knowledge_distillation.knowledge_distillation_handler import KnowledgeDistillationLossHandler
from nncf.torch.layer_utils import _NNCFModuleMixin
from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules
from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME
from nncf.torch.utils import compute_FLOPs_hook
from nncf.torch.utils import get_all_modules_by_type
Expand Down Expand Up @@ -345,25 +345,6 @@ def reset_nncf_modules(self):
module = self.get_module_by_scope(some_scope)
module.reset()

def get_shallow_copy(self) -> "NNCFNetwork":
from nncf.torch.utils import load_module_state
from nncf.torch.utils import save_module_state

saved_state = save_module_state(self._model_ref)
new_interface = NNCFNetworkInterface(
self._model_ref,
self._input_infos,
self._user_dummy_forward_fn,
self._wrap_inputs_fn,
self._scopes_without_shape_matching,
self._ignored_scopes,
self._target_scopes,
wrap_outputs_fn=self._wrap_outputs_fn,
)
self._model_ref._nncf = new_interface
load_module_state(self._model_ref, saved_state)
return self._model_ref

def get_clean_shallow_copy(self) -> "NNCFNetwork":
# WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
# and load_nncf_module_additions to preserve these, or temporary_clean_view().
Expand Down Expand Up @@ -802,14 +783,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork":
return self.compression_controller.strip(do_copy)


class TemporaryOp:
def __init__(self, op: Callable) -> None:
self._op = op

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._op(*args, **kwargs)


class NNCFNetworkMeta(type):
"""
Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be
Expand Down
19 changes: 3 additions & 16 deletions nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,32 +9,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any

from nncf.torch.dynamic_graph.context import TracingContext
from nncf.torch.quantization.debug_interface import QuantizationDebugInterface
from nncf.torch.quantization.external_hook import ExternalOpCallHook

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME


class ExternalOpCallHook:
def __init__(self, storage_name, context, storage_key):
self._storage_name = storage_name
self._compressed_context = context
self._storage_key = storage_key

def __call__(self, *args: Any, **kwargs) -> Any:
replica = self._compressed_context.base_module_thread_local_replica
storage = getattr(replica.nncf, self._storage_name)
return storage[self._storage_key](*args, **kwargs)


class ExternalQuantizerCallHook(ExternalOpCallHook):
"""
Cannot simply register the quantizer module as a callable hook, since we need to call
a thread-local version of the quantizer module during base module execution.
External hook which is using quantization storage name and
could be constructed with a debug interface.
"""

def __init__(
Expand Down
6 changes: 6 additions & 0 deletions tests/torch/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,12 @@ def check_with_reference(self):
hooks = self._target_model.nncf._compressed_context._post_hooks
self._check_pre_post_hooks(hooks, self._ref_hooks[TargetType.OPERATOR_POST_HOOK])

def clear(self):
"""
Removes all recorded references.
"""
self._ref_hooks.clear()

@staticmethod
def _check_weight_update_hooks(ref_hooks):
for target_module, ref_hooks_per_module in ref_hooks.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/ptq/test_smooth_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_node_name_map(self) -> Dict[str, str]:
@staticmethod
def get_target_node_name(command: TransformationCommand):
if isinstance(command, PTSharedFnInsertionCommand):
return command.target_commands[0].target_point.target_node_name
return command.target_points[0].target_node_name
return command.target_point.target_node_name

@staticmethod
Expand Down
82 changes: 77 additions & 5 deletions tests/torch/test_model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand
from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout
from nncf.torch.layers import NNCFConv2d
from nncf.torch.layers import register_module
Expand All @@ -61,6 +63,8 @@
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.nncf_network import PTInsertionPoint
from nncf.torch.nncf_network import PTInsertionType
from nncf.torch.quantization.external_hook import EXTERNAL_OP_STORAGE_NAME
from nncf.torch.quantization.external_hook import ExternalOpCallHook
from nncf.torch.quantization.layers import AsymmetricQuantizer
from nncf.torch.quantization.layers import PTQuantizerSpec
from tests.common.quantization.mock_graphs import get_ip_graph_for_test
Expand Down Expand Up @@ -475,19 +479,22 @@ def test_extraction_with_fused_bias_transformations():
assert isinstance(extracted_model[0], NNCFConv2d)


def test_bias_correction_transformations():
@pytest.mark.parametrize(
"command_cls,attr_name,new_value",
[(PTBiasCorrectionCommand, "bias", torch.tensor([42.0])), (PTWeightUpdateCommand, "weight", torch.tensor([42.0]))],
)
def test_correction_transformations(command_cls, attr_name, new_value):
model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])]))
model_transformer = PTModelTransformer(model)

new_bias = torch.Tensor([42])

target_point = PTTargetPoint(TargetType.LAYER, "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0")
command = PTBiasCorrectionCommand(target_point, new_bias)
command = command_cls(target_point, new_value)

transformation_layout = PTTransformationLayout()
transformation_layout.register(command)
updated_model = model_transformer.transform(transformation_layout)
assert updated_model.conv1.bias.data == new_bias
param = getattr(updated_model.conv1, attr_name)
assert param.data == new_value


def test_rebuild_graph_after_insert_transformation():
Expand Down Expand Up @@ -549,6 +556,71 @@ def test_quantizer_insertion_transformations(target_type, node_name, input_port_
assert isinstance(op, BaseOp)


@pytest.mark.parametrize(
"priority", [TransformationPriority.OP_INSERTION_PRIORITY, TransformationPriority.DEFAULT_PRIORITY]
)
@pytest.mark.parametrize("compression_module_registered", [False, True])
def test_shared_fn_insertion_point(priority, compression_module_registered, mocker):
model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])]))

class Hook(torch.nn.Module):
def forward(self, x):
return x

tps = [
PTTargetPoint(
TargetType.OPERATOR_POST_HOOK,
"/nncf_model_input_0",
),
PTTargetPoint(
TargetType.OPERATOR_PRE_HOOK,
"InsertionPointTestModel/linear_0",
input_port_id=0,
),
PTTargetPoint(
TargetType.OPERATION_WITH_WEIGHTS,
"InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0",
),
]
OP_UNIQUE_NAME = "UNIQUE_NAME"
if compression_module_registered:
model.nncf.register_compression_module_type(ExtraCompressionModuleType.EXTERNAL_OP)
hook_instance = Hook()
command = PTSharedFnInsertionCommand(tps, hook_instance, OP_UNIQUE_NAME, priority)
transformation_layout = PTTransformationLayout()
transformation_layout.register(command)

mocker.MagicMock()
mocker.patch(
"nncf.torch.model_transformer.PTModelTransformer._apply_insertion_transformations",
return_value=mocker.MagicMock(),
)
model_transformer = PTModelTransformer(model)
_ = model_transformer.transform(transformation_layout=transformation_layout)

assert model.nncf.is_compression_module_registered(ExtraCompressionModuleType.EXTERNAL_OP)

REF_STORAGE_KEY = (
"UNIQUE_NAME[/nncf_model_input_0;InsertionPointTestModel/linear_0;"
"InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0]"
)

storage = getattr(model.nncf, EXTERNAL_OP_STORAGE_NAME)
assert storage[REF_STORAGE_KEY] is hook_instance

mock = PTModelTransformer._apply_insertion_transformations
mock.assert_called_once()

_, commands = mock.call_args.args
assert len(commands) == len(tps)
for command in commands:
assert command.target_point in tps
fn = command.fn
assert isinstance(fn, ExternalOpCallHook)
assert fn._storage_name == EXTERNAL_OP_STORAGE_NAME
assert fn._storage_key == REF_STORAGE_KEY


@pytest.mark.parametrize(
"target_type, node_name, input_port_id",
(
Expand Down
53 changes: 53 additions & 0 deletions tests/torch/test_nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from nncf.torch.nncf_network import PTInsertionType
from tests.torch.composite.test_sparsity_quantization import get_basic_sparsity_plus_quantization_config
from tests.torch.helpers import BasicConvTestModel
from tests.torch.helpers import HookChecker
from tests.torch.helpers import TwoConvTestModel
from tests.torch.helpers import check_correct_nncf_modules_replacement
from tests.torch.helpers import create_compressed_model_and_algo_for_test
Expand Down Expand Up @@ -847,3 +848,55 @@ def test_access_to_input_info():
input_info = ExampleInputInfo.from_example_input(example_input)
nncf_model = NNCFNetwork(model, input_info)
nncf_model.nncf.input_infos


@pytest.mark.parametrize(
"target_type, target_node_name, input_port_id",
[
(TargetType.OPERATOR_PRE_HOOK, "/nncf_model_output_0", 0),
(TargetType.OPERATOR_POST_HOOK, "/nncf_model_input_0", 0),
(TargetType.PRE_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0),
(TargetType.POST_LAYER_OPERATION, "SimplestModel/NNCFConv2d[conv]/conv2d_0", 0),
],
)
def test_temporary_insert_at_point(target_type, target_node_name, input_port_id):
class Hook(torch.nn.Module):
def forward(self, x):
return x

model = SimplestModel()
example_input = torch.ones(SimplestModel.INPUT_SIZE)
input_info = ExampleInputInfo.from_example_input(example_input)
nncf_model = NNCFNetwork(model, input_info)

node_name_vs_address = nncf_model.nncf.get_node_to_op_address_mapping()
ip = PTInsertionPoint(target_type, node_name_vs_address[target_node_name], input_port_id=input_port_id)

checker = HookChecker(nncf_model, "conv")

def _check(ref_hooks_):
checker.clear()
checker.add_ref(ref_hooks_, target_type, target_node_name, input_port_id)
checker.check_with_reference()

permanent_hook = Hook()
# Make temporary hook a ref to the permanent hook
# to check tmp hooks are not removed by their id()
temporary_hook = permanent_hook
nncf_model.nncf.insert_at_point(ip, [permanent_hook])
ref_hooks = [permanent_hook]
_check(ref_hooks)

for _ in range(2):
temporary_hook = Hook()
nncf_model.nncf.temporary_insert_at_point(ip, [temporary_hook])
ref_hooks.append(temporary_hook)
_check(ref_hooks)

nncf_model.nncf.insert_at_point(ip, [permanent_hook])
ref_hooks.append(permanent_hook)
_check(ref_hooks)

nncf_model.nncf.remove_temporary_ops()
del ref_hooks[-2]
_check(ref_hooks)

0 comments on commit 9241766

Please sign in to comment.