Skip to content

Commit

Permalink
Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 16, 2023
1 parent ed67ed8 commit b1fc25a
Show file tree
Hide file tree
Showing 10 changed files with 73 additions and 23 deletions.
9 changes: 8 additions & 1 deletion nncf/onnx/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,14 @@ def get_bias_tensor_port_id(metatype: ONNXOpWithWeightsMetatype) -> Optional[int
return None


def _get_common_layer_attributes(node, metatype: ONNXOpMetatype):
def _get_common_layer_attributes(node, metatype: ONNXOpMetatype) -> Optional[BaseLayerAttributes]:
"""
Returns layer-specific layer attributes for the given node.
:param node: Target Node to get layer attributes for.
:param metatype: Target node metatype.
:return: Target node layer attributes or None.
"""
if metatype == ONNXConcatMetatype:
axis = [attr.i for attr in node.attribute if attr.name == "axis"][0]
num_inputs = len(node.input)
Expand Down
6 changes: 6 additions & 0 deletions nncf/openvino/graph/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,10 @@ def remove_fq_from_inputs(model: ov.Model, graph: NNCFGraph) -> ov.Model:


def get_input_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]:
"""
Get all nodes from given nncf_graph that are identified as a input nodes.
:param nncf_graph: NNCFGraph to work with.
:return: Target NNCFGraph input nodes.
"""
return list(set(nncf_graph.get_input_nodes()).upadte(nncf_graph.get_nodes_by_metatypes([OVReadValueMetatype])))
21 changes: 19 additions & 2 deletions nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,31 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
GraphConverter._add_edges_to_nncf_graph(model, nncf_graph)
return nncf_graph

def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph):
def _set_non_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph) -> None:
"""
Sets layer attributes for a non weighted node.
:param node: Target node.
:param metatype: Target node metatype.
:param nncf_graph: NNCFGraph to work with.
"""
if metatype == OVConcatMetatype:
nncf_node = nncf_graph.get_node_by_name(node.get_friendly_name())
nncf_node.layer_attributes = OVLayerAttributes(
{}, MultipleInputLayerAttributes(axis=node.get_axis(), num_inputs=len(node.inputs()))
)

def _set_weighted_layer_attributes(node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str]):
def _set_weighted_layer_attributes(
node: ov.Node, metatype: OVOpMetatype, nncf_graph: NNCFGraph, visited: Set[str]
) -> None:
"""
Sets layer attributes for a weighted node.
:param node: Target node.
:param metatype: Target node metatype.
:param nncf_graph: NNCFGraph to work with.
:param visited: Set with node names that were already processed by the GraphConverter.
"""
const_attrs, act_attrs = {}, {}
for inp in GraphConverter._filter_weight_input_ports(node.inputs(), metatype):
inp_name = inp.get_source_output().get_node().get_friendly_name()
Expand Down
7 changes: 6 additions & 1 deletion nncf/quantization/algorithms/accuracy_control/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,12 @@ def get_quantizable_metatypes() -> List[OperatorMetatype]:
@staticmethod
@abstractmethod
def get_graph_inputs(nncf_graph: NNCFGraph) -> List[NNCFNode]:
pass
"""
Returns a list of NNCFNodes that are identified as an inputs.
:param nncf_graph: The NNCF graph.
:return: List of NNCFNodes that are identified as an inputs.
"""

@staticmethod
@abstractmethod
Expand Down
14 changes: 10 additions & 4 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
Property for the backend-specific Dropout metatypes.
"""

@abstractmethod
def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
pass

@property
@abstractmethod
def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
Expand Down Expand Up @@ -143,6 +139,16 @@ def create_quantizer_insertion_command(
:return: Backend-specific TransformationCommand for the quantizer insertion operation.
"""

@staticmethod
@abstractmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
"""
Returns a list of NNCFNodes that are identified as an inputs.
:param nncf_graph: NNCFGraph to get input nodes from.
:return: List of NNCFNodes that are identified as an inputs.
"""

@staticmethod
@abstractmethod
def unify_statistics(statistics: List[MinMaxTensorStatistic]) -> MinMaxTensorStatistic:
Expand Down
7 changes: 4 additions & 3 deletions nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,6 @@ def overflow_fix_metatypes(self) -> List[OperatorMetatype]:
def add_metatypes(self) -> List[OperatorMetatype]:
return [om.ONNXAddLayerMetatype]

def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return nncf_graph.get_input_nodes()

@property
def group_conv_metatypes(self) -> List[OperatorMetatype]:
return self.conv_metatypes
Expand Down Expand Up @@ -95,6 +92,10 @@ def hw_config(self) -> HWConfig:
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_ONNX_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return nncf_graph.get_input_nodes()

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> ONNXTargetPoint:
return ONNXTargetPoint(target_type, target_node_name, port_id)
Expand Down
7 changes: 4 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,6 @@ def dropout_metatypes(self) -> List[OperatorMetatype]:
def read_variable_metatypes(self) -> List[OperatorMetatype]:
return []

def get_input_nodes(self, nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return get_inputs_for_graph_with_several_connected_components(nncf_graph)

@property
def conv_metatypes(self) -> List[OperatorMetatype]:
return [om.PTModuleConv1dMetatype, om.PTModuleConv2dMetatype, om.PTModuleConv3dMetatype]
Expand Down Expand Up @@ -113,6 +110,10 @@ def hw_config(self) -> HWConfig:
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return get_inputs_for_graph_with_several_connected_components(nncf_graph)

@staticmethod
def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint:
if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION:
Expand Down
10 changes: 3 additions & 7 deletions nncf/quantization/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@ def transform_to_inference_graph(
This method contains inplace pipeline of the passes that uses to provide inference graph without constant flows.
:param nncf_graph: NNCFGraph instance for the transformation.
:param input_nodes: List of input nodes for the given NNCFGraph.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
:param dropout_metatypes: List of backend-specific Dropout metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param nncf_graph_contains_constants: Whether NNCFGraph contains constant nodes or not.
:return: NNCFGraph in the inference style.
"""
remove_shapeof_subgraphs(nncf_graph, shapeof_metatypes, input_nodes)
Expand All @@ -53,8 +51,7 @@ def remove_shapeof_subgraphs(
:param nncf_graph: NNCFGraph instance for the transformation.
:param shapeof_metatypes: List of backend-specific ShapeOf metatypes.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param input_nodes: List of input nodes for the given NNCFGraph.
:return: NNCFGraph without ShapeOf subgraphs.
"""
nodes_to_drop = set()
Expand Down Expand Up @@ -149,8 +146,7 @@ def filter_constant_nodes(
The traversing starts from the input nodes and nodes with weights.
:param nncf_graph: NNCFGraph instance for the transformation.
:param read_variable_metatypes: List of backend-specific metatypes
that also can be interpreted as inputs (ReadValue).
:param input_nodes: List of input nodes for the given NNCFGraph.
:return: NNCFGraph without Constant nodes.
"""
if not input_nodes:
Expand Down
12 changes: 11 additions & 1 deletion nncf/torch/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,15 @@ def get_scope_by_node_name(self, node_name: NNCFNodeName) -> Scope:
return matches[0]


def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph):
def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGraph) -> List[NNCFNode]:
"""
Returns a list of NNCFNodes that are identified as an inputs. Requires MultipleInputLayerAttributes
for nodes with several inputs and right `input_edges_num_expected` parameter setted for
nncf nodes metatypes.
:param nncf_graph: NNCFGraph to get input nodes from.
:return: List of NNCFNodes that are identified as an inputs.
"""
input_nodes = set()
for node in nncf_graph.get_all_nodes():
input_edges_num_expected = None
Expand All @@ -84,6 +92,8 @@ def get_inputs_for_graph_with_several_connected_components(nncf_graph: PTNNCFGra
if input_edges_num_expected:
input_edges = nncf_graph.get_input_edges(node)
if len(input_edges) < input_edges_num_expected:
# If node has missed input edges we assume this node is an input node
# that was disconected from an activation input.
input_nodes.add(node)
input_nodes.update(nncf_graph.get_input_nodes())
return list(input_nodes)
3 changes: 2 additions & 1 deletion tests/common/quantization/test_quantizer_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tests.common.quantization.metatypes import QUANTIZABLE_METATYPES
from tests.common.quantization.metatypes import QUANTIZE_AGNOSTIC_METATYPES
from tests.common.quantization.metatypes import QUANTIZER_METATYPES
from tests.common.quantization.metatypes import ShapeOfTestMetatype


@dataclass
Expand Down Expand Up @@ -228,7 +229,7 @@ def test_find_quantizer_nodes_to_cut(nncf_graph: NNCFGraph, test_case: TestCase)
# As test graphs are fully connected and does not have readvariable metatyep,
# this should work
input_nodes = nncf_graph.get_input_nodes()
nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), input_nodes)
nncf_graph_without_shapeof = remove_shapeof_subgraphs(deepcopy(nncf_graph), ShapeOfTestMetatype, input_nodes)
nodes, ops = find_quantizer_nodes_to_cut(
nncf_graph_without_shapeof,
quantizer_node,
Expand Down

0 comments on commit b1fc25a

Please sign in to comment.