diff --git a/nncf/onnx/graph/nncf_graph_builder.py b/nncf/onnx/graph/nncf_graph_builder.py index 6c0c7291572..594e2444a6f 100644 --- a/nncf/onnx/graph/nncf_graph_builder.py +++ b/nncf/onnx/graph/nncf_graph_builder.py @@ -119,7 +119,14 @@ def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int return None -def _get_common_layer_attributes(node, metatype: ONNXOpMetatype): +def _get_common_layer_attributes(node, metatype: ONNXOpMetatype) -> Optional[BaseLayerAttributes]: + """ + Returns layer-specific layer attributes for the given node. + + :param node: Target Node to get layer attributes for. + :param metatype: Target node metatype. + :return: Target node layer attributes or None. + """ if metatype == ONNXConcatMetatype: axis = [attr.i for attr in node.attribute if attr.name == "axis"][0] num_inputs = len(node.input) diff --git a/nncf/openvino/graph/model_utils.py b/nncf/openvino/graph/model_utils.py index 3a224ba6043..22de60ed0e3 100644 --- a/nncf/openvino/graph/model_utils.py +++ b/nncf/openvino/graph/model_utils.py @@ -53,4 +53,10 @@ def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model: def get_input_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + """ + Get all nodes from given nncf_graph that are identified as a input nodes. + + :param nncf_graph: NNCFGraph to work with. + :return: Target NNCFGraph input nodes. + """ return list(set(nncf_graph.get_input_nodes()).upadte(nncf_graph.get_nodes_by_metatypes([OVReadValueMetatype]))) diff --git a/nncf/openvino/graph/nncf_graph_builder.py b/nncf/openvino/graph/nncf_graph_builder.py index 6bf45598cb0..ea8cb9b0e37 100644 --- a/nncf/openvino/graph/nncf_graph_builder.py +++ b/nncf/openvino/graph/nncf_graph_builder.py @@ -189,14 +189,31 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph: GraphConverter._add_edges_to_nncf_graph(model, nncf_graph) return nncf_graph - def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph): + def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph) -> None: + """ + Sets layer attributes for a non weighted node. + + :param node: Target node. + :param metatype: Target node metatype. + :param nncf_graph: NNCFGraph to work with. + """ if metatype == OVConcatMetatype: nncf_node = nncf_graph.get_node_by_name(node.get_friendly_name()) nncf_node.layer_attributes = OVLayerAttributes( {}, MultipleInputLayerAttributes(axis=node.get_axis(), num_inputs=len(node.inputs())) ) - def _set_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str]): + def _set_weighted_layer_attributes( + node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str] + ) -> None: + """ + Sets layer attributes for a weighted node. + + :param node: Target node. + :param metatype: Target node metatype. + :param nncf_graph: NNCFGraph to work with. + :param visited: Set with node names that were already processed by the GraphConverter. + """ const_attrs, act_attrs = {}, {} for inp in GraphConverter._filter_weight_input_ports(node.inputs(), metatype): inp_name = inp.get_source_output().get_node().get_friendly_name() diff --git a/nncf/quantization/algorithms/accuracy_control/backend.py b/nncf/quantization/algorithms/accuracy_control/backend.py index b03bc1d4f70..649d4620dc5 100644 --- a/nncf/quantization/algorithms/accuracy_control/backend.py +++ b/nncf/quantization/algorithms/accuracy_control/backend.py @@ -54,7 +54,12 @@ def get_quantizable_metatypes() -> List[OperatorMetatype]: @staticmethod @abstractmethod def get_graph_inputs(nncf_graph: NNCFGraph) -> List[NNCFNode]: - pass + """ + Returns a list of NNCFNodes that are identified as an inputs. + + :param nncf_graph: The NNCF graph. + :return: List of NNCFNodes that are identified as an inputs. + """ @staticmethod @abstractmethod diff --git a/nncf/quantization/algorithms/min_max/backend.py b/nncf/quantization/algorithms/min_max/backend.py index 8d940cd561e..3e2e92d9e32 100644 --- a/nncf/quantization/algorithms/min_max/backend.py +++ b/nncf/quantization/algorithms/min_max/backend.py @@ -67,10 +67,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: Property for the backend-specific Dropout metatypes. """ - @abstractmethod - def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]: - pass - @property @abstractmethod def overflow_fix_metatypes(self) -> List[OperatorMetatype]: @@ -143,6 +139,16 @@ def create_quantizer_insertion_command( :return: Backend-specific TransformationCommand for the quantizer insertion operation. """ + @staticmethod + @abstractmethod + def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]: + """ + Returns a list of NNCFNodes that are identified as an inputs. + + :param nncf_graph: NNCFGraph to get input nodes from. + :return: List of NNCFNodes that are identified as an inputs. + """ + @staticmethod @abstractmethod def unify_statistics(statistics: List[MinMaxTensorStatistic]) -> MinMaxTensorStatistic: diff --git a/nncf/quantization/algorithms/min_max/onnx_backend.py b/nncf/quantization/algorithms/min_max/onnx_backend.py index 47c4845fba2..9e0f21758dc 100644 --- a/nncf/quantization/algorithms/min_max/onnx_backend.py +++ b/nncf/quantization/algorithms/min_max/onnx_backend.py @@ -64,9 +64,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]: def add_metatypes(self) -> List[OperatorMetatype]: return [om.ONNXAddLayerMetatype] - def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]: - return nncf_graph.get_input_nodes() - @property def group_conv_metatypes(self) -> List[OperatorMetatype]: return self.conv_metatypes @@ -95,6 +92,10 @@ def hw_config(self) -> HWConfig: def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: return DEFAULT_ONNX_QUANT_TRAIT_TO_OP_DICT + @staticmethod + def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]: + return nncf_graph.get_input_nodes() + @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint: return ONNXTargetPoint(target_type, target_node_name, port_id) diff --git a/nncf/quantization/algorithms/min_max/openvino_backend.py b/nncf/quantization/algorithms/min_max/openvino_backend.py index abb771af685..5cfa0ee4081 100644 --- a/nncf/quantization/algorithms/min_max/openvino_backend.py +++ b/nncf/quantization/algorithms/min_max/openvino_backend.py @@ -87,9 +87,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: def read_variable_metatypes(self) -> List[OperatorMetatype]: return [om.OVReadValueMetatype] - def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]: - return get_input_nodes(nncf_graph) - @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: return {om.OVConcatMetatype: self.overflow_fix_metatypes} @@ -102,6 +99,10 @@ def hw_config(self) -> HWConfig: def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: return DEFAULT_OV_QUANT_TRAIT_TO_OP_DICT + @staticmethod + def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]: + return get_input_nodes(nncf_graph) + @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> OVTargetPoint: return OVTargetPoint(target_type, target_node_name, port_id) diff --git a/nncf/quantization/algorithms/min_max/torch_backend.py b/nncf/quantization/algorithms/min_max/torch_backend.py index c7bb2936f61..44d74bee238 100644 --- a/nncf/quantization/algorithms/min_max/torch_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_backend.py @@ -74,9 +74,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]: def read_variable_metatypes(self) -> List[OperatorMetatype]: return [] - def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]: - return get_inputs_for_graph_with_several_connected_components(nncf_graph) - @property def conv_metatypes(self) -> List[OperatorMetatype]: return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype] @@ -113,6 +110,10 @@ def hw_config(self) -> HWConfig: def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + @staticmethod + def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]: + return get_inputs_for_graph_with_several_connected_components(nncf_graph) + @staticmethod def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: diff --git a/nncf/quantization/passes.py b/nncf/quantization/passes.py index 4b398b7da52..9465fa95dec 100644 --- a/nncf/quantization/passes.py +++ b/nncf/quantization/passes.py @@ -29,11 +29,9 @@ def transform_to_inference_graph( This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows. :param nncf_graph: NNCFGraph instance for the transformation. + :param input_nodes: List of input nodes for the given NNCFGraph. :param shapeof_metatypes: List of backend-specific ShapeOf metatypes. :param dropout_metatypes: List of backend-specific Dropout metatypes. - :param read_variable_metatypes: List of backend-specific metatypes - that also can be interpreted as inputs (ReadValue). - :param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not. :return: NNCFGraph in the inference style. """ remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes) @@ -53,8 +51,7 @@ def remove_shapeof_subgraphs( :param nncf_graph: NNCFGraph instance for the transformation. :param shapeof_metatypes: List of backend-specific ShapeOf metatypes. - :param read_variable_metatypes: List of backend-specific metatypes - that also can be interpreted as inputs (ReadValue). + :param input_nodes: List of input nodes for the given NNCFGraph. :return: NNCFGraph without ShapeOf subgraphs. """ nodes_to_drop = set() @@ -149,8 +146,7 @@ def filter_constant_nodes( The traversing starts from the input nodes and nodes with weights. :param nncf_graph: NNCFGraph instance for the transformation. - :param read_variable_metatypes: List of backend-specific metatypes - that also can be interpreted as inputs (ReadValue). + :param input_nodes: List of input nodes for the given NNCFGraph. :return: NNCFGraph without Constant nodes. """ if not input_nodes: diff --git a/nncf/torch/graph/graph.py b/nncf/torch/graph/graph.py index 352cdd92436..b6f5240a4e9 100644 --- a/nncf/torch/graph/graph.py +++ b/nncf/torch/graph/graph.py @@ -71,7 +71,15 @@ def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope: return matches[0] -def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph): +def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: + """ + Returns a list of NNCFNodes that are identified as an inputs. Requires MultipleInputLayerAttributes + for nodes with several inputs and right `input_edges_num_expected` parameter setted for + nncf nodes metatypes. + + :param nncf_graph: NNCFGraph to get input nodes from. + :return: List of NNCFNodes that are identified as an inputs. + """ input_nodes = set() for node in nncf_graph.get_all_nodes(): input_edges_num_expected = None @@ -84,6 +92,8 @@ def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGra if input_edges_num_expected: input_edges = nncf_graph.get_input_edges(node) if len(input_edges) < input_edges_num_expected: + # If node has missed input edges we assume this node is an input node + # that was disconected from an activation input. input_nodes.add(node) input_nodes.update(nncf_graph.get_input_nodes()) return list(input_nodes) diff --git a/tests/common/quantization/test_quantizer_removal.py b/tests/common/quantization/test_quantizer_removal.py index 91331b81cf0..33a45eeae3f 100644 --- a/tests/common/quantization/test_quantizer_removal.py +++ b/tests/common/quantization/test_quantizer_removal.py @@ -24,6 +24,7 @@ from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES from tests.common.quantization.metatypes import QUANTIZER_METATYPES +from tests.common.quantization.metatypes import ShapeOfTestMetatype @dataclass @@ -228,7 +229,7 @@ def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: TestCase) # As test graphs are fully connected and does not have readvariable metatyep, # this should work input_nodes = nncf_graph.get_input_nodes() - nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), input_nodes) + nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), ShapeOfTestMetatype, input_nodes) nodes, ops = find_quantizer_nodes_to_cut( nncf_graph_without_shapeof, quantizer_node,