Skip to content

Commit

Permalink
Problems with compression subgraph matching
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Dec 17, 2024
1 parent cf228d9 commit 3b255c3
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions nncf/experimental/torch/sparsify_activations/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def insert_sparsifiers(

@staticmethod
def get_activation_port_id(matmul_node: NNCFNode, nncf_graph: NNCFGraph) -> int:
# return 0
return 0
n_inputs = len(nncf_graph.get_input_edges(matmul_node))
if n_inputs != 2:
raise RuntimeError(f"Expected node to have two inputs, but found {n_inputs} for node {matmul_node}.")
Expand All @@ -78,24 +78,26 @@ def get_activation_port_id(matmul_node: NNCFNode, nncf_graph: NNCFGraph) -> int:
nncf_graph.get_input_edges(matmul_node)[i].from_node.node_type == "Constant" for i in range(2)
]
if is_const_node_on_port[0] != is_const_node_on_port[1]:
assert not is_const_node_on_port[0], matmul_node.node_name
return 1 if is_const_node_on_port[0] else 0

# Try to match compressed constant subgraph
for i in range(2):
node = nncf_graph.get_input_edges(matmul_node)[i].from_node
if node.node_type == "Convert":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Multiply":
node = nncf_graph.get_input_edges(node)[0].from_node
else:
continue
if node.node_type == "Subtract":
if node.node_type == "Reshape":
node = nncf_graph.get_input_edges(node)[0].from_node
else:
continue
if node.node_type == "Convert":
if node.node_type == "Multiply":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Subtract":
node = nncf_graph.get_input_edges(node)[0].from_node
if node.node_type == "Convert":
node = nncf_graph.get_input_edges(node)[0].from_node
else:
continue
if node.node_type == "Constant":
assert i == 1, matmul_node.node_name
return int(i == 0)

raise RuntimeError(f"Could not find activation port id for node {matmul_node}.")

0 comments on commit 3b255c3

Please sign in to comment.