diff --git a/nncf/openvino/graph/model_builder.py b/nncf/openvino/graph/model_builder.py new file mode 100644 index 00000000000..cf314650718 --- /dev/null +++ b/nncf/openvino/graph/model_builder.py @@ -0,0 +1,222 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import deque +from typing import Dict, List, Tuple + +import openvino.runtime as ov +from openvino.runtime import opset13 as opset +from openvino.runtime.utils.node_factory import NodeFactory + +from nncf.openvino.graph.model_transformer import OVModelTransformer +from nncf.openvino.graph.node_utils import get_parameter_node_name +from nncf.openvino.graph.node_utils import get_result_node_name + + +class ModelBuilder: + """ + The purpose of the ModelBuilder is to build a new OpenVINO model from input and output points. + This Builder was created to reduce the number of model cloning that is required for ModelTransformer to work. + """ + + def __init__(self): + self._node_factory = NodeFactory() + + @staticmethod + def _create_parameter(node_name: str, node_input: ov.Input) -> ov.Node: + """ + A method that contains steps to create a Parameter for a new model using a specific template. + """ + port_id = node_input.get_index() + parameter_name = get_parameter_node_name(node_name, port_id) + return opset.parameter( + shape=node_input.get_partial_shape(), + dtype=node_input.get_element_type(), + name=parameter_name, + ) + + @staticmethod + def _create_result(node_name: str, node_output: ov.Input) -> ov.Node: + """ + A method that contains steps to create a Result for a new model using a specific template. + """ + port_id = node_output.get_index() + result_name = get_result_node_name(node_name, port_id=port_id) + result = opset.result(node_output, name=result_name) + result.get_output_tensor(0).set_names({result_name}) + return result + + def _collect_graph_nodes( + self, + input_ids: List[Tuple[str, int]], + output_ids: List[Tuple[str, int]], + node_mapping: Dict[str, ov.Node], + ) -> List[ov.Node]: + """ + A method for aggregating layers to be further cloned. + Aggregation is designed in such a way that layers are listed from right to left, + as they pass from bottom to top. This is done in order to find all constants in the model and + to start graph creation from them (as well as Parameter layers), because + OpenVINO graph is created from top-down and cannot be created otherwise. + + Legend: w - weigths, c - convert, il/lh - input low/high, ol/oh - output low/high + (w) + | + (c) (il) (ih) (ol) (oh) + \ | | / / + (fake quantize) (parameter) + \ / + (convolution) + | + (result) + Based on the above graph, the return value would look like this: + [convolution, parameter, fake quantize, oh, ol, ih, il, c, w] + + :param input_ids: List of the ids specified in algorithm. + :param output_ids: List of the ids specified in algorithm. + :param node_mapping: Original nodes mapping. + :return: List of the ov.Nodes to clone. + """ + # Creating a list as a deque for FIFO layer acquisition and retrieval + lookup_nodes = deque(node_mapping[n] for n, _ in output_ids) + graph_nodes = [] + + while lookup_nodes: + lookup_node = lookup_nodes.popleft() + lookup_name = lookup_node.get_friendly_name() + node_inputs = lookup_node.inputs() + graph_nodes.append(lookup_node) + # Reversing to lookup nodes from right to left + for node_input in reversed(node_inputs): + port_id = node_input.get_index() + if (lookup_name, port_id) in input_ids: + # We create Parameters here to avoid double creation in the future since it is not an original node, + # but we need to have it as input for next node. + parameter = self._create_parameter(lookup_name, node_input) + lookup_nodes.append(parameter) + continue + parent_node = node_input.get_source_output().get_node() + lookup_nodes.append(parent_node) + + return graph_nodes + + def build( + self, + input_ids: List[Tuple[str, int]], + output_ids: List[Tuple[str, int]], + node_mapping: Dict[str, ov.Node], + ) -> ov.Model: + """ + The basic method of the algorithm. This method uses an aggregated list of layers to be recreated. + Let us take a graph of this kind as an example: + + Legend: w - weigths, c - convert, il/lh - input low/high, ol/oh - output low/high + (w) + | + (c) (il) (ih) (ol) (oh) + \ | | / / + (fake quantize) (parameter) + \ / + (convolution) + | + (result) + + The externally collected list of layers will look like this: + [convolution, parameter, fake quantize, oh, ol, ih, il, c, w] + + Next, this list will be circled from right to left. At the same time, the list of already created layers + will be filled from left to right, which will be used in the traversal step also, from left to right, + in order to keep the order of the original layer inputs. + For example: + + graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c, w] + clone_nodes = [] + + *creating w - weight node.* + graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il, c] + clone_nodes = [w] + + *creating c - convert node. + Based on the .inputs() output, we'll use the already created w-weight node to fill in the convert input. + As the result, weight node would be removed from the clone_nodes list and convert node would be placed here.* + graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il] + clone_nodes = [c] + + *creating il/lh - input low/high, ol/oh - output low/high nodes. + Since these nodes are constants and do not require any nodes as inputs, cloned nodes will not be used.* + graph_nodes = [convolution, parameter, fake quantize, oh, ol, ih, il] + clone_nodes = [c, il, ih, ol, oh] + + *creating fake quantize node. + This node requires to have input values in a specific order. + All previous nodes will be connected/used for fake quantize, from left to right.* + graph_nodes = [convolution, parameter] + clone_nodes = [f] + + *creating parameter node. + In this step, the list of parameters will also be filled out with the new node.* + graph_nodes = [convolution] + clone_nodes = [f, parameter] + + *creating convolution node. + This node also requires to have inputs in a specific order. + All previous nodes will be connected/used for convolution, from left to right. Also, + the outputs verification step will show here that one of the convolution outputs is in the output_ids list. + This means that the Result node would be created and placed into the results list.* + graph_nodes = [] + clone_nodes = [convolution] + + The last step is to create a subgraph model based on the parameters & results lists. + + :param input_ids: List of the ids specified in algorithm. + :param output_ids: List of the ids specified in algorithm. + :param node_mapping: Original nodes mapping. + :return: Builded ov.Model based on parameters. + """ + + parameters, results = [], [] + clone_nodes = deque() + + # Collecting nodes that declares the graph. + graph_nodes = self._collect_graph_nodes(input_ids, output_ids, node_mapping) + + while graph_nodes: + graph_node = graph_nodes.pop() + node_type = graph_node.get_type_name() + node_name = graph_node.get_friendly_name() + + # To create the new OpenVINO nodes, we need to provide all possible layer attributes. + attrs = graph_node.get_attributes() + attrs["name"] = node_name + + if node_type == "Constant": + # Constants creation is apart due to specific behavior. + clone_node = OVModelTransformer._create_constant( + graph_node.get_data(), dtype=graph_node.get_element_type(), name=attrs["name"] + ) + elif node_type == "Parameter": + # We've created Parameter nodes on the previous step. + clone_node = graph_node + parameters.append(clone_node) + else: + # We have to have args as the inputs since all of them are nodes and are required to be as input. + args = [clone_nodes.popleft() for _ in graph_node.inputs()] + + clone_node = self._node_factory.create(node_type, args, attrs) + + for node_output in clone_node.outputs(): + port_id = node_output.get_index() + if (node_name, port_id) in output_ids: + result = self._create_result(node_name, node_output) + results.append(result) + + clone_nodes.append(clone_node) + + return ov.Model(results, parameters) diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index 7496187adb1..e73fdb14026 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -121,18 +121,22 @@ def get_const_value(const_node: ov.Node) -> np.ndarray: return const_node.data -def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray: +def get_bias_value( + node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model, node_mapping: Dict[str, ov.Node] = None +) -> np.ndarray: """ Returns the bias tensor for the biased node. :param node_with_bias: The node that corresponds to the operation with bias. :param nncf_graph: NNCFGraph instance. :param model: The model that contains this operation. + :param node_mapping: Original nodes mapping cache. :return: The bias value that is applied to the output tensor of the node's operation. """ - ops_dict = {op.get_friendly_name(): op for op in model.get_ops()} + if node_mapping is None: + node_mapping = {op.get_friendly_name(): op for op in model.get_ops()} bias_constant = get_node_with_bias_value(get_add_bias_node(node_with_bias, nncf_graph), nncf_graph) - ov_bias_constant = ops_dict[bias_constant.node_name] + ov_bias_constant = node_mapping[bias_constant.node_name] return get_const_value(ov_bias_constant) diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 3d104cad3c9..35f057f8a66 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -17,7 +17,6 @@ from nncf.common.factory import EngineFactory from nncf.common.factory import ModelTransformerFactory from nncf.common.graph.graph import NNCFGraph -from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout @@ -111,7 +110,7 @@ def _set_backend_entity(self, model: TModel) -> None: OVFastBiasCorrectionAlgoBackend, ) - self._backend_entity = OVFastBiasCorrectionAlgoBackend() + self._backend_entity = OVFastBiasCorrectionAlgoBackend(model) elif model_backend == BackendType.TORCH: from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend @@ -167,7 +166,7 @@ def apply( # Outputs of the subgraphs for the FastBiasCorrection are the same across the backends. output_id = (out_node_name, 0) - extracted_model = self._extract_submodel(model_transformer, input_id, output_id) + extracted_model = self._backend_entity.extract_submodel(model_transformer, input_id, output_id) if extracted_model is None: nncf_logger.debug(f"Skipping node {node_name} because cant extract submodel") continue @@ -287,23 +286,6 @@ def output_filter_func(point): output_fp.extend(tensor_collector.get_statistics().mean_values) return output_fp - def _extract_submodel( - self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int] - ) -> TModel: - """ - Extracts sub-model using backend-specific ModelTransformer. - - :param model_transformer: Backend-specific ModelTransformer. - :param input_id: Input ID. - :param output_id: Output ID. - :return: Backend-specific sub-model. - """ - model_extraction_command = self._backend_entity.model_extraction_command([input_id], [output_id]) - me_transformation_layout = TransformationLayout() - me_transformation_layout.register(model_extraction_command) - extracted_model = model_transformer.transform(me_transformation_layout) - return extracted_model - def _add_statistic_point(self, container: StatisticPointsContainer, point: TargetPoint, axis: int) -> None: """ Adds specific statistic point. diff --git a/nncf/quantization/algorithms/fast_bias_correction/backend.py b/nncf/quantization/algorithms/fast_bias_correction/backend.py index 110e05161cd..7c76b3857fe 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/backend.py @@ -15,9 +15,11 @@ from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode +from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.tensor_statistics.collectors import TensorStatisticCollectorBase from nncf.tensor import Tensor @@ -194,3 +196,20 @@ def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple :param input_shape: Shape of the input. :return: Channel axis number. """ + + def extract_submodel( + self, model_transformer: ModelTransformer, input_id: Tuple[str, int], output_id: Tuple[str, int] + ) -> TModel: + """ + Extracts sub-model using backend-specific ModelTransformer. + + :param model_transformer: Backend-specific ModelTransformer. + :param input_id: Input ID. + :param output_id: Output ID. + :return: Backend-specific sub-model. + """ + model_extraction_command = self.model_extraction_command([input_id], [output_id]) + me_transformation_layout = TransformationLayout() + me_transformation_layout.register(model_extraction_command) + extracted_model = model_transformer.transform(me_transformation_layout) + return extracted_model diff --git a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py index 1f92559eeb8..b79e165228e 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/openvino_backend.py @@ -20,6 +20,7 @@ from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.openvino.graph.metatypes.groups import FAKE_QUANTIZE_OPERATIONS from nncf.openvino.graph.metatypes.groups import OPERATIONS_WITH_BIAS_REDUCED +from nncf.openvino.graph.model_builder import ModelBuilder from nncf.openvino.graph.node_utils import get_activation_channel_axis from nncf.openvino.graph.node_utils import get_bias_value from nncf.openvino.graph.node_utils import is_node_with_bias @@ -33,6 +34,12 @@ class OVFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): + + def __init__(self, model): + # Node mapping caching to reduce time for calculations + self._node_mapping = {op.get_friendly_name(): op for op in model.get_ops()} + self._model_builder = ModelBuilder() + @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: return OVTargetPoint(target_type, target_node_name, port_id) @@ -73,9 +80,8 @@ def create_input_data( input_data = {input_name: blob} return input_data - @staticmethod - def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor: - return Tensor(get_bias_value(node, nncf_graph, model)) + def get_bias_value(self, node: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> Tensor: + return Tensor(get_bias_value(node, nncf_graph, model, node_mapping=self._node_mapping)) @staticmethod def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: @@ -113,3 +119,11 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG @staticmethod def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: return get_activation_channel_axis(node, port_id, input_shape) + + def extract_submodel(self, model_transformer, input_id, output_id): + + return self._model_builder.build( + input_ids=[input_id], + output_ids=[output_id], + node_mapping=self._node_mapping, + )