From d982620cc138a5f68f40a770f572d9894bf6a903 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Mon, 16 Oct 2023 15:28:50 +0100 Subject: [PATCH] Optimize NNCFGraph.get_node_by_name() method (#2190) ### Changes Linear search was removed inside the `NNCFGraph.get_node_by_name()` method. ### Reason for changes The implementation of the `NNCFGraph.get_node_by_name()` method is very slow. The total time spent on its execution (several calls during quantization) is 216.887 seconds for the databricks/dolly-v2-3b model. This method currently operates at O(N) time complexity, where N is the number of nodes in the `NNCFGraph`. I expected it to have O(1) complexity. ### Related tickets Ref: 119299 ### Tests pre-commit scope --- nncf/common/graph/graph.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/nncf/common/graph/graph.py b/nncf/common/graph/graph.py index 16b073d0e47..08a1fd587bd 100644 --- a/nncf/common/graph/graph.py +++ b/nncf/common/graph/graph.py @@ -196,9 +196,9 @@ def __init__(self): self._nodes: Dict[str, NNCFNode] = {} self._input_nncf_nodes: Dict[int, NNCFNode] = {} self._output_nncf_nodes: Dict[int, NNCFNode] = {} - self._node_ids_vs_layer_names: Dict[int, LayerName] = {} self._layer_name_vs_shared_nodes: Dict[LayerName, List[NNCFNode]] = defaultdict(list) + self._node_name_to_node_id_map: Dict[str, List[int]] = {} @property def nodes(self) -> Dict[str, NNCFNode]: @@ -458,6 +458,9 @@ def add_nncf_node( if node_id in self._node_id_to_key_dict: raise ValueError(f"NNCF node with id {node_id} is already in the NNCFGraph") + node_ids = self._node_name_to_node_id_map.setdefault(node_name, []) + node_ids.append(node_id) + node_key = f"{node_id} {node_name}" self._node_id_to_key_dict[node_id] = node_key @@ -635,16 +638,14 @@ def _get_graph_for_visualization(self) -> nx.DiGraph: return out_graph def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode: - matches = [node for node in self.get_all_nodes() if node.node_name == name] - if not matches: + node_ids = self._node_name_to_node_id_map.get(name, None) + if node_ids is None: raise RuntimeError("Could not find a node {} in NNCFGraph!".format(name)) - if len(matches) > 1: - raise RuntimeError( - "More than one node in NNCFGraph matches name {}:\n{}".format( - name, "\t\n".join([str(n.node_id) for n in matches]) - ) - ) - return next(iter(matches)) + if len(node_ids) > 1: + raise RuntimeError(f"More than one node in NNCFGraph matches name {name}") + + node_key = f"{node_ids[0]} {name}" + return self._nodes[node_key] def __eq__(self, other: "NNCFGraph"): nm = iso.categorical_node_match(