Skip to content

Commit

Permalink
Optimize NNCFGraph.get_node_by_name() method (#2190)
Browse files Browse the repository at this point in the history
### Changes

Linear search was removed inside the `NNCFGraph.get_node_by_name()`
method.

### Reason for changes

The implementation of the `NNCFGraph.get_node_by_name()` method is very
slow. The total time spent on its execution (several calls during
quantization) is 216.887 seconds for the databricks/dolly-v2-3b model.
This method currently operates at O(N) time complexity, where N is the
number of nodes in the `NNCFGraph`. I expected it to have O(1)
complexity.

### Related tickets

Ref: 119299

### Tests

pre-commit scope
  • Loading branch information
andrey-churkin authored Oct 16, 2023
1 parent 9b75974 commit d982620
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions nncf/common/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def __init__(self):
self._nodes: Dict[str, NNCFNode] = {}
self._input_nncf_nodes: Dict[int, NNCFNode] = {}
self._output_nncf_nodes: Dict[int, NNCFNode] = {}

self._node_ids_vs_layer_names: Dict[int, LayerName] = {}
self._layer_name_vs_shared_nodes: Dict[LayerName, List[NNCFNode]] = defaultdict(list)
self._node_name_to_node_id_map: Dict[str, List[int]] = {}

@property
def nodes(self) -> Dict[str, NNCFNode]:
Expand Down Expand Up @@ -458,6 +458,9 @@ def add_nncf_node(
if node_id in self._node_id_to_key_dict:
raise ValueError(f"NNCF node with id {node_id} is already in the NNCFGraph")

node_ids = self._node_name_to_node_id_map.setdefault(node_name, [])
node_ids.append(node_id)

node_key = f"{node_id} {node_name}"

self._node_id_to_key_dict[node_id] = node_key
Expand Down Expand Up @@ -635,16 +638,14 @@ def _get_graph_for_visualization(self) -> nx.DiGraph:
return out_graph

def get_node_by_name(self, name: NNCFNodeName) -> NNCFNode:
matches = [node for node in self.get_all_nodes() if node.node_name == name]
if not matches:
node_ids = self._node_name_to_node_id_map.get(name, None)
if node_ids is None:
raise RuntimeError("Could not find a node {} in NNCFGraph!".format(name))
if len(matches) > 1:
raise RuntimeError(
"More than one node in NNCFGraph matches name {}:\n{}".format(
name, "\t\n".join([str(n.node_id) for n in matches])
)
)
return next(iter(matches))
if len(node_ids) > 1:
raise RuntimeError(f"More than one node in NNCFGraph matches name {name}")

node_key = f"{node_ids[0]} {name}"
return self._nodes[node_key]

def __eq__(self, other: "NNCFGraph"):
nm = iso.categorical_node_match(
Expand Down

0 comments on commit d982620

Please sign in to comment.