Skip to content

Commit

Permalink
[TorchFX] Pre-hook insertion support (#2861)
Browse files Browse the repository at this point in the history
### Changes

Torch FX pre-hook insertion support

### Reason for changes

To enable vit_b_16 quantization

### Related tickets

#2766 

### Tests

test_quantized_models is updated by vit_b_16 and swin_v2_s
  • Loading branch information
daniil-lyakhov authored Aug 9, 2024
1 parent 8634037 commit 27296b4
Show file tree
Hide file tree
Showing 7 changed files with 9,267 additions and 640 deletions.
108 changes: 77 additions & 31 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
# Insert call_module nodes to the model
graph = model.graph
for target_point in target_points:
target_node = _get_target_node(graph, target_point)
_insert_call_module(graph, target_node, module_attr_name)
_insert_call_module(graph, target_point, module_attr_name)

return leaf_module_insertion_transformation

Expand Down Expand Up @@ -100,13 +99,12 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule):
" Please use non shared qdq pairs for the weights quantization."
)
for target_point in target_points:
target_node = _get_target_node(model.graph, target_point)
insert_one_qdq_after_node(model, target_node, quantizer)
insert_one_qdq(model, target_point, quantizer)

return qdq_insertion_tranformation


def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize):
def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
"""
Inserts quantize-dequantize after the target node to the target model.
Expand Down Expand Up @@ -146,6 +144,7 @@ def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx

# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
# TODO(dlyakhov): use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
Expand All @@ -170,51 +169,98 @@ def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)
with graph.inserting_after(target_node):

input_node = get_input_node(target_point, target_node)
quantize_op_inputs[0] = input_node

ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
elif target_point.target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
with graph.inserting_after(quantized_node):
dq_node = graph.call_function(dequantize_op, tuple(dq_inputs), {})

args = list(target_node.args)
args[target_point.input_port_id] = dq_node
target_node.args = tuple(args)
else:
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")


def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str):
def _insert_call_module(graph: torch.fx.Graph, target_point: PTTargetPoint, module_attr_name: str):
"""
Inserts module call node to the graph after the target node.
:param graph: Graph to insert module call node.
:param target_node: Target node, module call node is being iserted just after the target node.
:param module_attr_name: The name of the graph attribute which keeps the target module.
"""
with graph.inserting_after(target_node):
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
input_node = get_input_node(target_point, target_node)
ctx_manager = get_ctx_manager(graph, target_point)
with ctx_manager(target_node):
return graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node"
"call_module",
module_attr_name,
(input_node,),
{},
name=f"{module_attr_name}_{str(target_point.target_type)}_graph_node",
)


def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torch.fx.Node:
def get_input_node(target_point: PTTargetPoint, target_node: torch.fx.Node) -> torch.fx.Node:
"""
Returns TorchFX graph node correspondent to the target point.
Returns an input node according to the given target point.
:param graph: Target torch.fx.Graph.
:param target_point: A target point to find the target node.
:return: TorchFX graph node correspondent to the target point.
:param target_point: Given target point.
:param target_node: The target node of the given target point.
:return: An input node according to the given target point.
"""
# TODO(dlyakhov): Support node insertion on a specific input port id.
target_type = target_point.target_type
target_node = get_graph_node_by_name(graph, target_point.target_node_name)
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
elif target_type != TargetType.OPERATOR_POST_HOOK:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node
if target_type not in [
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.OPERATION_WITH_WEIGHTS,
]:
raise nncf.InternalError(f"Unexpected target type: {target_type}")
if target_type == TargetType.OPERATOR_POST_HOOK:
return target_node
return target_node.args[target_point.input_port_id]


def get_ctx_manager(graph: torch.fx.Graph, target_point: PTTargetPoint) -> Callable:
"""
Return insertion context manager according to the given target point.
An insertion context manager sets the point at which create_node and
companion methods will insert into the torch.fx.Graph.
:param graph: torch.fx.Graph instance.
:param target_point: Given target point.
:return: Insertion context manager according to the given target point.
"""
if target_point.target_type not in [
TargetType.OPERATOR_PRE_HOOK,
TargetType.OPERATOR_POST_HOOK,
TargetType.OPERATION_WITH_WEIGHTS,
]:
raise nncf.InternalError(f"Unexpected target type: {target_point.target_type}")

if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
return graph.inserting_after
return graph.inserting_before


def _set_module_to_the_graph_module(
Expand Down
Loading

0 comments on commit 27296b4

Please sign in to comment.