Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed May 31, 2024
1 parent ceb9c4b commit 27308ca
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 95 deletions.
7 changes: 0 additions & 7 deletions nncf/common/hardware/configs/cpu.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@
"weights": ["q8_w_sym", "q8_w_asym"]
}
},
{
"type": "Add",
"quantization": {
"activations": "q8_a",
"weights": ["q8_w_sym", "q8_w_asym"]
}
},
{
"type": "Multiply",
"quantization": {
Expand Down
17 changes: 7 additions & 10 deletions nncf/experimental/torch_fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,21 @@ def _get_grah_node_by_name(graph, name):
return node

@staticmethod
def _get_target_node_and_ctx(graph: torch.fx.Graph, target_point: PTTargetPoint):
def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint):
target_type = target_point.target_type
target_node = FXModelTransformer._get_grah_node_by_name(graph, target_point.target_node_name)
if target_type == TargetType.OPERATOR_PRE_HOOK:
ctx = graph.inserting_before(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATION_WITH_WEIGHTS:
if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]:
target_node = target_node.all_input_nodes[target_point.input_port_id]
ctx = graph.inserting_after(target_node)
elif target_type == TargetType.OPERATOR_POST_HOOK:
pass
else:
raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}")
return target_node, ctx
return target_node

@staticmethod
def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str):
target_node, ctx = FXModelTransformer._get_target_node_and_ctx(graph, target_point)
with ctx:
target_node = FXModelTransformer._get_target_node(graph, target_point)
with graph.inserting_after(target_node):
graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node")

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -18,7 +17,7 @@

# If a metatype is not in this list, then it is considered to be QuantizationTrait.NON_QUANTIZABLE.

DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT: Dict[QuantizationTrait, List[PTOperatorMetatype]] = {
DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT: Dict[QuantizationTrait, List[PTOperatorMetatype]] = {
QuantizationTrait.INPUTS_QUANTIZABLE: [
operator_metatypes.PTConv2dMetatype,
operator_metatypes.PTModuleConv2dMetatype,
Expand All @@ -36,7 +35,7 @@
operator_metatypes.PTModuleLinearMetatype,
operator_metatypes.PTLayerNormMetatype,
operator_metatypes.PTModuleLayerNormMetatype,
#operator_metatypes.PTAddMetatype,
# operator_metatypes.PTAddMetatype,
operator_metatypes.PTMulMetatype,
operator_metatypes.PTDivMetatype,
operator_metatypes.PTMatMulMetatype,
Expand Down
137 changes: 68 additions & 69 deletions nncf/experimental/torch_fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from torch.quantization.fake_quantize import FakeQuantize

from nncf.experimental.torch_fx.model_transformer import FXModelTransformer
from nncf.quantization.fake_quantize import FakeQuantizeParameters
from nncf.torch.graph.transformations.commands import PTTargetPoint
from nncf.torch.quantization.layers import PTQuantizerSpec


def stat_collectorts_insertion_tranformation_builder():
Expand All @@ -34,8 +32,8 @@ def fake_quantize_insertion_transformation(model: torch.fx.GraphModule):
module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points)
graph = model.graph
for target_point in target_points:
target_node, ctx = FXModelTransformer._get_target_node_and_ctx(model.graph, target_point)
with ctx:
target_node = FXModelTransformer._get_target_node(model.graph, target_point)
with graph.inserting_after(target_node):
fq_node = graph.create_node(
"call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer"
)
Expand Down Expand Up @@ -66,70 +64,71 @@ def _set_module_to_the_graph_module(
return module_name_in_model


def qdq_insertion_tranformation_builder(qspec: PTQuantizerSpec, fq_params: FakeQuantizeParameters, axis: int, eps=1e-5):
# signed = bool(torch.any(fq_params.input_low.data < 0))
# Subtract eps from the scale to make quantizer parameters equal to
# original parameters on the forward call.
scale = (fq_params.input_high.data - eps).reshape(qspec.scale_shape)

def qdq_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node):
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op: Optional[Callable] = None
# scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
if qspec.per_channel:
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
qparams = {
"_scale_": scale,
"_zero_point_": 0,
"_axis_": axis,
"_quant_min_": 0,
"_quant_max_": 2**qspec.num_bits - 1,
"_dtype_": torch.int8,
}
# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
# TODO: use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
# so, inserting quantize-dequantize and all constant nodes before the usage of the nodes
with graph.inserting_before(node):
quantize_op_inputs = [node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))):
# For scale and zero_point values we register them as buffers in the root module.
# However, note that when the values are not tensors, as in the case of
# per_tensor quantization, they will be treated as literals.
# However, registering them as a node seems to cause issue with dynamo
# tracing where it may consider tensor overload as opposed to default.
# With extra check of scale and zero_point being scalar, it makes
# sure that the default overload can be used.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, node.name + key, value_or_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)
with graph.inserting_after(node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(node, dq_node)
def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]):
def qdq_insertion_tranformation(model: torch.fx.GraphModule):
for target_point in target_points:
target_node = FXModelTransformer._get_target_node(model.graph, target_point)
insert_one_qdq(model, target_node, quantizer)

return qdq_insertion_tranformation


def insert_one_qdq(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize):
# 1. extract information for inserting q/dq node from activation_post_process
node_type = "call_function"
quantize_op: Optional[Callable] = None
# scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator]
if quantizer.is_per_channel:
quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default
else:
quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default
dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default
# TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize
qparams = {
"_scale_": quantizer.scale,
"_zero_point_": quantizer.zero_point,
"_axis_": quantizer.ch_axis,
"_quant_min_": quantizer.quant_min,
"_quant_max_": quantizer.quant_max,
"_dtype_": torch.int8,
}
# 2. replace activation_post_process node with quantize and dequantize
graph = model.graph
# TODO: use metatype to get correct input_port_id
# Do not quantize already quantized nodes
# inserting_before handle only order in the graph generated code.
# so, inserting quantize-dequantize and all constant nodes before the usage of the nodes
with graph.inserting_before(target_node):
quantize_op_inputs = [target_node]
for key, value_or_node in qparams.items():
# TODO: we can add the information of whether a value needs to
# be registered as an attribute in qparams dict itself
if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))):
# For scale and zero_point values we register them as buffers in the root module.
# However, note that when the values are not tensors, as in the case of
# per_tensor quantization, they will be treated as literals.
# However, registering them as a node seems to cause issue with dynamo
# tracing where it may consider tensor overload as opposed to default.
# With extra check of scale and zero_point being scalar, it makes
# sure that the default overload can be used.
# TODO: maybe need more complex attr name here
qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node)
quantize_op_inputs.append(qparam_node)
else:
# for qparams that are not scale/zero_point (like axis, dtype) we store
# them as literals in the graph.
quantize_op_inputs.append(value_or_node)
with graph.inserting_after(target_node):
quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {})
# use the same qparams from quantize op
dq_inputs = [quantized_node] + quantize_op_inputs[1:]
user_dq_nodes = []
with graph.inserting_after(quantized_node):
for user in target_node.users:
if user is quantized_node:
continue
user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {})))

for user, dq_node in user_dq_nodes:
user.replace_input_with(target_node, dq_node)
13 changes: 8 additions & 5 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@
from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand
from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder
from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT
from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder # noqa
from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand Down Expand Up @@ -117,7 +118,7 @@ def hw_config(self) -> HWConfig:

@property
def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]:
return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT
return DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT

@staticmethod
def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]:
Expand Down Expand Up @@ -297,7 +298,8 @@ def create_quantizer_insertion_command(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
transformation = fake_quantize_insertion_tranformation_builder(quantizer, [target_point])
# transformation = fake_quantize_insertion_tranformation_builder(quantizer, [target_point])
transformation = qdq_insertion_tranformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)

@staticmethod
Expand All @@ -315,7 +317,8 @@ def create_unified_scales_quantizers_insertion_commands(
quantizer_config, scale_shape, parameters, target_points[0].target_type
)

transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points)
# transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points)
transformation = qdq_insertion_tranformation_builder(quantizer, target_points)
return [FXApplyTransformationCommand(transformation)]

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion torch_compile_ex_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_exported_model_from_nn_module(module, example_inputs):
return capture_pre_autograd_graph(module, example_inputs)


NNCF_IMPL = True
NNCF_IMPL = False


def get_qsetup(exported_model, example_inputs):
Expand Down

0 comments on commit 27308ca

Please sign in to comment.