diff --git a/nncf/common/graph/transformations/commands.py b/nncf/common/graph/transformations/commands.py index fa26e587ec7..11734b1f152 100644 --- a/nncf/common/graph/transformations/commands.py +++ b/nncf/common/graph/transformations/commands.py @@ -35,6 +35,7 @@ class TransformationPriority(IntEnum): FP32_TENSOR_STATISTICS_OBSERVATION = 1 PRUNING_PRIORITY = 2 SPARSIFICATION_PRIORITY = 3 + OP_INSERTION_PRIORITY = 4 QUANTIZATION_PRIORITY = 11 diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 9a065279f93..ed2e0564db8 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -29,7 +29,8 @@ from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype -from nncf.common.graph.transformations.commands import TargetType + +# from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track @@ -77,7 +78,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.OPENVINO] + return [BackendType.OPENVINO, BackendType.TORCH] def _set_backend_entity(self, model: TModel) -> None: """ @@ -90,6 +91,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend self._backend_entity = OVSmoothQuantAlgoBackend() + elif model_backend == BackendType.TORCH: + from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend + + self._backend_entity = PTSmoothQuantAlgoBackend() else: raise RuntimeError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) @@ -123,11 +128,11 @@ def apply( if any(val is None for val in activations_value): empty_statistic = True break - activations_value = self._backend_entity.clip_statistics(activations_value) + assert len(activations_value) == 1 + activations_value = self._backend_entity.clip_statistics(activations_value[0]) - weight_port = self._backend_entity.get_weight_tensor_port_id(node_to_smooth) - weight_value = self._backend_entity.get_weight_value(node_to_smooth, model, weight_port) - weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, weight_port) + weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) + weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, graph) weight_statistics = self._backend_entity.clip_statistics(weight_statistics) alpha = alpha_map[node_to_smooth.metatype] @@ -153,13 +158,12 @@ def apply( continue for node_to_smooth in nodes: - weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth) - weight_port = self._backend_entity.get_weight_tensor_port_id(node_to_smooth) - weight_value = self._backend_entity.get_weight_value(node_to_smooth, model, weight_port) + weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) + weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value, graph) + ### TODO: DO it as NNCFTensor op scaled_weight = weight_value * weights_scale - weight_update_command = self._backend_entity.weight_update_command( - node_to_smooth, scaled_weight, weight_port - ) + ### + weight_update_command = self._backend_entity.weight_update_command(node_to_smooth, scaled_weight) transformation_layout.register(weight_update_command) activations_shape = graph.get_output_edges(source_node)[source_output_port_id].tensor_shape @@ -208,16 +212,11 @@ def _get_statistics_for_node( :return: List of the TTensor instances. """ - def filter_func(point: StatisticPoint) -> bool: - return ( - self._algorithm_key in point.algorithm_to_tensor_collectors - and point.target_point.type == TargetType.PRE_LAYER_OPERATION - and point.target_point.port_id == act_port - ) - statistics_for_node = [] for tensor_collector in statistic_points.get_algo_statistics_for_node( - node_name, filter_func, self._algorithm_key + node_name, + self._backend_entity.get_filter_fn_for_statistics(act_port), + self._algorithm_key, ): statistics_for_node.append(tensor_collector.get_statistics()[STATISTIC_BRANCH_KEY]) return statistics_for_node @@ -233,7 +232,6 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin for node_data in nodes_to_smooth_data: node_to_smooth = node_data["node_to_smooth"] target_point = self._backend_entity.target_point( - TargetType.PRE_LAYER_OPERATION, target_node_name=node_to_smooth.node_name, port_id=node_data["input_act_port"], ) @@ -267,27 +265,26 @@ def _get_nodes_to_smooth_data(self, nncf_graph: NNCFGraph, node_metatypes: List[ if not self._backend_entity.is_node_with_weights(node_with_weight): continue - ports_map = self._backend_entity.get_input_ports_map(node_with_weight, nncf_graph) + activation_port_id = self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph) input_edges = nncf_graph.get_input_edges(node_with_weight) - weight_node = input_edges[ports_map["weight"]].from_node - activation_node = input_edges[ports_map["activation"]].from_node + activation_node = input_edges[activation_port_id].from_node # Skipping agnostic layers as inputs to propagate quantizer # Only for Convolution layers if ( - node_with_weight.metatype == self._backend_entity.convolution_metatype + node_with_weight.metatype in self._backend_entity.convolution_metatypes and activation_node.metatype in self._backend_entity.quantize_agnostic_metatypes ): continue # Skipping shared weights - if len(nncf_graph.get_next_nodes(weight_node)) > 1: + if self._backend_entity.is_node_with_shared_weight(node_with_weight, nncf_graph): continue nodes_to_smooth_data.append( { "node_to_smooth": node_with_weight, - "input_act_port": ports_map["activation"], + "input_act_port": self._backend_entity.get_activations_port_id(node_with_weight, nncf_graph), } ) return nodes_to_smooth_data @@ -303,11 +300,10 @@ def _calculate_activation_scale( :param nodes: List of consumers for Smooth node. :return: Calculated per-channel activation scale. """ - activation_ports_map = { - node: self._backend_entity.get_input_ports_map(node, nncf_graph)["activation"] for node in nodes - } + activation_ports_map = {node: self._backend_entity.get_activations_port_id(node, nncf_graph) for node in nodes} channel_axes = [ - self._backend_entity.get_activation_channel_axis(node, port) for node, port in activation_ports_map.items() + self._backend_entity.get_activation_channel_axis(node, port, activations_shape) + for node, port in activation_ports_map.items() ] channel_axis = channel_axes[0] @@ -317,7 +313,9 @@ def _calculate_activation_scale( activations_size = len(activations_shape) return self._backend_entity.calculate_activation_scale(scale_value, activations_size, channel_axis) - def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode) -> TTensor: + def _calculate_weight_scale( + self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor, graph: NNCFGraph + ) -> TTensor: """ Calculates scale for weight tensor. @@ -325,10 +323,9 @@ def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode) -> TTens :param node: Consumer for Smooth node. :return: Calculated scale for weights. """ - port_id = self._backend_entity.get_weight_tensor_port_id(node) - weights_size = len(node.layer_attributes.constant_attributes[port_id]["shape"]) + weights_size = len(weights_value.shape) if weights_size > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id) + channel_axis = self._backend_entity.get_weight_channel_axis(node, graph) return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis) return scale_value @@ -344,11 +341,11 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode, shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape reduction_axes = tuple([]) if len(shape) > 1: - channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port) + channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port, shape) reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape) return reduction_axes - def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id: int) -> TTensor: + def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NNCFGraph) -> TTensor: """ Returns processed weight statistics for node. @@ -359,7 +356,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, port_id: """ channel_axis = 0 if len(weights.shape) > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, port_id) + channel_axis = self._backend_entity.get_weight_channel_axis(node, graph) reduction_shape = [i for i, _ in enumerate(weights.shape)] reduction_shape.pop(channel_axis) return self._backend_entity.process_weight_statistics(weights, tuple(reduction_shape)) @@ -385,8 +382,8 @@ def _get_alpha_map(self) -> Dict[OperatorMetatype, float]: """ alpha_by_metatype_map = {} name_to_metatype = { - "convolution": self._backend_entity.convolution_metatype, - "matmul": self._backend_entity.matmul_metatype, + "convolution": self._backend_entity.convolution_metatypes, + "matmul": self._backend_entity.matmul_metatypes, } for type_name, alpha_value in self._alpha_map.items(): if alpha_value < 0: @@ -395,6 +392,7 @@ def _get_alpha_map(self) -> Dict[OperatorMetatype, float]: "Skipping these layers." ) continue - metatype = name_to_metatype[type_name] - alpha_by_metatype_map[metatype] = alpha_value + metatypes = name_to_metatype[type_name] + for metatype in metatypes: + alpha_by_metatype_map[metatype] = alpha_value return alpha_by_metatype_map diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index 9fab9178851..d015eb73a45 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -11,13 +11,12 @@ from abc import ABC from abc import abstractmethod -from typing import Dict, List, Optional, Tuple, TypeVar +from typing import List, Optional, Tuple, TypeVar from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetPoint -from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand from nncf.experimental.common.tensor_statistics.collectors import TensorCollector @@ -28,20 +27,20 @@ class SmoothQuantAlgoBackend(ABC): @property @abstractmethod - def convolution_metatype(self) -> OperatorMetatype: + def convolution_metatypes(self) -> List[OperatorMetatype]: """ - Parameter for backend-specific metatype for Convolution. + Parameter for backend-specific metatypes for Convolution. - :return: OperatorMetatype + :return: OperatorMetatype list. """ @property @abstractmethod - def matmul_metatype(self) -> OperatorMetatype: + def matmul_metatypes(self) -> List[OperatorMetatype]: """ - Parameter for backend-specific metatype for MatMul. + Parameter for backend-specific metatypes for MatMul. - :return: OperatorMetatype + :return: OperatorMetatype list. """ @property @@ -55,11 +54,10 @@ def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: @staticmethod @abstractmethod - def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> TargetPoint: + def target_point(TargetType, target_node_name: str, port_id: int) -> TargetPoint: """ Returns backend-specific target point. - :param target_type: Type of the location that should be modified. :param target_node_name: Name of the located node. :param port_id: Port ID of the tensor for the statistics distribution. :return: Backend-specific TargetPoint. @@ -77,7 +75,7 @@ def is_node_with_weights(node: NNCFNode) -> bool: @staticmethod @abstractmethod - def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]: + def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: """ Returns map with activation & weighted ports. @@ -210,21 +208,21 @@ def weight_update_command( @staticmethod @abstractmethod def scale_insertion_command( - source_node: NNCFNode, scale_value: TTensor, port_id: int, nodes: List[NNCFNode] + source_node: NNCFNode, scale_value: TTensor, source_output_port_id: int, nodes: List[NNCFNode] ) -> TransformationCommand: """ Returns command to insert Smooth Quant node. :param source_node: NNCFNode instance. :param scale_value: Smooth Quant value. - :param port_id: Output port for source node. + :param source_output_port_id: Output port for source node. :param nodes: List of consumers for Smooth node. :return: TransformationCommand instance. """ @staticmethod @abstractmethod - def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: """ Returns axis number of the activation tensor which correspond to it channel. @@ -235,7 +233,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: @staticmethod @abstractmethod - def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: + def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: """ Returns axis number of the weight tensor which correspond to it channel. @@ -254,3 +252,13 @@ def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: :param transpose: Transpose position. :return: Channel axis. """ + + @staticmethod + @abstractmethod + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + pass + + @staticmethod + @abstractmethod + def get_filter_fn_for_statistics(activation_port_id: int): + pass diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index 0d7f9501df5..ca797899d6b 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -18,6 +18,7 @@ from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.metatypes.groups import QUANTIZE_AGNOSTIC_OPERATIONS @@ -36,27 +37,27 @@ class OVSmoothQuantAlgoBackend(SmoothQuantAlgoBackend): @property - def convolution_metatype(self) -> OperatorMetatype: - return OVConvolutionMetatype + def convolution_metatypes(self) -> List[OperatorMetatype]: + return [OVConvolutionMetatype] @property - def matmul_metatype(self) -> OperatorMetatype: - return OVMatMulMetatype + def matmul_metatypes(self) -> List[OperatorMetatype]: + return [OVMatMulMetatype] @property def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: return QUANTIZE_AGNOSTIC_OPERATIONS @staticmethod - def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: - return OVTargetPoint(target_type, target_node_name, port_id) + def target_point(target_node_name: str, port_id: int) -> OVTargetPoint: + return OVTargetPoint(TargetType.PRE_LAYER_OPERATION, target_node_name, port_id) @staticmethod def is_node_with_weights(node: NNCFNode) -> bool: return node.layer_attributes and node.layer_attributes.constant_attributes @staticmethod - def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]: + def _get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]: weight_ports = node.layer_attributes.get_const_port_ids() activation_ports = [ e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in weight_ports @@ -67,6 +68,10 @@ def get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int] return {"activation": activation_ports[0], "weight": weight_ports[0]} + @staticmethod + def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + return OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["activation"] + @staticmethod def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]: return get_channel_agnostic_reduction_axes([channel_axis], shape) @@ -86,11 +91,16 @@ def process_weight_statistics(weights: np.ndarray, reduction_shape: Tuple[int]) return np.max(np.abs(weights), axis=reduction_shape) @staticmethod - def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) -> np.ndarray: + def get_weight_value(node_with_weight: NNCFNode, model: ov.Model) -> np.ndarray: + port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) return get_weight_value(node_with_weight, model, port_id) @staticmethod def get_weight_tensor_port_id(node: NNCFNode) -> int: + return OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node) + + @staticmethod + def _get_weight_tensor_port_id(node: NNCFNode) -> int: const_ids = node.layer_attributes.get_const_port_ids() if len(const_ids) != 1: raise RuntimeError(f"Found more than 1 port for {node.node_name} node") @@ -134,19 +144,24 @@ def calculate_weight_scale(scale_value: np.ndarray, weights_size: int, channel_a return weight_scale @staticmethod - def weight_update_command( - node_with_weight: NNCFNode, weight_value: np.ndarray, weight_port_id: int - ) -> OVWeightUpdateCommand: + def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand: + weight_port_id = OVSmoothQuantAlgoBackend._get_weight_tensor_port_id(node_with_weight) return OVCommandCreator.create_command_to_update_weight(node_with_weight, weight_value, weight_port_id) @staticmethod def scale_insertion_command( - source_node: NNCFNode, scale_value: np.ndarray, port_id: int, nodes: List[NNCFNode], scale_node_name: str + source_node: NNCFNode, + scale_value: np.ndarray, + source_output_port_id: int, + nodes: List[NNCFNode], + scale_node_name: str, ) -> OVMultiplyInsertionCommand: - return OVCommandCreator.multiply_insertion_command(source_node, nodes, port_id, scale_value, scale_node_name) + return OVCommandCreator.multiply_insertion_command( + source_node, nodes, source_output_port_id, scale_value, scale_node_name + ) @staticmethod - def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: channel_axis = 1 if port_id > 1: @@ -164,11 +179,9 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: return channel_axis @staticmethod - def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: - channel_axis = 1 - - if port_id > 1: - raise RuntimeError(f"{node.metatype.name} can not take more than 2 input tensors.") + def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + port_id = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["weight"] + channel_axis = 1 if node.metatype.const_channel_axis is None else node.metatype.const_channel_axis[0] if port_id not in node.layer_attributes.constant_attributes: raise RuntimeError(f"{node.node_name} should contain {port_id} in the attributes map.") @@ -183,3 +196,17 @@ def get_weight_channel_axis(node: NNCFNode, port_id: int) -> int: @staticmethod def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: return -2 + port_id if transpose else -1 - port_id + + @staticmethod + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + ports_map = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph) + weight_node = nncf_graph.get_input_edges(node)[ports_map["weight"]].from_node + # Skipping shared weights + return len(nncf_graph.get_next_nodes(weight_node)) > 1 + + @staticmethod + def get_filter_fn_for_statistics(activation_port_id: int): + def filter_func(point: StatisticPoint) -> bool: + return point.target_point.port_id == activation_port_id + + return filter_func diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py new file mode 100644 index 00000000000..57e81c72e72 --- /dev/null +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import numpy as np +import torch + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.quantization.quantizer_propagation.structs import QuantizationTrait +from nncf.common.tensor_statistics.statistic_point import StatisticPoint +from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.openvino.graph.node_utils import get_channel_agnostic_reduction_axes +from nncf.openvino.graph.transformations.commands import OVMultiplyInsertionCommand +from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand +from nncf.quantization.algorithms.smooth_quant.backend import SmoothQuantAlgoBackend +from nncf.torch.graph.transformations.command_creation import create_command_to_update_weight +from nncf.torch.graph.transformations.command_creation import multiply_insertion_command +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT +from nncf.torch.tensor_statistics.collectors import PTAbsMaxReducer +from nncf.torch.tensor_statistics.collectors import PTNNCFCollectorTensorProcessor + + +class PTSmoothQuantAlgoBackend(SmoothQuantAlgoBackend): + @property + def convolution_metatypes(self) -> List[OperatorMetatype]: + return [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTModuleConv1dMetatype, + om.PTModuleConv2dMetatype, + om.PTModuleConv3dMetatype, + om.PTDepthwiseConv1dSubtype, + om.PTDepthwiseConv2dSubtype, + om.PTDepthwiseConv3dSubtype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + ] + + @property + def matmul_metatypes(self) -> List[OperatorMetatype]: + return [om.PTMatMulMetatype, om.PTLinearMetatype, om.PTModuleLinearMetatype] + + @property + def quantize_agnostic_metatypes(self) -> List[OperatorMetatype]: + return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT[QuantizationTrait.QUANTIZATION_AGNOSTIC] + + @staticmethod + def target_point(target_node_name: str, port_id: int) -> PTTargetPoint: + return PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node_name, input_port_id=port_id) + + @staticmethod + def is_node_with_weights(node: NNCFNode) -> bool: + return node.layer_attributes is not None + + @staticmethod + def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + return 0 + + @staticmethod + def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]: + return get_channel_agnostic_reduction_axes([channel_axis], shape) + + @staticmethod + def get_abs_max_channel_collector( + num_samples: int, stats_reduction_axes: Tuple[int], inplace: bool, branch_key: str + ) -> TensorCollector: + collector = TensorCollector() + reducer = PTAbsMaxReducer(reduction_axes=stats_reduction_axes) + aggregator = MaxAggregator(tensor_processor=PTNNCFCollectorTensorProcessor, num_samples=num_samples) + collector.register_statistic_branch(branch_key, reducer, aggregator) + return collector + + @staticmethod + def process_weight_statistics(weights: np.ndarray, channel_axis: int) -> np.ndarray: + return torch.amax(torch.abs(weights), dim=channel_axis) + + @staticmethod + def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork) -> np.ndarray: + node_module = model.nncf.get_containing_module(node_with_weight.node_name) + if node_module.weight is None: + return None + return node_module.weight.data + + @staticmethod + def get_weight_tensor_port_id(node: NNCFNode) -> int: + # Should be refactored + const_ids = node.layer_attributes.get_const_port_ids() + if len(const_ids) != 1: + raise RuntimeError(f"Found more than 1 port for {node.node_name} node") + return const_ids[0] + + @staticmethod + def clip_statistics(statistics: torch.Tensor) -> np.ndarray: + a_min = 1e-5 + squeezed = torch.squeeze(statistics) + return torch.clip(squeezed, min=a_min, max=None) + + @staticmethod + def calculate_scale_and_ratio( + activations: np.ndarray, weights: np.ndarray, alpha: float, quantile: Optional[float] = 0.1 + ) -> np.ndarray: + scales = torch.pow(activations, alpha) / (torch.pow(weights, 1 - alpha) + torch.finfo(float).eps) + + a_min = torch.quantile(scales, quantile) + a_max = 1e2 + + scales = torch.clip(scales, min=a_min, max=a_max) + ratio = scales.min() / (scales.max() + torch.finfo(float).eps) + return scales, ratio + + @staticmethod + def calculate_activation_scale(scale_value: np.ndarray, activations_size: int, channel_axis: int) -> np.ndarray: + activation_scale = scale_value ** (-1) + if activations_size > 1: + reshape_shape = np.ones(activations_size, dtype=np.int64).tolist() + reshape_shape[channel_axis] = activation_scale.size()[0] + activation_scale = torch.reshape(activation_scale, reshape_shape) + return activation_scale + + @staticmethod + def calculate_weight_scale(scale_value: np.ndarray, weights_size: int, channel_axis: int) -> np.ndarray: + weight_scale = scale_value + if weights_size > 1: + reshape_shape = np.ones(weights_size, dtype=np.int64).tolist() + reshape_shape[channel_axis] = scale_value.size()[0] + weight_scale = torch.reshape(scale_value, reshape_shape) + return weight_scale + + @staticmethod + def weight_update_command(node_with_weight: NNCFNode, weight_value: np.ndarray) -> OVWeightUpdateCommand: + return create_command_to_update_weight(node_with_weight, weight_value) + + @staticmethod + def scale_insertion_command( + source_node: NNCFNode, + scale_value: np.ndarray, + source_output_port_id: int, + nodes: List[NNCFNode], + scale_node_name: str, + ) -> OVMultiplyInsertionCommand: + input_port_id = 0 + return multiply_insertion_command(nodes, scale_value, scale_node_name, input_port_id) + + @staticmethod + def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: + return len(activations_shape) - 1 + + @staticmethod + def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + return 1 + + @staticmethod + def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: + return -2 + port_id if transpose else -1 - port_id + + @staticmethod + def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): + return node.is_shared() + + @staticmethod + def get_filter_fn_for_statistics(activation_port_id: int): + def filter_func(point: StatisticPoint) -> bool: + return True + + return filter_func diff --git a/nncf/torch/dynamic_graph/context.py b/nncf/torch/dynamic_graph/context.py index d57e042af1e..ab35b8db9b9 100644 --- a/nncf/torch/dynamic_graph/context.py +++ b/nncf/torch/dynamic_graph/context.py @@ -11,6 +11,7 @@ import threading import weakref +from collections import defaultdict from collections import deque from contextlib import contextmanager from typing import Callable, Dict, List, Optional @@ -91,8 +92,8 @@ def __init__(self): self.graph = DynamicGraph() self._save_context = None - self._post_hooks = {} - self._pre_hooks: Dict[PreHookId, List[Callable]] = {} + self._post_hooks = defaultdict(list) + self._pre_hooks: Dict[PreHookId, List[Callable]] = defaultdict(list) self._num_nested_hooks = 0 self._threading = CopySafeThreadingVars() @@ -257,9 +258,7 @@ def pop_scope(self): def register_pre_hooks(self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int): pre_hook_id = PreHookId(op_address, input_port_id) - if pre_hook_id in self._pre_hooks: - raise KeyError("Pre hook for context {} is already registered".format(str(pre_hook_id))) - self._pre_hooks[pre_hook_id] = fn_list + self._pre_hooks[pre_hook_id].extend(fn_list) def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInput) -> OperatorInput: in_op = getattr(self, "in_operator", False) @@ -278,9 +277,7 @@ def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInp return op_inputs def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress): - if op_address in self._post_hooks: - raise KeyError("Post hook for context {} is already registered".format(str(op_address))) - self._post_hooks[op_address] = fn_list + self._post_hooks[op_address].extend(fn_list) def execute_post_hooks(self, op_address: OperationAddress, outputs): in_op = getattr(self, "in_operator", False) diff --git a/nncf/torch/dynamic_graph/layer_attributes_handlers.py b/nncf/torch/dynamic_graph/layer_attributes_handlers.py index 151645181b0..0b839b8f18c 100644 --- a/nncf/torch/dynamic_graph/layer_attributes_handlers.py +++ b/nncf/torch/dynamic_graph/layer_attributes_handlers.py @@ -31,11 +31,13 @@ from nncf.common.graph.layer_attributes import PermuteLayerAttributes from nncf.common.graph.layer_attributes import ReshapeLayerAttributes from nncf.common.graph.layer_attributes import TransposeLayerAttributes +from nncf.common.graph.layer_attributes import MatMulLayerAttributes from nncf.common.graph.utils import get_concat_axis from nncf.common.graph.utils import get_split_axis from nncf.torch.graph.operator_metatypes import PTCatMetatype from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype from nncf.torch.graph.operator_metatypes import PTPadMetatype +from nncf.torch.graph.operator_metatypes import PTMatMulMetatype from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.operator_metatypes import PTSqueezeMetatype diff --git a/nncf/torch/dynamic_graph/wrappers.py b/nncf/torch/dynamic_graph/wrappers.py index ee0a74355bb..144f5d68238 100644 --- a/nncf/torch/dynamic_graph/wrappers.py +++ b/nncf/torch/dynamic_graph/wrappers.py @@ -28,6 +28,8 @@ from nncf.torch.dynamic_graph.trace_tensor import trace_tensors from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.layers import ITERATION_MODULES +from nncf.torch.return_types import maybe_unwrap_from_torch_return_type +from nncf.torch.return_types import maybe_wrap_to_torch_return_type _IGNORED_SCOPES = [] @@ -188,8 +190,10 @@ def _execute_op( if is_debug() and node is not None: ctx.register_node_call(node) - result = trace_tensors(result, node, ctx) - result = ctx.execute_post_hooks(op_address, result) + unwrapped_result = maybe_unwrap_from_torch_return_type(result) + unwrapped_result = trace_tensors(unwrapped_result, node, ctx) + unwrapped_result = ctx.execute_post_hooks(op_address, unwrapped_result) + result = maybe_wrap_to_torch_return_type(unwrapped_result, result) return result diff --git a/nncf/torch/graph/graph_builder.py b/nncf/torch/graph/graph_builder.py index 1a6f51bfae7..e241608640e 100644 --- a/nncf/torch/graph/graph_builder.py +++ b/nncf/torch/graph/graph_builder.py @@ -27,11 +27,19 @@ class GraphBuilder: def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]): self.custom_forward_fn = custom_forward_fn + def build_dynamic_graph( + self, + model: torch.nn.Module, + context_to_use: Optional[TracingContext] = None, + as_eval: bool = False, + ) -> DynamicGraph: + tracer = GraphTracer(self.custom_forward_fn) + return tracer.trace_graph(model, context_to_use, as_eval) + def build_graph( self, model: torch.nn.Module, context_to_use: Optional[TracingContext] = None, as_eval: bool = False ) -> PTNNCFGraph: - tracer = GraphTracer(self.custom_forward_fn) - dynamic_graph = tracer.trace_graph(model, context_to_use, as_eval) + dynamic_graph = self.build_dynamic_graph(model, context_to_use, as_eval) return GraphConverter.convert(dynamic_graph) diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 8408c92aa4c..42fc5d0cb41 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -9,12 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +import torch from torch import Tensor from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority from nncf.torch.graph.transformations.commands import PTBiasCorrectionCommand +from nncf.torch.graph.transformations.commands import PTInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> PTBiasCorrectionCommand: @@ -27,3 +34,34 @@ def create_bias_correction_command(node: NNCFNode, bias_value: Tensor) -> PTBias """ target_point = PTTargetPoint(TargetType.LAYER, node.node_name) return PTBiasCorrectionCommand(target_point, bias_value) + + +def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTWeightUpdateCommand: + """ + Creates weight update command. + + :param node: The node in the NNCF graph that corresponds to operation with weight. + :param weight_value: The new weight value that will be set. + :return: The `PTWeightUpdateCommand` command to update weight. + """ + target_point = PTTargetPoint(TargetType.LAYER, node.node_name) + return PTWeightUpdateCommand(target_point, weight_value) + + +def multiply_insertion_command( + target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int +) -> PTInsertionCommand: + commands = [] + for target_node in target_nodes: + target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id) + commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY)) + + class SQMultiply(torch.nn.Module): + def __init__(self, scale_value): + super().__init__() + self._scale_value = scale_value + + def forward(self, x): + return torch.mul(x, self._scale_value) + + return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name) diff --git a/nncf/torch/graph/transformations/commands.py b/nncf/torch/graph/transformations/commands.py index 5f873539c29..2b691aa47eb 100644 --- a/nncf/torch/graph/transformations/commands.py +++ b/nncf/torch/graph/transformations/commands.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict +from typing import Any, Callable, Dict, List import torch @@ -154,8 +154,36 @@ def requires_graph_rebuild(self): :return: Boolean flag. """ - # Rebuild graph when adding quantization nodes. - return self.priority == TransformationPriority.QUANTIZATION_PRIORITY + # Rebuild graph when adding quantization nodes or an op. + return self.priority in [ + TransformationPriority.QUANTIZATION_PRIORITY, + TransformationPriority.OP_INSERTION_PRIORITY, + ] + + +class PTSharedFnInsertionCommand(PTTransformationCommand): + def __init__( + self, + target_commands: List[PTInsertionCommand], + fn: Callable, + op_unique_name: str, + ): + super().__init__(TransformationType.INSERT, None) + self.target_commands = target_commands + self.fn = fn + self.op_name = op_unique_name + + def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand": + # TODO: keep all TransformationCommands atomic, refactor TransformationLayout instead + raise NotImplementedError() + + def requires_graph_rebuild(self): + """ + Return boolean flag to rebuild graph of model. + + :return: Boolean flag. + """ + return True class PTQuantizerInsertionCommand(PTTransformationCommand): @@ -209,3 +237,20 @@ def __init__(self, target_point: PTTargetPoint, bias_value: torch.Tensor): def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand": raise NotImplementedError() + + +class PTWeightUpdateCommand(PTTransformationCommand): + """ + Corrects weight value in the model based on the input value. + """ + + def __init__(self, target_point: PTTargetPoint, weight_value: torch.Tensor): + """ + :param target_point: The TargetPoint instance for the correction that contains layer's information. + :param weight_value: The new weight value that will be used instead of the original weight value. + """ + super().__init__(TransformationType.CHANGE, target_point) + self.weight_value = weight_value + + def union(self, other: "PTTransformationCommand") -> "PTTransformationCommand": + raise NotImplementedError() diff --git a/nncf/torch/model_transformer.py b/nncf/torch/model_transformer.py index db4bf2fec72..5e2f3a73320 100644 --- a/nncf/torch/model_transformer.py +++ b/nncf/torch/model_transformer.py @@ -13,6 +13,7 @@ from collections import defaultdict from typing import Callable, Dict, List, Tuple +import torch from torch import Tensor from torch import nn @@ -24,13 +25,17 @@ from nncf.torch.graph.transformations.commands import PTInsertionCommand from nncf.torch.graph.transformations.commands import PTModelExtractionWithFusedBiasCommand from nncf.torch.graph.transformations.commands import PTQuantizerInsertionCommand +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTWeightUpdateCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout from nncf.torch.model_analyzer import get_potential_fused_node from nncf.torch.module_operations import UpdateWeight from nncf.torch.nncf_network import ExtraCompressionModuleType from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.nncf_network import PTInsertionPoint +from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME +from nncf.torch.quantization.external_quantizer import ExternalOpCallHook from nncf.torch.quantization.external_quantizer import ExternalQuantizerCallHook @@ -47,6 +52,8 @@ def __init__(self, model: NNCFNetwork): (PTInsertionCommand, self._apply_insertion_transformations), (PTQuantizerInsertionCommand, self._apply_quantizer_insertion_transformations), (PTBiasCorrectionCommand, self._apply_bias_correction_transformations), + (PTSharedFnInsertionCommand, self._apply_shared_nodes_insertion), + (PTWeightUpdateCommand, self._apply_weights_update_transformations), ] def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork: @@ -102,6 +109,29 @@ def _apply_insertion_transformations(model: NNCFNetwork, transformations: List[P return model + @staticmethod + def _apply_shared_nodes_insertion( + model: NNCFNetwork, transformations: List[PTSharedFnInsertionCommand] + ) -> NNCFNetwork: + compression_model_type = ExtraCompressionModuleType.EXTERNAL_OP + + if not model.nncf.is_compression_module_registered(compression_model_type): + model.nncf.register_compression_module_type(compression_model_type) + + insertion_commands: List[PTInsertionCommand] = [] + + for command in transformations: + op_id = ( + command.op_name + f"[{';'.join([tp.target_point.target_node_name for tp in command.target_commands])}]" + ) + model.nncf.add_compression_module(op_id, command.fn, compression_model_type) + + for command in command.target_commands: + command.fn = ExternalOpCallHook(EXTERNAL_OP_STORAGE_NAME, model.nncf.get_tracing_context(), op_id) + insertion_commands.append(command) + + return PTModelTransformer._apply_insertion_transformations(model, insertion_commands) + @staticmethod def _apply_quantizer_insertion_transformations( model: NNCFNetwork, transformations: List[PTQuantizerInsertionCommand] @@ -127,7 +157,7 @@ def _apply_quantizer_insertion_transformations( if target_point.type is not TargetType.OPERATION_WITH_WEIGHTS: quantizer_id = NonWeightQuantizerId(target_point.target_node_name, target_point.input_port_id) storage_key = str(quantizer_id) - model.nncf.add_compression_module(storage_key, transformation_command.quantizer, compression_model_type) + model.nncf.add_compression_module(storage_key, fn, compression_model_type) fn = ExternalQuantizerCallHook(model.nncf.get_tracing_context(), storage_key) insertion_commands.append( @@ -169,6 +199,21 @@ def _apply_bias_correction_transformations( ) return model + @staticmethod + def _apply_weights_update_transformations( + model: NNCFNetwork, transformations: List[PTWeightUpdateCommand] + ) -> NNCFNetwork: + """ + Applies weight update transformations on the model. + + :param model: Model to apply transformations. + :param transformations: List of the weight update transformations. + :return: Model with updated weights. + """ + for transformation in transformations: + update_parameter(transformation.target_point.target_node_name, "weight", transformation.weight_value, model) + return model + def update_fused_bias(target_node_name: str, new_bias: Tensor, model: NNCFNetwork) -> None: """ @@ -182,9 +227,21 @@ def update_fused_bias(target_node_name: str, new_bias: Tensor, model: NNCFNetwor fused_node = get_potential_fused_node(target_node_name, nncf_graph) if fused_node: target_node_name = fused_node.node_name + update_parameter(target_node_name, "bias", new_bias, model) + + +def update_parameter(target_node_name: str, parameter_name: str, new_value: Tensor, model: NNCFNetwork) -> None: + """ + Update parameter for target module. - node = model.nncf.get_containing_module(target_node_name) - node.bias.data = new_bias + :param target_node_name: The target node name. + :param parmeter_name: The name of the parameter to update. + :param new_value: New parameter value. + :param model: The model. + """ + module = model.nncf.get_containing_module(target_node_name) + parameter: torch.nn.parameter.Parameter = getattr(module, parameter_name) + parameter.data = new_value def extraction_potential_fused_modules(node_name: str, model: NNCFNetwork) -> nn.Sequential: diff --git a/nncf/torch/nncf_network.py b/nncf/torch/nncf_network.py index 0cc9eda854a..b99afa46b24 100644 --- a/nncf/torch/nncf_network.py +++ b/nncf/torch/nncf_network.py @@ -63,6 +63,7 @@ from nncf.torch.layer_utils import _NNCFModuleMixin from nncf.torch.nested_objects_traversal import objwalk from nncf.torch.nncf_module_replacement import replace_modules_by_nncf_modules +from nncf.torch.quantization.external_quantizer import EXTERNAL_OP_STORAGE_NAME from nncf.torch.quantization.external_quantizer import EXTERNAL_QUANTIZERS_STORAGE_NAME from nncf.torch.utils import compute_FLOPs_hook from nncf.torch.utils import get_all_modules_by_type @@ -119,6 +120,7 @@ def __hash__(self): class ExtraCompressionModuleType(Enum): EXTERNAL_QUANTIZER = 0 + EXTERNAL_OP = 1 class NNCFNetworkInterface(torch.nn.Module): @@ -272,6 +274,7 @@ def __init__( ) self._original_graph = GraphConverter.convert(self._original_dynamic_graph) self._compressed_graph: PTNNCFGraph = None + self._compressed_traced_graph: DynamicGraph = None self._compressed_context = TracingContext() @@ -358,6 +361,25 @@ def reset_nncf_modules(self): module = self.get_module_by_scope(some_scope) module.reset() + def get_shallow_copy(self) -> "NNCFNetwork": + from nncf.torch.utils import load_module_state + from nncf.torch.utils import save_module_state + + saved_state = save_module_state(self._model_ref) + new_interface = NNCFNetworkInterface( + self._model_ref, + self._input_infos, + self._user_dummy_forward_fn, + self._wrap_inputs_fn, + self._scopes_without_shape_matching, + self._ignored_scopes, + self._target_scopes, + wrap_outputs_fn=self._wrap_outputs_fn, + ) + self._model_ref._nncf = new_interface + load_module_state(self._model_ref, saved_state) + return self._model_ref + def get_clean_shallow_copy(self) -> "NNCFNetwork": # WARNING: Will reset pre- and post-ops of the underlying model. Use save_nncf_module_additions # and load_nncf_module_additions to preserve these, or temporary_clean_view(). @@ -389,6 +411,9 @@ def get_modules_in_nncf_modules_by_type(self, class_names: List[str]) -> Dict[Sc retval[nncf_module_scope + relative_scope] = target_module return retval + def update_model_ref(self, model: torch.nn.Module) -> None: + object.__setattr__(self, "__model_ref", model) + def insert_at_point(self, point: PTInsertionPoint, fn_list: List[Callable]): if point.insertion_type == PTInsertionType.OPERATOR_PRE_HOOK: self._compressed_context.register_pre_hooks(fn_list, point.op_address, point.input_port_id) @@ -500,7 +525,8 @@ def rebuild_graph(self, *input_args): builder = GraphBuilder(dummy_forward_fn) with training_mode_switcher(self._model_ref, is_training=False): - self._compressed_graph = builder.build_graph(self._model_ref, self._compressed_context) + self._compressed_traced_graph = builder.build_dynamic_graph(self._model_ref, self._compressed_context) + self._compressed_graph = GraphConverter.convert(self._compressed_traced_graph) def is_scope_in_nncf_module_scope(self, scope: Scope) -> bool: norm_nncf_scopes = [] @@ -554,6 +580,8 @@ def _compression_module_type_to_attr_name(compression_module_type: ExtraCompress """ if compression_module_type == ExtraCompressionModuleType.EXTERNAL_QUANTIZER: return EXTERNAL_QUANTIZERS_STORAGE_NAME + if compression_module_type == ExtraCompressionModuleType.EXTERNAL_OP: + return EXTERNAL_OP_STORAGE_NAME raise RuntimeError("Unknown extra module type") def sort_compression_modules(self, compression_module_type: ExtraCompressionModuleType): @@ -732,13 +760,15 @@ def _collect_eval_op_scopes(self, model: nn.Module, dummy_forward_fn: Callable) return result def get_node_to_op_address_mapping(self) -> Dict[NNCFNodeName, OperationAddress]: - # The IDs of corresponding nodes of the original dynamic graph and original NNCF graph - # must be equal for this to work. retval = {} - for node in self._original_dynamic_graph.get_all_nodes(): + dynamic_graph = ( + self._original_dynamic_graph if self._compressed_traced_graph is None else self._compressed_traced_graph + ) + nncf_graph = self._original_graph if self._compressed_graph is None else self._compressed_graph + for node in dynamic_graph.get_all_nodes(): node_id = node.node_id op_address = node.op_exec_context.op_address - nncf_node = self._original_graph.get_node_by_id(node_id) + nncf_node = nncf_graph.get_node_by_id(node_id) retval[nncf_node.node_name] = op_address return retval diff --git a/nncf/torch/quantization/external_quantizer.py b/nncf/torch/quantization/external_quantizer.py index c027e407efa..08848f1bfaf 100644 --- a/nncf/torch/quantization/external_quantizer.py +++ b/nncf/torch/quantization/external_quantizer.py @@ -9,14 +9,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from nncf.torch.dynamic_graph.context import TracingContext from nncf.torch.quantization.debug_interface import QuantizationDebugInterface EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers" +EXTERNAL_OP_STORAGE_NAME = "external_op" EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME -class ExternalQuantizerCallHook: +class ExternalOpCallHook: + def __init__(self, storage_name, context, storage_key): + self._storage_name = storage_name + self._compressed_context = context + self._storage_key = storage_key + + def __call__(self, *args: Any, **kwargs) -> Any: + replica = self._compressed_context.base_module_thread_local_replica + storage = getattr(replica.nncf, self._storage_name) + return storage[self._storage_key](*args, **kwargs) + + +class ExternalQuantizerCallHook(ExternalOpCallHook): """ Cannot simply register the quantizer module as a callable hook, since we need to call a thread-local version of the quantizer module during base module execution. @@ -28,13 +43,10 @@ def __init__( quantizer_storage_key: str, debug_interface: QuantizationDebugInterface = None, ): - self.compressed_context = context - self.quantizer_storage_key = quantizer_storage_key + super().__init__(EXTERNAL_QUANTIZERS_STORAGE_NAME, context, quantizer_storage_key) self.debug_interface = debug_interface def __call__(self, *args, **kwargs): if self.debug_interface is not None: self.debug_interface.register_activation_quantize_call(str(self.quantizer_storage_key)) - replica = self.compressed_context.base_module_thread_local_replica - storage = getattr(replica.nncf, EXTERNAL_QUANTIZERS_STORAGE_NAME) - return storage[self.quantizer_storage_key](*args, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/nncf/torch/return_types.py b/nncf/torch/return_types.py new file mode 100644 index 00000000000..b92bf4446b0 --- /dev/null +++ b/nncf/torch/return_types.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Tuple, Type, Union + +import torch + + +def __get_supported_torch_return_types() -> Tuple[Type[object]]: + """ + Collects types from torch.return_type which can be wrapped/unwrapped by nncf. + + :return: List of types from torch.return_type which can be wrapped/unwrapped by nncf. + """ + return_type_names = [t for t in dir(torch.return_types) if not t.startswith("_") and not t.startswith("linalg")] + return_types = [getattr(torch.return_types, t_name) for t_name in return_type_names] + return_types = [t for t in return_types if hasattr(t, "values")] + return tuple(return_types) + + +_TORCH_RETURN_TYPES = __get_supported_torch_return_types() + + +def maybe_unwrap_from_torch_return_type(tensor: Any) -> torch.Tensor: + """ + Attempts to unwrap the tensor value from one of torch.return_types instantces + in case torch operation output is wrapped by a torch return_type. + + :param tensor: Torch tensor or torch return type instance to unwrap values from. + :return: Unwrapped torch tensor. + """ + if isinstance(tensor, _TORCH_RETURN_TYPES): + return tensor.values + return tensor + + +def maybe_wrap_to_torch_return_type(tensor: torch.Tensor, wrapped_input: Optional[Union[tuple, torch.Tensor]]) -> Any: + """ + Wraps tensor to wrapped_input wrapper in case wrapped_input is wrapped by a torch.return_value container. + + :param tensor: Torch tensor to wrap. + :param wrapped_tensor: Instance of the tensor before it was unwrapped. + :return: Wrapped tensor in case wrapped_input is wrapped by a torch.return_value container else the tensor. + """ + + if isinstance(wrapped_input, _TORCH_RETURN_TYPES): + return wrapped_input.__class__([tensor] + [arg for arg in wrapped_input[1:]]) + return tensor diff --git a/nncf/torch/statistics/aggregator.py b/nncf/torch/statistics/aggregator.py index 41fdc20c4fa..e369e57345c 100644 --- a/nncf/torch/statistics/aggregator.py +++ b/nncf/torch/statistics/aggregator.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import Dict import numpy as np @@ -26,10 +27,40 @@ from nncf.torch.tensor_statistics.algo import create_register_input_hook +class ModelView: + def __init__(self, model: NNCFNetwork): + self.model = model + self.nncf_module_additions = self.model.nncf.save_nncf_module_additions() + + def __enter__(self): + # Model ref removed to prevent copying + self.model.nncf.update_model_ref(None) + + # nncf_replaced_models removed to prevent copying + replaced_modules = self.model.nncf._nncf_replaced_modules + self.model.nncf._nncf_replaced_modules = None + + self.nncf_interface = deepcopy(self.model.nncf) + + # Model ref is recovering + self.model.nncf.update_model_ref(self.model) + self.nncf_interface.update_model_ref(self.model) + + # nncf_replaced_models is recovering + self.model.nncf._nncf_replaced_modules = replaced_modules + self.nncf_interface._nncf_replaced_modules = replaced_modules + return self.model + + def __exit__(self, exc_type, exc_val, exc_tb): + self.model._nncf = self.nncf_interface + self.model.nncf.reset_nncf_modules() + self.model.nncf.load_nncf_module_additions(self.nncf_module_additions) + + class PTStatisticsAggregator(StatisticsAggregator): def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: with torch.no_grad(): - with model.nncf.temporary_clean_view() as intermediate_model: + with ModelView(model) as intermediate_model: super().collect_statistics(intermediate_model, graph) def _register_statistics( diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py index 986adb46aa4..b7977ca2818 100644 --- a/nncf/torch/tensor.py +++ b/nncf/torch/tensor.py @@ -9,9 +9,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import torch from nncf.common.tensor import NNCFTensor +from nncf.torch.return_types import maybe_unwrap_from_torch_return_type class PTNNCFTensor(NNCFTensor): @@ -19,11 +22,13 @@ class PTNNCFTensor(NNCFTensor): A realisation of torch tensors wrapper for common NNCF algorithms. """ - def __init__(self, tensor: torch.tensor): + def __init__(self, tensor: Union[torch.tensor, "PTNNCFTensor", tuple]): # In case somebody attempts to wrap # tensor twice if isinstance(tensor, self.__class__): tensor = tensor.tensor + else: + tensor = maybe_unwrap_from_torch_return_type(tensor) super().__init__(tensor) diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index a8731d7d91e..7a14e335393 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -10,7 +10,7 @@ # limitations under the License. from abc import abstractmethod -from typing import Callable, Dict, TypeVar +from typing import Callable, Dict, Type, TypeVar import pytest @@ -62,7 +62,7 @@ def check_scales(model: TModel, reference_values: Dict[str, TTensor]) -> None: @staticmethod @abstractmethod - def get_backend() -> SmoothQuantAlgoBackend: + def get_backend() -> Type[SmoothQuantAlgoBackend]: """ Returns backend-specific SmoothQuantAlgoBackend. """ diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py new file mode 100644 index 00000000000..c55b86b414b --- /dev/null +++ b/tests/torch/ptq/test_smooth_quant.py @@ -0,0 +1,91 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict, Type + +import numpy as np +import openvino.runtime as ov +import pytest +import torch + +from nncf.openvino.graph.layer_attributes import OVLayerAttributes +from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend +from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype +from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype +from nncf.torch.model_creation import wrap_model +from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm + + +class TestTorchSQAlgorithm(TemplateTestSQAlgorithm): + @staticmethod + def fn_to_type(tensor) -> torch.Tensor: + return torch.tensor(tensor) + + @staticmethod + def get_transform_fn() -> Callable: + def transform_fn(data_item): + return data_item[0] + + return transform_fn + + @staticmethod + def get_backend() -> Type[PTSmoothQuantAlgoBackend]: + return PTSmoothQuantAlgoBackend() + + @staticmethod + def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model: + return wrap_model(model.eval(), torch.rand(model.INPUT_SIZE)) + + @staticmethod + def check_scales(model: torch.nn.Module, reference_values: Dict[str, np.ndarray]) -> None: + ops_list = {op.get_friendly_name(): op for op in model.get_ops()} + for ref_name, ref_value in reference_values.items(): + node = ops_list[ref_name] + const_node = node.input(1).get_source_output().get_node() + + assert const_node.get_type_name() == "Constant" + + value = const_node.data + ref_value = np.array(ref_value) + assert value.shape == ref_value.shape + assert np.all(np.isclose(value, ref_value, atol=0.0001)), f"{value} != {ref_value}" + + @pytest.mark.parametrize( + "node_metatype, layer_attributes, port_id, reference_value", + ( + (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 0, -1), + (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": True}), 0, -2), + (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 1, -2), + (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": True}), 1, -1), + (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 2, RuntimeError), + (PTModuleConv2dMetatype, OVLayerAttributes({}, inputs_attributes={}), 0, 1), + ), + ) + def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): + return super().test_get_activation_channel_axis(node_metatype, layer_attributes, port_id, reference_value) + + @pytest.mark.parametrize( + "node_metatype, layer_attributes, port_id, reference_value", + ( + (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": False}}), 1, -2), + (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": True}}), 1, -1), + (PTModuleLinearMetatype, OVLayerAttributes({0: {"transpose": False}}), 0, -1), + (PTModuleLinearMetatype, OVLayerAttributes({0: {"transpose": True}}), 0, -2), + (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": False}}), 2, RuntimeError), + (PTModuleConv2dMetatype, OVLayerAttributes({1: {}}), 1, 1), + ), + ) + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): + return super().test_get_weight_channel_axis(node_metatype, layer_attributes, port_id, reference_value) + + @staticmethod + def get_matmul_metatype(): + return PTModuleLinearMetatype diff --git a/tests/torch/test_model_transformer.py b/tests/torch/test_model_transformer.py index f31753e18ac..51fdba11f50 100644 --- a/tests/torch/test_model_transformer.py +++ b/tests/torch/test_model_transformer.py @@ -543,3 +543,56 @@ def test_quantizer_insertion_transformations(target_type, node_name, input_port_ assert hasattr(external_quantizers, ref_name) op = getattr(external_quantizers, ref_name) assert isinstance(op, BaseOp) + + +@pytest.mark.parametrize( + "target_type, node_name, input_port_id", + ( + ( + TargetType.OPERATOR_POST_HOOK, + "/nncf_model_input_0", + None, + ), + ( + TargetType.OPERATOR_PRE_HOOK, + "InsertionPointTestModel/linear_0", + 0, + ), + ( + TargetType.OPERATION_WITH_WEIGHTS, + "InsertionPointTestModel/NNCFConv2d[conv1]/conv2d_0", + None, + ), + ), +) +def test_successive_insertion_transformation(target_type, node_name, input_port_id): + model = NNCFNetwork(InsertionPointTestModel(), FillerInputInfo([FillerInputElement([1, 1, 10, 10])])) + + target_point = PTTargetPoint(target_type, node_name, input_port_id=input_port_id) + transformed_model = model + ops = [BaseOp(lambda x: x), BaseOp(lambda x: x)] + for op in ops: + command = PTInsertionCommand(target_point, op) + + model_transformer = PTModelTransformer(transformed_model) + transformation_layout = PTTransformationLayout() + transformation_layout.register(command) + transformed_model = model_transformer.transform(transformation_layout) + transformed_model.nncf.rebuild_graph() + + if target_type == TargetType.OPERATION_WITH_WEIGHTS: + pre_ops = transformed_model.conv1.pre_ops + assert len(pre_ops) == 2 + for module, op_ref in zip(pre_ops._modules.values(), ops): + assert isinstance(module, UpdateWeight) + assert module.op is op_ref + else: + if target_type == TargetType.OPERATOR_POST_HOOK: + hooks = transformed_model.nncf._compressed_context._post_hooks + else: + hooks = transformed_model.nncf._compressed_context._pre_hooks + assert len(hooks) == 1 + _, hook_ops = hooks.popitem() + assert len(hook_ops) == 2 + for hook_op, op in zip(hook_ops, ops): + assert hook_op is op diff --git a/tests/torch/test_statistics_aggregator.py b/tests/torch/test_statistics_aggregator.py index 60bfb99015a..f3a6892ffb4 100644 --- a/tests/torch/test_statistics_aggregator.py +++ b/tests/torch/test_statistics_aggregator.py @@ -17,11 +17,20 @@ from torch import nn from nncf import Dataset +from nncf.common import factory +from nncf.common.factory import NNCFGraphFactory from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.quantization.structs import QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.experimental.common.tensor_statistics.collectors import TensorReducerBase from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend +from nncf.quantization.range_estimator import RangeEstimatorParametersSet from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.model_transformer import PTInsertionCommand +from nncf.torch.module_operations import UpdateWeight from nncf.torch.statistics.aggregator import PTStatisticsAggregator from tests.common.test_statistics_aggregator import TemplateTestStatisticsAggregator from tests.torch.ptq.helpers import get_nncf_network @@ -45,6 +54,9 @@ def get_nncf_network(self): return get_nncf_network(self, INPUT_SHAPE) +MinMaxTestParameters = TemplateTestStatisticsAggregator.MinMaxTestParameters + + class TestStatisticsAggregator(TemplateTestStatisticsAggregator): def get_min_max_algo_backend_cls(self) -> Type[PTMinMaxAlgoBackend]: return PTMinMaxAlgoBackend @@ -121,3 +133,136 @@ def test_statistic_merging(self, dataset_samples, inplace_statistics): @pytest.mark.skip("Merging is not implemented yet") def test_same_collectors_different_attrs_dont_merge(self, statistics_type, test_params, dataset_samples): pass + + @pytest.mark.parametrize( + "test_parameters, ", + ( + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_PRE_HOOK, + QuantizationMode.SYMMETRIC, + False, + 256, + -256, + ), + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATION_WITH_WEIGHTS, + QuantizationMode.SYMMETRIC, + False, + 256, + -256, + ), + MinMaxTestParameters( + RangeEstimatorParametersSet.MINMAX, + TargetType.OPERATOR_POST_HOOK, + QuantizationMode.SYMMETRIC, + False, + 256, + -256, + ), + ), + ) + def test_successive_statistics_aggregation( + self, + test_parameters: MinMaxTestParameters, + dataset_samples, + is_stat_in_shape_of_scale, + inplace_statistics, + is_backend_support_custom_estimators, + ): + model = self.get_backend_model(dataset_samples) + quantizer_config = QuantizerConfig( + mode=test_parameters.quantization_mode, per_channel=test_parameters.per_channel + ) + + is_standard_estimator = test_parameters.range_estimator_params in [ + RangeEstimatorParametersSet.MINMAX, + RangeEstimatorParametersSet.MEAN_MINMAX, + ] + if not is_standard_estimator and not is_backend_support_custom_estimators: + pytest.skip("Custom estimators are not supported for this backend yet") + + ### Register operations before statistic collection + def fn(*args, **kwargs): + return args[0] * 2 + + layout = TransformationLayout() + for target_point in [test_parameters.target_type]: + target_point = self.get_target_point(target_point) + command = PTInsertionCommand(target_point, fn) + layout.register(command) + model_transformer = factory.ModelTransformerFactory.create(model) + model = model_transformer.transform(layout) + model.nncf.rebuild_graph() + + ### Register and collect statistics after inserted operations + statistics_points = StatisticPointsContainer() + for target_point in [test_parameters.target_type]: + target_point = self.get_target_point(target_point) + algorithm_name = "TestAlgo" + statistic_point = self.create_statistics_point( + model, + quantizer_config, + target_point, + len(dataset_samples), + algorithm_name, + inplace_statistics, + test_parameters.range_estimator_params, + ) + statistics_points.add_statistic_point(statistic_point) + + dataset = self.get_dataset(dataset_samples) + statistics_aggregator = self.get_statistics_aggregator(dataset) + statistics_aggregator.register_statistic_points(statistics_points) + graph = NNCFGraphFactory.create(model) + statistics_aggregator.collect_statistics(model, graph) + + def filter_func(point): + return ( + algorithm_name in point.algorithm_to_tensor_collectors and point.target_point.type == target_point.type + ) + + tensor_collectors = list( + statistics_points.get_algo_statistics_for_node(target_point.target_node_name, filter_func, algorithm_name) + ) + + ### Check values are changed because of the inserted operation + assert len(tensor_collectors) == 1 + for tensor_collector in tensor_collectors: + stat = tensor_collector.get_statistics() + # Torch and Openvino backends tensor collectors return values in shape of scale + # in comparison to ONNX backends. + ref_min_val, ref_max_val = test_parameters.ref_min_val, test_parameters.ref_max_val + if isinstance(ref_min_val, np.ndarray) and is_stat_in_shape_of_scale: + shape = (1, 3, 1, 1) + if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS: + shape = (3, 1, 1, 1) + ref_min_val, ref_max_val = map(lambda x: np.reshape(x, shape), (ref_min_val, ref_max_val)) + + assert np.allclose(stat.min_values, ref_min_val) + assert np.allclose(stat.max_values, ref_max_val) + if isinstance(ref_min_val, np.ndarray): + assert stat.min_values.shape == ref_min_val.shape + assert stat.max_values.shape == ref_max_val.shape + else: + ref_shape = (1, 1, 1, 1) if is_stat_in_shape_of_scale else () + assert stat.min_values.shape == ref_shape + assert stat.max_values.shape == ref_shape + + ### Check the inserted operation is inside the model + if test_parameters.target_type == TargetType.OPERATION_WITH_WEIGHTS: + pre_ops = model.conv.pre_ops + assert len(pre_ops) == 1 + for module in pre_ops.values(): + assert isinstance(module, UpdateWeight) + assert module.op is fn + else: + if test_parameters.target_type == TargetType.OPERATOR_POST_HOOK: + hooks = model.nncf._compressed_context._post_hooks + else: + hooks = model.nncf._compressed_context._pre_hooks + assert len(hooks) == 1 + _, hook_ops = hooks.popitem() + assert len(hook_ops) == 1 + assert hook_ops[0] is fn