Skip to content

Commit

Permalink
[Relay][Frontend][QNN] fix access param_debug_name_map to node outp…
Browse files Browse the repository at this point in the history
…ut name in fx-quantized graph node replacement (#16217)

* update qnn_torch.py

* remove unused function
  • Loading branch information
PineApple777 authored Dec 27, 2023
1 parent 1c45389 commit 506eff2
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tvm.relay.frontend.common import infer_shape

from .common import logger
from .pytorch_utils import is_version_greater_than, getattr_attr_name
from .pytorch_utils import is_version_greater_than


class QNNParam(object):
Expand Down Expand Up @@ -540,18 +540,11 @@ def inline_input_quant_params_for_fx(graph, params, param_debug_name_map):
# pylint: disable=c-extension-no-member
import torch

def get_full_attr_name(current):
current_attr = getattr_attr_name(current)
inputs = list(current.inputs())
if len(inputs) == 1 and inputs[0].node().kind() == "prim::GetAttr":
return get_full_attr_name(inputs[0].node()) + "." + current_attr
return current_attr

for node in graph.findAllNodes("prim::GetAttr", recurse=True):
out_name = node.output().debugName()

if "_scale" in out_name or "_zero_point" in out_name:
full_attr = param_debug_name_map[get_full_attr_name(node)]
full_attr = param_debug_name_map[out_name]
assert full_attr in params, f"{full_attr} not found in param dict."
param_np = params[full_attr].numpy()
new_const_node = graph.create("prim::Constant")
Expand Down

0 comments on commit 506eff2

Please sign in to comment.