Skip to content

Commit

Permalink
Test finishing
Browse files Browse the repository at this point in the history
Cleanup

Comments
  • Loading branch information
daniil-lyakhov committed Jan 2, 2024
1 parent c9b27f8 commit 08661a5
Show file tree
Hide file tree
Showing 15 changed files with 314 additions and 145 deletions.
1 change: 0 additions & 1 deletion nncf/common/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class TransformationPriority(IntEnum):
FP32_TENSOR_STATISTICS_OBSERVATION = 1
PRUNING_PRIORITY = 2
SPARSIFICATION_PRIORITY = 3
OP_INSERTION_PRIORITY = 4
QUANTIZATION_PRIORITY = 11


Expand Down
5 changes: 1 addition & 4 deletions nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,7 @@ def stack(x: List[Tensor], axis: int = 0) -> Tensor:
:return: Stacked Tensor.
"""
if isinstance(x, List):
unwrapped_x = [i.data for i in x]
# singledispatch cannot dispatch function by element in a list
res = stack.dispatch(type(unwrapped_x[0]))(unwrapped_x, axis=axis)
return Tensor(res)
return Tensor(_dispatch_list(stack, x, axis=axis))
raise NotImplementedError(f"Function `stack` is not implemented for {type(x)}")


Expand Down
12 changes: 10 additions & 2 deletions nncf/quantization/algorithms/smooth_quant/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,14 @@ def apply(
if any(val.data is None for val in activations_value):
empty_statistic = True
break
assert len(activations_value) == 1
if len(activations_value) != 1:
raise RuntimeError(
(
"More than one statistic is collected for one node during"
f"Smooth Quanti algorithm: {node_to_smooth.node_name}"
)
)

activations_value = self._clip_statistics(activations_value)

weight_value = self._backend_entity.get_weight_value(node_to_smooth, model)
Expand Down Expand Up @@ -254,6 +261,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin
for node_data in nodes_to_smooth_data:
node_to_smooth = node_data["node_to_smooth"]
target_point = self._backend_entity.target_point(
target_type=self._backend_entity.pre_layer_target_type(),
target_node_name=node_to_smooth.node_name,
port_id=node_data["input_act_port"],
)
Expand Down Expand Up @@ -306,7 +314,7 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[
nodes_to_smooth_data.append(
{
"node_to_smooth": node_with_weight,
"input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph),
"input_act_port": activation_port_id,
}
)
return nodes_to_smooth_data
Expand Down
34 changes: 28 additions & 6 deletions nncf/quantization/algorithms/smooth_quant/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from abc import ABC
from abc import abstractmethod
from typing import List, Tuple, TypeVar
from typing import Callable, List, Tuple, TypeVar

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.graph.transformations.commands import TargetPoint
from nncf.common.graph.transformations.commands import TargetType
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.tensor_statistics.statistic_point import StatisticPoint
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.tensor import Tensor

Expand Down Expand Up @@ -55,10 +57,20 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:

@staticmethod
@abstractmethod
def target_point(target_node_name: str, port_id: int) -> TargetPoint:
def pre_layer_target_type() -> TargetType:
"""
Returns backend-specific pre layer target type.
:returns: Backend-specific pre layer target type.
"""

@staticmethod
@abstractmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint:
"""
Returns backend-specific target point.
:param target_type: Type of the location that should be modified.
:param target_node_name: Name of the located node.
:param port_id: Port ID of the tensor for the statistics distribution.
:return: Backend-specific TargetPoint.
Expand Down Expand Up @@ -184,10 +196,20 @@ def get_weight_channel_axis(node: NNCFNode) -> int:

@staticmethod
@abstractmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
pass
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
"""
Returns true if given node shares constant with a different node.
:param node: NNCFNode instance.
:param nncf_graph: NNCFGraph instance.
:return: Whether the given node is shares weights with a different node or not.
"""

@staticmethod
@abstractmethod
def get_filter_fn_for_statistics(activation_port_id: int):
pass
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
"""
Returns backend-specific callable to filter statistic containers according to its statistic point.
:param activation_port_id: Activation port id for the statistic collection target node.
"""
22 changes: 11 additions & 11 deletions nncf/quantization/algorithms/smooth_quant/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import Callable, List, Tuple

import numpy as np
import openvino.runtime as ov
Expand Down Expand Up @@ -52,8 +52,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
return QUANTIZE_AGNOSTIC_OPERATIONS

@staticmethod
def target_point(target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id)
def pre_layer_target_type() -> TargetType:
return TargetType.PRE_LAYER_OPERATION

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint:
return OVTargetPoint(target_type, target_node_name, port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
Expand Down Expand Up @@ -92,23 +96,19 @@ def get_abs_max_channel_collector(

@staticmethod
def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> Tensor:
port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight)
return Tensor(get_weight_value(node_with_weight, model, port_id))

@staticmethod
def get_weight_tensor_port_id(node: NNCFNode) -> int:
return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node)

@staticmethod
def _get_weight_tensor_port_id(node: NNCFNode) -> int:
const_ids = node.layer_attributes.get_const_port_ids()
if len(const_ids) != 1:
raise RuntimeError(f"Found more than 1 port for {node.node_name} node")
return const_ids[0]

@staticmethod
def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight)
weight_port_id = OVSmoothQuantAlgoBackend.get_weight_tensor_port_id(node_with_weight)
return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id)

@staticmethod
Expand Down Expand Up @@ -154,13 +154,13 @@ def _calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int:
return -2 + port_id if transpose else -1 - port_id

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node)
weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node
return len(nncf_graph.get_next_nodes(weight_node)) > 1

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int):
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return point.target_point.port_id == activation_port_id

Expand Down
16 changes: 10 additions & 6 deletions nncf/quantization/algorithms/smooth_quant/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple
from typing import Callable, List, Tuple

import numpy as np

Expand Down Expand Up @@ -63,8 +63,12 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT[QuantizationTrait.QUANTIZATION_AGNOSTIC]

@staticmethod
def target_point(target_node_name: str, port_id: int) -> PTTargetPoint:
return PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node_name, input_port_id=port_id)
def pre_layer_target_type() -> TargetType:
return TargetType.OPERATOR_PRE_HOOK

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
return PTTargetPoint(target_type, target_node_name, input_port_id=port_id)

@staticmethod
def is_node_with_weights(node: NNCFNode) -> bool:
Expand Down Expand Up @@ -92,7 +96,7 @@ def get_abs_max_channel_collector(
def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> Tensor:
node_module = model.nncf.get_containing_module(node_with_weight.node_name)
if node_module.weight is None:
return None
raise RuntimeError(f"{node_module} module has no .weight attribute.")
return Tensor(node_module.weight.data)

@staticmethod
Expand Down Expand Up @@ -130,11 +134,11 @@ def get_weight_channel_axis(node: NNCFNode) -> int:
return 1

@staticmethod
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph):
def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph) -> bool:
return node.is_shared()

@staticmethod
def get_filter_fn_for_statistics(activation_port_id: int):
def get_filter_fn_for_statistics(activation_port_id: int) -> Callable[[StatisticPoint], bool]:
def filter_func(point: StatisticPoint) -> bool:
return True

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self):

self._post_hooks = defaultdict(OrderedDict)
self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(OrderedDict)
self._hooks_counter = 0
self._hooks_counter = -1

self._threading = CopySafeThreadingVars()

Expand Down
32 changes: 1 addition & 31 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from enum import Enum
from enum import IntEnum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar

import torch
from torch import nn
Expand Down Expand Up @@ -364,25 +364,6 @@ def reset_nncf_modules(self):
module = self.get_module_by_scope(some_scope)
module.reset()

def get_shallow_copy(self) -> "NNCFNetwork":
from nncf.torch.utils import load_module_state
from nncf.torch.utils import save_module_state

saved_state = save_module_state(self._model_ref)
new_interface = NNCFNetworkInterface(
self._model_ref,
self._input_infos,
self._user_dummy_forward_fn,
self._wrap_inputs_fn,
self._scopes_without_shape_matching,
self._ignored_scopes,
self._target_scopes,
wrap_outputs_fn=self._wrap_outputs_fn,
)
self._model_ref._nncf = new_interface
load_module_state(self._model_ref, saved_state)
return self._model_ref

def get_clean_shallow_copy(self) -> "NNCFNetwork":
# WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions
# and load_nncf_module_additions to preserve these, or temporary_clean_view().
Expand Down Expand Up @@ -414,9 +395,6 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc
retval[nncf_module_scope + relative_scope] = target_module
return retval

def update_model_ref(self, model: torch.nn.Module) -> None:
object.__setattr__(self, "__model_ref", model)

def temporary_insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]):
hook_addresses = self.insert_at_point(point, fn_list)
self._temprorary_hooks_adresses.append(hook_addresses)
Expand Down Expand Up @@ -832,14 +810,6 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork":
return self.compression_controller.strip(do_copy)


class TemporaryOp:
def __init__(self, op: Callable) -> None:
self._op = op

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self._op(*args, **kwargs)


class NNCFNetworkMeta(type):
"""
Metaclass for the NNCFNetwork mixin. Has magic methods defined so that the original model object could be
Expand Down
1 change: 0 additions & 1 deletion nncf/torch/quantization/external_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from nncf.torch.quantization.debug_interface import QuantizationDebugInterface

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
EXTERNAL_OP_STORAGE_NAME = "external_op"
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME


Expand Down
33 changes: 0 additions & 33 deletions nncf/torch/statistics/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from typing import Dict

import numpy as np
Expand All @@ -27,41 +26,9 @@
from nncf.torch.tensor_statistics.algo import create_register_input_hook


class ModelView:
def __init__(self, model: NNCFNetwork):
self.model = model
self.nncf_module_additions = self.model.nncf.save_nncf_module_additions()

def __enter__(self):
# Model ref removed to prevent copying
self.model.nncf.update_model_ref(None)

# nncf_replaced_models removed to prevent copying
replaced_modules = self.model.nncf._nncf_replaced_modules
self.model.nncf._nncf_replaced_modules = None

self.nncf_interface = deepcopy(self.model.nncf)

# Model ref is recovering
self.model.nncf.update_model_ref(self.model)
self.nncf_interface.update_model_ref(self.model)

# nncf_replaced_models is recovering
self.model.nncf._nncf_replaced_modules = replaced_modules
self.nncf_interface._nncf_replaced_modules = replaced_modules
return self.model

def __exit__(self, exc_type, exc_val, exc_tb):
self.model._nncf = self.nncf_interface
self.model.nncf.reset_nncf_modules()
self.model.nncf.load_nncf_module_additions(self.nncf_module_additions)


class PTStatisticsAggregator(StatisticsAggregator):
def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None:
with torch.no_grad():
# with ModelView(model) as intermediate_model:
# super().collect_statistics(intermediate_model, graph)
super().collect_statistics(model, graph)
model.nncf.remove_temporary_ops()

Expand Down
7 changes: 1 addition & 6 deletions nncf/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch

from nncf.common.tensor import NNCFTensor
from nncf.torch.return_types import maybe_unwrap_from_torch_return_type


class PTNNCFTensor(NNCFTensor):
"""
A realisation of torch tensors wrapper for common NNCF algorithms.
"""

def __init__(self, tensor: Union[torch.tensor, "PTNNCFTensor", tuple]):
def __init__(self, tensor: torch.tensor):
# In case somebody attempts to wrap
# tensor twice
if isinstance(tensor, self.__class__):
tensor = tensor.tensor
else:
tensor = maybe_unwrap_from_torch_return_type(tensor)

super().__init__(tensor)

Expand Down
Loading

0 comments on commit 08661a5

Please sign in to comment.