Skip to content

Commit

Permalink
get_inputs -> get_start_nodes_for_activation_path_tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 22, 2023
1 parent 52a2e1a commit 8dc09e8
Show file tree
Hide file tree
Showing 12 changed files with 28 additions and 23 deletions.
8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/accuracy_control/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def get_quantizable_metatypes() -> List[OperatorMetatype]:

@staticmethod
@abstractmethod
def get_graph_inputs(nncf_graph: NNCFGraph) -> List[NNCFNode]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[NNCFNode]:
"""
Returns a list of NNCFNodes that are identified as an inputs.
Returns a list of NNCFNodes to use as start nodes for activation path tracing.
:param nncf_graph: The NNCF graph.
:return: List of NNCFNodes that are identified as an inputs.
:param nncf_graph: NNCFGraph to get the start nodes.
:return: List of NNCFNodes to use as start nodes for activation path tracing.
"""

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_shapeof_metatypes() -> List[OVOpMetatype]:
return SHAPEOF_OPERATIONS

@staticmethod
def get_graph_inputs(nncf_graph: NNCFGraph) -> List[NNCFNode]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[NNCFNode]:
return get_input_nodes(nncf_graph)

# Manipulations with bias value and weights
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/accuracy_control/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) ->
]

quantized_model_graph_without_shapeof = remove_shapeof_subgraphs(
deepcopy(quantized_model_graph), self._algo_backend.get_graph_inputs(quantized_model_graph)
deepcopy(quantized_model_graph),
self._algo_backend.get_start_nodes_for_activation_path_tracing(quantized_model_graph),
)

for quantizer_node in reversed(quantizers):
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def _get_quantization_target_points(

inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
self._backend_entity.get_input_nodes(nncf_graph),
self._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
self._backend_entity.shapeof_metatypes,
self._backend_entity.dropout_metatypes,
)
Expand Down
8 changes: 4 additions & 4 deletions nncf/quantization/algorithms/min_max/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ def create_quantizer_insertion_command(

@staticmethod
@abstractmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
"""
Returns a list of NNCFNodes that are identified as an inputs.
Returns a list of NNCFNodes to use as start nodes for activation path tracing.
:param nncf_graph: NNCFGraph to get input nodes from.
:return: List of NNCFNodes that are identified as an inputs.
:param nncf_graph: NNCFGraph to get the start nodes.
:return: List of NNCFNodes to use as start nodes for activation path tracing.
"""

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/onnx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_ONNX_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return nncf_graph.get_input_nodes()

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_OV_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_input_nodes(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: NNCFGraph) -> List[OperatorMetatype]:
return get_input_nodes(nncf_graph)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion nncf/quantization/algorithms/min_max/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_input_nodes(nncf_graph: PTNNCFGraph) -> List[OperatorMetatype]:
def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[OperatorMetatype]:
return nncf_graph.get_disconnected_nodes() + nncf_graph.get_input_nodes()

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion tests/post_training/test_templates/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
conv_layer_attrs=None,
input_layer_attrs=None,
output_layer_attrs=None,
nncf_graph_cls=NNCFGraph,
):
# Original graph
# Input_1
Expand All @@ -65,7 +66,7 @@ def __init__(
]
node_edges = [("Input_1", "Conv_1"), ("Conv_1", "Output_1")]
original_mock_graph = create_mock_graph(nodes, node_edges)
self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph)
self.nncf_graph = get_nncf_graph_from_mock_nx_graph(original_mock_graph, nncf_graph_cls)


class NNCFGraphToTestSumAggregation:
Expand Down
11 changes: 7 additions & 4 deletions tests/post_training/test_templates/test_ptq_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_quantize_outputs(self, test_params, quantize_outputs):
ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"]
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand All @@ -186,7 +186,7 @@ def test_ignored_scopes(self, test_params, ignored_scopes_data):
ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"]
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand All @@ -211,7 +211,7 @@ def test_model_type_pass(self, test_params, model_type):
ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"]
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand Down Expand Up @@ -277,7 +277,10 @@ def test_quantization_points_overflow_fix(self, overflow_fix, affected_target_po
def test_validate_scope(self, test_params, validate_scopes):
nncf_graph = test_params["test_model_type_pass"]["nncf_graph"]
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph), self.get_algo_backend().get_input_nodes(nncf_graph), [], []
deepcopy(nncf_graph),
self.get_algo_backend().get_start_nodes_for_activation_path_tracing(nncf_graph),
[],
[],
)
ignored_patterns = test_params["test_model_type_pass"]["ignored_patterns"]
algo = MinMaxQuantization(
Expand Down
6 changes: 3 additions & 3 deletions tests/post_training/test_templates/test_quantizer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_default_quantizer_config(self, single_conv_nncf_graph):
nncf_graph = single_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_quantizer_config_from_ptq_params_for_CPU(
nncf_graph = single_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_depthwise_conv_default_quantizer_config(self, depthwise_conv_nncf_graph
nncf_graph = depthwise_conv_nncf_graph.nncf_graph
inference_nncf_graph = transform_to_inference_graph(
deepcopy(nncf_graph),
min_max_algo._backend_entity.get_input_nodes(nncf_graph),
min_max_algo._backend_entity.get_start_nodes_for_activation_path_tracing(nncf_graph),
min_max_algo._backend_entity.shapeof_metatypes,
min_max_algo._backend_entity.dropout_metatypes,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/ptq/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_depthwise_conv_nncf_graph() -> NNCFGraphToTestDepthwiseConv:
transpose=False,
padding_values=(1, 1),
)
return NNCFGraphToTestDepthwiseConv(PTDepthwiseConv2dSubtype, conv_layer_attrs)
return NNCFGraphToTestDepthwiseConv(PTDepthwiseConv2dSubtype, conv_layer_attrs, nncf_graph_cls=PTNNCFGraph)


def get_single_no_weight_matmul_nncf_graph() -> NNCFGraphToTest:
Expand Down

0 comments on commit 8dc09e8

Please sign in to comment.