From 6cffa744edf11b2ac37ffec0a3075756f539f334 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 24 Nov 2023 20:08:00 +0100 Subject: [PATCH] SQ test is complited --- .../algorithms/smooth_quant/algorithm.py | 19 ++-- .../algorithms/smooth_quant/backend.py | 15 +-- .../smooth_quant/openvino_backend.py | 38 +++---- .../algorithms/smooth_quant/torch_backend.py | 14 +-- .../layer_attributes_handlers.py | 2 - .../graph/transformations/command_creation.py | 17 +-- nncf/torch/tensor.py | 2 +- tests/openvino/native/test_smooth_quant.py | 72 +++++++++--- tests/post_training/test_templates/helpers.py | 31 ++++- .../test_templates/test_smooth_quant.py | 106 ++++++++++++------ tests/torch/ptq/test_smooth_quant.py | 75 +++++++++---- 11 files changed, 254 insertions(+), 137 deletions(-) diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index ed2e0564db8..8fb422a23d0 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -132,7 +132,7 @@ def apply( activations_value = self._backend_entity.clip_statistics(activations_value[0]) weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) - weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value, graph) + weight_statistics = self._process_weight_statistics(node_to_smooth, weight_value) weight_statistics = self._backend_entity.clip_statistics(weight_statistics) alpha = alpha_map[node_to_smooth.metatype] @@ -159,7 +159,7 @@ def apply( for node_to_smooth in nodes: weight_value = self._backend_entity.get_weight_value(node_to_smooth, model) - weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value, graph) + weights_scale = self._calculate_weight_scale(best_scale, node_to_smooth, weight_value) ### TODO: DO it as NNCFTensor op scaled_weight = weight_value * weights_scale ### @@ -302,8 +302,7 @@ def _calculate_activation_scale( """ activation_ports_map = {node: self._backend_entity.get_activations_port_id(node, nncf_graph) for node in nodes} channel_axes = [ - self._backend_entity.get_activation_channel_axis(node, port, activations_shape) - for node, port in activation_ports_map.items() + self._backend_entity.get_activation_channel_axis(node, port) for node, port in activation_ports_map.items() ] channel_axis = channel_axes[0] @@ -313,9 +312,7 @@ def _calculate_activation_scale( activations_size = len(activations_shape) return self._backend_entity.calculate_activation_scale(scale_value, activations_size, channel_axis) - def _calculate_weight_scale( - self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor, graph: NNCFGraph - ) -> TTensor: + def _calculate_weight_scale(self, scale_value: TTensor, node: NNCFNode, weights_value: TTensor) -> TTensor: """ Calculates scale for weight tensor. @@ -325,7 +322,7 @@ def _calculate_weight_scale( """ weights_size = len(weights_value.shape) if weights_size > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, graph) + channel_axis = self._backend_entity.get_weight_channel_axis(node) return self._backend_entity.calculate_weight_scale(scale_value, weights_size, channel_axis) return scale_value @@ -341,11 +338,11 @@ def _calculate_input_reduction_axes(self, nncf_graph: NNCFGraph, node: NNCFNode, shape = nncf_graph.get_input_edges(node)[input_port].tensor_shape reduction_axes = tuple([]) if len(shape) > 1: - channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port, shape) + channel_axis = self._backend_entity.get_activation_channel_axis(node, input_port) reduction_axes = self._backend_entity.get_channel_agnostic_reduction_axes(channel_axis, shape) return reduction_axes - def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NNCFGraph) -> TTensor: + def _process_weight_statistics(self, node: NNCFNode, weights: TTensor) -> TTensor: """ Returns processed weight statistics for node. @@ -356,7 +353,7 @@ def _process_weight_statistics(self, node: NNCFNode, weights: TTensor, graph: NN """ channel_axis = 0 if len(weights.shape) > 1: - channel_axis = self._backend_entity.get_weight_channel_axis(node, graph) + channel_axis = self._backend_entity.get_weight_channel_axis(node) reduction_shape = [i for i, _ in enumerate(weights.shape)] reduction_shape.pop(channel_axis) return self._backend_entity.process_weight_statistics(weights, tuple(reduction_shape)) diff --git a/nncf/quantization/algorithms/smooth_quant/backend.py b/nncf/quantization/algorithms/smooth_quant/backend.py index d015eb73a45..fbd28ae890f 100644 --- a/nncf/quantization/algorithms/smooth_quant/backend.py +++ b/nncf/quantization/algorithms/smooth_quant/backend.py @@ -222,7 +222,7 @@ def scale_insertion_command( @staticmethod @abstractmethod - def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: """ Returns axis number of the activation tensor which correspond to it channel. @@ -233,7 +233,7 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: @staticmethod @abstractmethod - def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + def get_weight_channel_axis(node: NNCFNode) -> int: """ Returns axis number of the weight tensor which correspond to it channel. @@ -242,17 +242,6 @@ def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: :return: Channel axis number. """ - @staticmethod - @abstractmethod - def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: - """ - Returns port-based channel axis. - - :param port_id: Specified input port id. - :param transpose: Transpose position. - :return: Channel axis. - """ - @staticmethod @abstractmethod def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): diff --git a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py index ca797899d6b..58db319ebee 100644 --- a/nncf/quantization/algorithms/smooth_quant/openvino_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/openvino_backend.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np import openvino.runtime as ov @@ -56,8 +56,14 @@ def target_point(target_node_name: str, port_id: int) -> OVTargetPoint: def is_node_with_weights(node: NNCFNode) -> bool: return node.layer_attributes and node.layer_attributes.constant_attributes + def _get_weight_port_id(node: NNCFNode) -> int: + weight_ports = node.layer_attributes.get_const_port_ids() + if len(weight_ports) != 1: + raise RuntimeError(f"Too many weight ports for {node.node_name} node") + return weight_ports[0] + @staticmethod - def _get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int]: + def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: weight_ports = node.layer_attributes.get_const_port_ids() activation_ports = [ e.input_port_id for e in nncf_graph.get_input_edges(node) if e.input_port_id not in weight_ports @@ -65,12 +71,7 @@ def _get_input_ports_map(node: NNCFNode, nncf_graph: NNCFGraph) -> Dict[str, int if len(weight_ports) != 1 or len(activation_ports) != 1: raise RuntimeError(f"Too many weight or activation ports for {node.node_name} node") - - return {"activation": activation_ports[0], "weight": weight_ports[0]} - - @staticmethod - def get_activations_port_id(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - return OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["activation"] + return activation_ports[0] @staticmethod def get_channel_agnostic_reduction_axes(channel_axis: int, shape: Tuple[int]) -> Tuple[int]: @@ -161,7 +162,7 @@ def scale_insertion_command( ) @staticmethod - def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: channel_axis = 1 if port_id > 1: @@ -174,33 +175,30 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: and "transpose" in node.layer_attributes.input_attributes ): transpose = node.layer_attributes.input_attributes["transpose"] - channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose) + channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose) return channel_axis @staticmethod - def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: - port_id = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph)["weight"] - channel_axis = 1 if node.metatype.const_channel_axis is None else node.metatype.const_channel_axis[0] - - if port_id not in node.layer_attributes.constant_attributes: - raise RuntimeError(f"{node.node_name} should contain {port_id} in the attributes map.") + def get_weight_channel_axis(node: NNCFNode) -> int: + port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node) + channel_axis = 1 if node.metatype == OVMatMulMetatype: if "transpose" in node.layer_attributes.constant_attributes[port_id]: transpose = node.layer_attributes.constant_attributes[port_id]["transpose"] - channel_axis = OVSmoothQuantAlgoBackend.calculate_port_based_channel_axis(port_id, transpose) + channel_axis = OVSmoothQuantAlgoBackend._calculate_port_based_channel_axis(port_id, transpose) return channel_axis @staticmethod - def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: + def _calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: return -2 + port_id if transpose else -1 - port_id @staticmethod def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): - ports_map = OVSmoothQuantAlgoBackend._get_input_ports_map(node, nncf_graph) - weight_node = nncf_graph.get_input_edges(node)[ports_map["weight"]].from_node + weight_port_id = OVSmoothQuantAlgoBackend._get_weight_port_id(node) + weight_node = nncf_graph.get_input_edges(node)[weight_port_id].from_node # Skipping shared weights return len(nncf_graph.get_next_nodes(weight_node)) > 1 diff --git a/nncf/quantization/algorithms/smooth_quant/torch_backend.py b/nncf/quantization/algorithms/smooth_quant/torch_backend.py index 57e81c72e72..b4ca0596b6c 100644 --- a/nncf/quantization/algorithms/smooth_quant/torch_backend.py +++ b/nncf/quantization/algorithms/smooth_quant/torch_backend.py @@ -160,16 +160,16 @@ def scale_insertion_command( return multiply_insertion_command(nodes, scale_value, scale_node_name, input_port_id) @staticmethod - def get_activation_channel_axis(node: NNCFNode, port_id: int, activations_shape: Tuple[int, ...]) -> int: - return len(activations_shape) - 1 - - @staticmethod - def get_weight_channel_axis(node: NNCFNode, nncf_graph: NNCFGraph) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int) -> int: + if node.metatype == om.PTModuleLinearMetatype: + return -1 + # TODO: Add activation axis calculation when MatMul wiil be supported return 1 @staticmethod - def calculate_port_based_channel_axis(port_id: int, transpose: bool) -> int: - return -2 + port_id if transpose else -1 - port_id + def get_weight_channel_axis(node: NNCFNode) -> int: + # TODO: Add activation axis calculation when MatMul wiil be supported + return 1 @staticmethod def is_node_with_shared_weight(node: NNCFNode, nncf_graph: NNCFGraph): diff --git a/nncf/torch/dynamic_graph/layer_attributes_handlers.py b/nncf/torch/dynamic_graph/layer_attributes_handlers.py index 0b839b8f18c..151645181b0 100644 --- a/nncf/torch/dynamic_graph/layer_attributes_handlers.py +++ b/nncf/torch/dynamic_graph/layer_attributes_handlers.py @@ -31,13 +31,11 @@ from nncf.common.graph.layer_attributes import PermuteLayerAttributes from nncf.common.graph.layer_attributes import ReshapeLayerAttributes from nncf.common.graph.layer_attributes import TransposeLayerAttributes -from nncf.common.graph.layer_attributes import MatMulLayerAttributes from nncf.common.graph.utils import get_concat_axis from nncf.common.graph.utils import get_split_axis from nncf.torch.graph.operator_metatypes import PTCatMetatype from nncf.torch.graph.operator_metatypes import PTGroupNormMetatype from nncf.torch.graph.operator_metatypes import PTPadMetatype -from nncf.torch.graph.operator_metatypes import PTMatMulMetatype from nncf.torch.graph.operator_metatypes import PTReshapeMetatype from nncf.torch.graph.operator_metatypes import PTSplitMetatype from nncf.torch.graph.operator_metatypes import PTSqueezeMetatype diff --git a/nncf/torch/graph/transformations/command_creation.py b/nncf/torch/graph/transformations/command_creation.py index 42fc5d0cb41..3273a35c4a5 100644 --- a/nncf/torch/graph/transformations/command_creation.py +++ b/nncf/torch/graph/transformations/command_creation.py @@ -48,6 +48,15 @@ def create_command_to_update_weight(node: NNCFNode, weight_value: Tensor) -> PTW return PTWeightUpdateCommand(target_point, weight_value) +class SQMultiply(torch.nn.Module): + def __init__(self, scale_value): + super().__init__() + self._scale_value = scale_value + + def forward(self, x): + return torch.mul(x, self._scale_value) + + def multiply_insertion_command( target_nodes: List[NNCFNode], scale_value: Tensor, scale_node_name: str, input_port_id: int ) -> PTInsertionCommand: @@ -56,12 +65,4 @@ def multiply_insertion_command( target_point = PTTargetPoint(TargetType.OPERATOR_PRE_HOOK, target_node.node_name, input_port_id=input_port_id) commands.append(PTInsertionCommand(target_point, None, priority=TransformationPriority.OP_INSERTION_PRIORITY)) - class SQMultiply(torch.nn.Module): - def __init__(self, scale_value): - super().__init__() - self._scale_value = scale_value - - def forward(self, x): - return torch.mul(x, self._scale_value) - return PTSharedFnInsertionCommand(commands, SQMultiply(scale_value), scale_node_name) diff --git a/nncf/torch/tensor.py b/nncf/torch/tensor.py index b7977ca2818..908e482f889 100644 --- a/nncf/torch/tensor.py +++ b/nncf/torch/tensor.py @@ -37,4 +37,4 @@ def device(self) -> torch.device: return self._tensor.device def is_empty(self) -> bool: - return self.tensor.size == 0 + return self.tensor.numel() == 0 diff --git a/tests/openvino/native/test_smooth_quant.py b/tests/openvino/native/test_smooth_quant.py index 39fe6af4dca..9b1ad20113f 100644 --- a/tests/openvino/native/test_smooth_quant.py +++ b/tests/openvino/native/test_smooth_quant.py @@ -18,18 +18,61 @@ import torch from openvino.tools.mo import convert_model +from nncf.common.graph.transformations.commands import TransformationCommand from nncf.openvino.graph.layer_attributes import OVLayerAttributes from nncf.openvino.graph.metatypes.openvino_metatypes import OVConvolutionMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import OVMatMulMetatype from nncf.quantization.algorithms.smooth_quant.openvino_backend import OVSmoothQuantAlgoBackend from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm +OV_LINEAR_MODEL_MM_OP_MAP = { + "MatMul1": "/MatMul", + "MatMul2": "/MatMul_1", + "MatMul3": "/MatMul_2", + "MatMul4": "/MatMul_4", + "MatMul5": "32", + "MatMul6": "37", + "MatMul7": "54", + "MatMul8": "68", + "Linear1": "/linear_2/MatMul", + "Linear2": "/linear_1/MatMul", + "Linear3": "/linear_3/MatMul", + "Linear4": "/linear_4/MatMul", +} + + +OV_LINEAR_MODEL_SQ_OP_MAP = { + "MatMul1": "/Reshape_0_0/nncf_smooth_quant", + "MatMul2": "/Reshape_0_0/nncf_smooth_quant", + "MatMul3": "/Reshape_1_0_0/nncf_smooth_quant", + "MatMul4": "/Reshape_1_0_1/nncf_smooth_quant", + "MatMul5": "/Reshape_2_0_0/nncf_smooth_quant", + "MatMul6": "/ReduceMax_0_0/nncf_smooth_quant", + "MatMul7": "/Reshape_3_0_0/nncf_smooth_quant", + "MatMul8": "/Reshape_4_0_0/nncf_smooth_quant", + "Linear1": "/Split_1_0/nncf_smooth_quant", + "Linear2": "/Split_0_0/nncf_smooth_quant", + "Linear3": "/Add_0_0/nncf_smooth_quant", + "Linear4": "/Add_0_0/nncf_smooth_quant", +} + class TestOVSQAlgorithm(TemplateTestSQAlgorithm): @staticmethod def fn_to_type(tensor) -> np.ndarray: return np.array(tensor) + @pytest.fixture(params=[False, True], ids=["out_of_palce", "inplace"]) + def inplace_statistics(self, request) -> bool: + return request.param + + def get_node_name_map(self) -> Dict[str, str]: + return OV_LINEAR_MODEL_MM_OP_MAP + + @staticmethod + def get_target_node_name(command: TransformationCommand): + return command.target_point.target_node_name + @staticmethod def get_transform_fn() -> Callable: def transform_fn(data_item): @@ -53,10 +96,14 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model: @staticmethod def check_scales(model: ov.Model, reference_values: Dict[str, np.ndarray]) -> None: ops_list = {op.get_friendly_name(): op for op in model.get_ops()} - for ref_name, ref_value in reference_values.items(): - node = ops_list[ref_name] - const_node = node.input(1).get_source_output().get_node() - + for ref_names, ref_value in reference_values.items(): + const_nodes = [] + for ref_name in ref_names: + node = ops_list[OV_LINEAR_MODEL_SQ_OP_MAP[ref_name]] + const_nodes.append(node.input(1).get_source_output().get_node()) + # Check unified group acutally shares one constant + assert all(node is const_nodes[0] for node in const_nodes[1:]) + const_node = const_nodes[0] assert const_node.get_type_name() == "Constant" value = const_node.data @@ -79,18 +126,17 @@ def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port return super().test_get_activation_channel_axis(node_metatype, layer_attributes, port_id, reference_value) @pytest.mark.parametrize( - "node_metatype, layer_attributes, port_id, reference_value", + "node_metatype, layer_attributes, reference_value", ( - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 1, -2), - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": True}}), 1, -1), - (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": False}}), 0, -1), - (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": True}}), 0, -2), - (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), 2, RuntimeError), - (OVConvolutionMetatype, OVLayerAttributes({1: {}}), 1, 1), + (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": False}}), -2), + (OVMatMulMetatype, OVLayerAttributes({1: {"transpose": True}}), -1), + (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": False}}), -1), + (OVMatMulMetatype, OVLayerAttributes({0: {"transpose": True}}), -2), + (OVConvolutionMetatype, OVLayerAttributes({1: {}}), 1), ), ) - def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): - return super().test_get_weight_channel_axis(node_metatype, layer_attributes, port_id, reference_value) + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value): + return super().test_get_weight_channel_axis(node_metatype, layer_attributes, reference_value) @staticmethod def get_matmul_metatype(): diff --git a/tests/post_training/test_templates/helpers.py b/tests/post_training/test_templates/helpers.py index 8f60f61ee9f..bc58cfd132b 100644 --- a/tests/post_training/test_templates/helpers.py +++ b/tests/post_training/test_templates/helpers.py @@ -157,6 +157,14 @@ def __init__(self) -> None: self.matmul_7_data = torch.randn((6, 6), dtype=torch.float32) self.matmul_8_data = torch.randn((10, 6), dtype=torch.float32) + self.linear_3 = nn.Linear(4, 4) + self.linear_3.weight.data = torch.randn((4, 4), dtype=torch.float32) + self.linear_3.bias.data = torch.randn((1, 4), dtype=torch.float32) + + self.linear_4 = nn.Linear(4, 4) + self.linear_4.weight.data = torch.randn((4, 4), dtype=torch.float32) + self.linear_4.bias.data = torch.randn((1, 4), dtype=torch.float32) + def forward(self, x): x = torch.reshape(x, (1, 3, 2, 4)) @@ -164,6 +172,15 @@ def forward(self, x): x_2 = torch.matmul(x, self.matmul_2_data) x = torch.add(x_1, x_2) + + x_3 = self.linear_3(x) + x_4 = self.linear_4(x) + + x_ = torch.add(x_3, x_4) + + x = torch.add(x, x_) + x = torch.sub(x, x_) + x_1 = torch.reshape(x, (1, 3, 8)) x_1_1 = torch.matmul(x_1, self.matmul_3_data) @@ -189,13 +206,23 @@ def forward(self, x): class NonZeroLinearModel(nn.Module): INPUT_SIZE = [10] + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 5) + self.linear.weight.data = torch.ones((5, 1)) + self.linear.bias.data = torch.zeros((1, 1)) + + self.linear1 = nn.Linear(10, 10) + self.linear1.weight.data = torch.ones((10, 10)) + self.linear1.bias.data = torch.zeros((1, 1)) + def forward(self, x): zeros = (x > torch.inf).float() empty = torch.nonzero(zeros).reshape((-1, 1, 1)).float() - y = torch.matmul(empty, torch.ones((1, 5))) + y = self.linear(empty) y += 5 y = torch.cat((torch.ones((1, 10)), y.reshape(1, -1)), dim=1) - y = torch.matmul(y, torch.ones(10, 10)) + y = self.linear1(y) y += 5 return y diff --git a/tests/post_training/test_templates/test_smooth_quant.py b/tests/post_training/test_templates/test_smooth_quant.py index 7a14e335393..903b98871ea 100644 --- a/tests/post_training/test_templates/test_smooth_quant.py +++ b/tests/post_training/test_templates/test_smooth_quant.py @@ -17,6 +17,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TransformationCommand from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.parameters import ModelType @@ -39,6 +40,27 @@ class TemplateTestSQAlgorithm: def fn_to_type(tensor) -> TTensor: return tensor + @pytest.fixture + @abstractmethod + def inplace_statistics(self) -> bool: + """ + Returns all possible values for inplace parameter. + """ + + @abstractmethod + def get_node_name_map(self) -> Dict[str, str]: + """ + Return backend specific map from the LinearMultiShapeModel labels + to nncf_grpah nodes names. + """ + + @staticmethod + @abstractmethod + def get_target_node_name(command: TransformationCommand): + """ + Get target node name from a transformation command. + """ + @staticmethod @abstractmethod def get_transform_fn() -> Callable: @@ -92,10 +114,8 @@ def get_quantization_algorithm(): ( LinearMultiShapeModel, { - "/Reshape_0_0/nncf_smooth_quant": [[[[1.0594617, 1.1019668, 1.2208323, 1.1003988]]]], - "/Split_1_0/nncf_smooth_quant": [[[[1.1276343, 0.7605822]]]], - "/Split_0_0/nncf_smooth_quant": [[[[0.32575992, 0.33121374]]]], - "/Reshape_1_0_0/nncf_smooth_quant": [ + ("MatMul1", "MatMul2"): [[[[1.0594617, 1.1019668, 1.2208323, 1.1003988]]]], + ("MatMul3",): [ [ [ 0.3251956, @@ -109,11 +129,16 @@ def get_quantization_algorithm(): ] ] ], - "/Reshape_1_0_1/nncf_smooth_quant": [[[0.4699388], [0.3369332], [0.3674589]]], - "/Reshape_2_0_0/nncf_smooth_quant": [[0.1242606]], - "/ReduceMax_0_0/nncf_smooth_quant": [ + ("MatMul4",): [[[0.4699388], [0.3369332], [0.3674589]]], + ("MatMul5",): [[0.1242606]], + ("MatMul6",): [ [0.08709318, 0.08033343, 0.67289335, 0.33452678, 0.14223875, 0.19858328, 0.46314085, 0.68816555] ], + ("MatMul7",): [0.25238913, 0.38786113, 0.15471783, 0.27681994, 0.53814197, 0.18316744], + ("MatMul8",): [1.562704, 1.1183096, 2.3738348, 2.382382, 0.9243705, 1.8179475], + ("Linear1",): [[[[1.1276343, 0.7605822]]]], + ("Linear2",): [[[[0.32575992, 0.33121374]]]], + ("Linear3", "Linear4"): [[[[0.33630377, 0.3288621, 0.9898262, 0.7217065]]]], }, ), ), @@ -128,26 +153,27 @@ def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir): self.check_scales(quantized_model, reference_values) - def test_get_abs_max_channel_collector(self): + def test_get_abs_max_channel_collector(self, inplace_statistics: bool): backend = self.get_backend() reduction_axes = (3, 2, 1) samples = 1 - for inplace_type in [False, True]: - backend_tensor_collector = backend.get_abs_max_channel_collector( - num_samples=samples, - stats_reduction_axes=reduction_axes, - inplace=inplace_type, - branch_key="test_branch", - ) + backend_tensor_collector = backend.get_abs_max_channel_collector( + num_samples=samples, + stats_reduction_axes=reduction_axes, + inplace=inplace_statistics, + branch_key="test_branch", + ) - for aggregator in backend_tensor_collector.aggregators.values(): - assert isinstance(aggregator, MaxAggregator) + assert len(backend_tensor_collector.aggregators) == 1 + for aggregator in backend_tensor_collector.aggregators.values(): + assert isinstance(aggregator, MaxAggregator) - for reducer in backend_tensor_collector.reducers: - assert isinstance(reducer, AbsMaxReducer) - assert reducer.inplace == inplace_type - assert reducer._reduction_axes == reduction_axes + assert len(backend_tensor_collector.reducers) == 1 + for reducer in backend_tensor_collector.reducers: + assert isinstance(reducer, AbsMaxReducer) + assert reducer.inplace == inplace_statistics + assert reducer._reduction_axes == reduction_axes @pytest.mark.parametrize( "model_cls, references", @@ -155,16 +181,18 @@ def test_get_abs_max_channel_collector(self): ( LinearMultiShapeModel, [ - ("/MatMul_1", 0), - ("/MatMul", 0), - ("/linear_2/MatMul", 0), - ("/linear_1/MatMul", 0), - ("/MatMul_2", 0), - ("/MatMul_4", 1), - ("55", 1), - ("41", 0), - ("19", 1), - ("24", 0), + ("MatMul1", 0), + ("MatMul2", 0), + ("MatMul3", 0), + ("MatMul4", 1), + ("MatMul5", 1), + ("MatMul6", 0), + ("MatMul7", 0), + ("MatMul8", 1), + ("Linear1", 0), + ("Linear2", 0), + ("Linear3", 0), + ("Linear4", 0), ], ), ), @@ -179,9 +207,15 @@ def test__get_nodes_to_smooth_data(self, model_cls, references, tmpdir): smooth_data = algo._get_nodes_to_smooth_data(nncf_graph, alpha_map.keys()) smooth_data = {d["node_to_smooth"].node_name: d["input_act_port"] for d in smooth_data} + name_map = self.get_node_name_map() + assert len(name_map) == len(smooth_data) + matched = 0 for ref_node_name, ref_port_id in references: - assert ref_node_name in smooth_data - assert smooth_data[ref_node_name] == ref_port_id + if ref_node_name not in name_map: + continue + matched += 1 + assert smooth_data[name_map[ref_node_name]] == ref_port_id + assert matched == len(smooth_data) def test_empty_stats(self, mocker, tmpdir): model_cls = NonZeroLinearModel @@ -206,7 +240,7 @@ def test_empty_stats(self, mocker, tmpdir): mm_metatype = self.get_matmul_metatype() matmuls = [node for node in graph.topological_sort() if node.metatype == mm_metatype] for transformation in arg.transformations: - assert transformation.target_point.target_node_name != matmuls[0].node_name + assert self.get_target_node_name(transformation) != matmuls[0].node_name def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): backend = self.get_backend() @@ -227,7 +261,7 @@ def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port assert activation_channel_axis == reference_value - def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value): backend = self.get_backend() attributes = { @@ -239,7 +273,7 @@ def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, node = NNCFNode(attributes) try: - activation_channel_axis = backend.get_weight_channel_axis(node, port_id) + activation_channel_axis = backend.get_weight_channel_axis(node) except RuntimeError as e: if isinstance(e, reference_value): pytest.xfail("Expected exception") diff --git a/tests/torch/ptq/test_smooth_quant.py b/tests/torch/ptq/test_smooth_quant.py index c55b86b414b..7f723167e3b 100644 --- a/tests/torch/ptq/test_smooth_quant.py +++ b/tests/torch/ptq/test_smooth_quant.py @@ -16,19 +16,51 @@ import pytest import torch -from nncf.openvino.graph.layer_attributes import OVLayerAttributes +from nncf.common.graph.transformations.commands import TransformationCommand from nncf.quantization.algorithms.smooth_quant.torch_backend import PTSmoothQuantAlgoBackend from nncf.torch.graph.operator_metatypes import PTModuleConv2dMetatype from nncf.torch.graph.operator_metatypes import PTModuleLinearMetatype +from nncf.torch.graph.transformations.command_creation import SQMultiply +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.model_creation import wrap_model +from nncf.torch.nncf_network import ExtraCompressionModuleType from tests.post_training.test_templates.test_smooth_quant import TemplateTestSQAlgorithm +PT_LINEAR_MODEL_SQ_MAP = { + ("Linear1",): "LinearMultiShapeModel/split_0_1_0/nncf_smooth_quant" + "[LinearMultiShapeModel/NNCFLinear[linear_2]/linear_0]", + ("Linear2",): "LinearMultiShapeModel/split_0_0_0/nncf_smooth_quant" + "[LinearMultiShapeModel/NNCFLinear[linear_1]/linear_0]", + ("Linear3", "Linear4"): "LinearMultiShapeModel/add_0_0_0/nncf_smooth_quant" + "[LinearMultiShapeModel/NNCFLinear[linear_3]/linear_0;LinearMultiShapeModel/NNCFLinear[linear_4]/linear_0]", +} + +PT_LINEAR_MODEL_MM_MAP = { + "Linear1": "LinearMultiShapeModel/NNCFLinear[linear_2]/linear_0", + "Linear2": "LinearMultiShapeModel/NNCFLinear[linear_1]/linear_0", + "Linear3": "LinearMultiShapeModel/NNCFLinear[linear_3]/linear_0", + "Linear4": "LinearMultiShapeModel/NNCFLinear[linear_4]/linear_0", +} + class TestTorchSQAlgorithm(TemplateTestSQAlgorithm): @staticmethod def fn_to_type(tensor) -> torch.Tensor: return torch.tensor(tensor) + @pytest.fixture(params=[False], ids=["out_of_palce"]) + def inplace_statistics(self, request) -> bool: + return request.param + + def get_node_name_map(self) -> Dict[str, str]: + return PT_LINEAR_MODEL_MM_MAP + + @staticmethod + def get_target_node_name(command: TransformationCommand): + if isinstance(command, PTSharedFnInsertionCommand): + return command.target_commands[0].target_point.target_node_name + return command.target_point.target_node_name + @staticmethod def get_transform_fn() -> Callable: def transform_fn(data_item): @@ -46,45 +78,40 @@ def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> ov.Model: @staticmethod def check_scales(model: torch.nn.Module, reference_values: Dict[str, np.ndarray]) -> None: - ops_list = {op.get_friendly_name(): op for op in model.get_ops()} - for ref_name, ref_value in reference_values.items(): - node = ops_list[ref_name] - const_node = node.input(1).get_source_output().get_node() + modules = model.nncf.get_compression_modules_by_type(ExtraCompressionModuleType.EXTERNAL_OP) + for ref_names, ref_value in reference_values.items(): + if not all(name.startswith("Linear") for name in ref_names): + # Pytorch SQ algorithm supports only linear modules by far, + # so other multiplies are skipped + continue + sq_node = modules[PT_LINEAR_MODEL_SQ_MAP[ref_names]] - assert const_node.get_type_name() == "Constant" + assert isinstance(sq_node, SQMultiply) - value = const_node.data - ref_value = np.array(ref_value) + value = sq_node._scale_value + ref_value = torch.tensor(ref_value) assert value.shape == ref_value.shape - assert np.all(np.isclose(value, ref_value, atol=0.0001)), f"{value} != {ref_value}" + assert torch.all(torch.isclose(value, ref_value, rtol=1e-4)) @pytest.mark.parametrize( "node_metatype, layer_attributes, port_id, reference_value", ( - (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 0, -1), - (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": True}), 0, -2), - (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 1, -2), - (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": True}), 1, -1), - (PTModuleLinearMetatype, OVLayerAttributes({}, inputs_attributes={"transpose": False}), 2, RuntimeError), - (PTModuleConv2dMetatype, OVLayerAttributes({}, inputs_attributes={}), 0, 1), + (PTModuleLinearMetatype, None, 0, -1), + (PTModuleConv2dMetatype, None, 0, 1), ), ) def test_get_activation_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): return super().test_get_activation_channel_axis(node_metatype, layer_attributes, port_id, reference_value) @pytest.mark.parametrize( - "node_metatype, layer_attributes, port_id, reference_value", + "node_metatype, layer_attributes, reference_value", ( - (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": False}}), 1, -2), - (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": True}}), 1, -1), - (PTModuleLinearMetatype, OVLayerAttributes({0: {"transpose": False}}), 0, -1), - (PTModuleLinearMetatype, OVLayerAttributes({0: {"transpose": True}}), 0, -2), - (PTModuleLinearMetatype, OVLayerAttributes({1: {"transpose": False}}), 2, RuntimeError), - (PTModuleConv2dMetatype, OVLayerAttributes({1: {}}), 1, 1), + (PTModuleLinearMetatype, None, 1), + (PTModuleConv2dMetatype, None, 1), ), ) - def test_get_weight_channel_axis(self, node_metatype, layer_attributes, port_id, reference_value): - return super().test_get_weight_channel_axis(node_metatype, layer_attributes, port_id, reference_value) + def test_get_weight_channel_axis(self, node_metatype, layer_attributes, reference_value): + return super().test_get_weight_channel_axis(node_metatype, layer_attributes, reference_value) @staticmethod def get_matmul_metatype():