diff --git a/nncf/common/hardware/configs/cpu.json b/nncf/common/hardware/configs/cpu.json index 0cd5290ae6d..4b39be807d9 100644 --- a/nncf/common/hardware/configs/cpu.json +++ b/nncf/common/hardware/configs/cpu.json @@ -64,13 +64,6 @@ "weights": ["q8_w_sym", "q8_w_asym"] } }, - { - "type": "Add", - "quantization": { - "activations": "q8_a", - "weights": ["q8_w_sym", "q8_w_asym"] - } - }, { "type": "Multiply", "quantization": { diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py index a3ac2caacb8..49f54d7b534 100644 --- a/nncf/experimental/torch_fx/model_transformer.py +++ b/nncf/experimental/torch_fx/model_transformer.py @@ -147,24 +147,21 @@ def _get_grah_node_by_name(graph, name): return node @staticmethod - def _get_target_node_and_ctx(graph: torch.fx.Graph, target_point: PTTargetPoint): + def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): target_type = target_point.target_type target_node = FXModelTransformer._get_grah_node_by_name(graph, target_point.target_node_name) - if target_type == TargetType.OPERATOR_PRE_HOOK: - ctx = graph.inserting_before(target_node) - elif target_type == TargetType.OPERATOR_POST_HOOK: - ctx = graph.inserting_after(target_node) - elif target_type == TargetType.OPERATION_WITH_WEIGHTS: + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: target_node = target_node.all_input_nodes[target_point.input_port_id] - ctx = graph.inserting_after(target_node) + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass else: raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") - return target_node, ctx + return target_node @staticmethod def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): - target_node, ctx = FXModelTransformer._get_target_node_and_ctx(graph, target_point) - with ctx: + target_node = FXModelTransformer._get_target_node(graph, target_point) + with graph.inserting_after(target_node): graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") @staticmethod diff --git a/nncf/experimental/torch_fx/quantization/default_quantization.py b/nncf/experimental/torch_fx/quantization/default_quantization.py index c824f9d6493..978bbee88c6 100644 --- a/nncf/experimental/torch_fx/quantization/default_quantization.py +++ b/nncf/experimental/torch_fx/quantization/default_quantization.py @@ -1,4 +1,3 @@ - # Copyright (c) 2024 Intel Corporation # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +17,7 @@ # If a metatype is not in this list, then it is considered to be QuantizationTrait.NON_QUANTIZABLE. -DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT: Dict[QuantizationTrait, List[PTOperatorMetatype]] = { +DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT: Dict[QuantizationTrait, List[PTOperatorMetatype]] = { QuantizationTrait.INPUTS_QUANTIZABLE: [ operator_metatypes.PTConv2dMetatype, operator_metatypes.PTModuleConv2dMetatype, @@ -36,7 +35,7 @@ operator_metatypes.PTModuleLinearMetatype, operator_metatypes.PTLayerNormMetatype, operator_metatypes.PTModuleLayerNormMetatype, - #operator_metatypes.PTAddMetatype, + # operator_metatypes.PTAddMetatype, operator_metatypes.PTMulMetatype, operator_metatypes.PTDivMetatype, operator_metatypes.PTMatMulMetatype, diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py index ad53b89ca18..88a77c4b1bd 100644 --- a/nncf/experimental/torch_fx/transformations.py +++ b/nncf/experimental/torch_fx/transformations.py @@ -17,9 +17,7 @@ from torch.quantization.fake_quantize import FakeQuantize from nncf.experimental.torch_fx.model_transformer import FXModelTransformer -from nncf.quantization.fake_quantize import FakeQuantizeParameters from nncf.torch.graph.transformations.commands import PTTargetPoint -from nncf.torch.quantization.layers import PTQuantizerSpec def stat_collectorts_insertion_tranformation_builder(): @@ -34,8 +32,8 @@ def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) graph = model.graph for target_point in target_points: - target_node, ctx = FXModelTransformer._get_target_node_and_ctx(model.graph, target_point) - with ctx: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + with graph.inserting_after(target_node): fq_node = graph.create_node( "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" ) @@ -66,70 +64,71 @@ def _set_module_to_the_graph_module( return module_name_in_model -def qdq_insertion_tranformation_builder(qspec: PTQuantizerSpec, fq_params: FakeQuantizeParameters, axis: int, eps=1e-5): - # signed = bool(torch.any(fq_params.input_low.data < 0)) - # Subtract eps from the scale to make quantizer parameters equal to - # original parameters on the forward call. - scale = (fq_params.input_high.data - eps).reshape(qspec.scale_shape) - - def qdq_insertion_tranformation(model: torch.fx.GraphModule, node: torch.fx.Node): - # 1. extract information for inserting q/dq node from activation_post_process - node_type = "call_function" - quantize_op: Optional[Callable] = None - # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] - if qspec.per_channel: - quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default - dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default - else: - quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default - dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default - # TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize - qparams = { - "_scale_": scale, - "_zero_point_": 0, - "_axis_": axis, - "_quant_min_": 0, - "_quant_max_": 2**qspec.num_bits - 1, - "_dtype_": torch.int8, - } - # 2. replace activation_post_process node with quantize and dequantize - graph = model.graph - # TODO: use metatype to get correct input_port_id - # Do not quantize already quantized nodes - # inserting_before handle only order in the graph generated code. - # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes - with graph.inserting_before(node): - quantize_op_inputs = [node] - for key, value_or_node in qparams.items(): - # TODO: we can add the information of whether a value needs to - # be registered as an attribute in qparams dict itself - if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): - # For scale and zero_point values we register them as buffers in the root module. - # However, note that when the values are not tensors, as in the case of - # per_tensor quantization, they will be treated as literals. - # However, registering them as a node seems to cause issue with dynamo - # tracing where it may consider tensor overload as opposed to default. - # With extra check of scale and zero_point being scalar, it makes - # sure that the default overload can be used. - # TODO: maybe need more complex attr name here - qparam_node = create_getattr_from_value(model, graph, node.name + key, value_or_node) - quantize_op_inputs.append(qparam_node) - else: - # for qparams that are not scale/zero_point (like axis, dtype) we store - # them as literals in the graph. - quantize_op_inputs.append(value_or_node) - with graph.inserting_after(node): - quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) - # use the same qparams from quantize op - dq_inputs = [quantized_node] + quantize_op_inputs[1:] - user_dq_nodes = [] - with graph.inserting_after(quantized_node): - for user in node.users: - if user is quantized_node: - continue - user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) - - for user, dq_node in user_dq_nodes: - user.replace_input_with(node, dq_node) +def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def qdq_insertion_tranformation(model: torch.fx.GraphModule): + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + insert_one_qdq(model, target_node, quantizer) return qdq_insertion_tranformation + + +def insert_one_qdq(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize): + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + if quantizer.is_per_channel: + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + # TODO: map FakeQuantizePramaeters to qparams for quantize/dequantize + qparams = { + "_scale_": quantizer.scale, + "_zero_point_": quantizer.zero_point, + "_axis_": quantizer.ch_axis, + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": torch.int8, + } + # 2. replace activation_post_process node with quantize and dequantize + graph = model.graph + # TODO: use metatype to get correct input_port_id + # Do not quantize already quantized nodes + # inserting_before handle only order in the graph generated code. + # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes + with graph.inserting_before(target_node): + quantize_op_inputs = [target_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store + # them as literals in the graph. + quantize_op_inputs.append(value_or_node) + with graph.inserting_after(target_node): + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + user_dq_nodes = [] + with graph.inserting_after(quantized_node): + for user in target_node.users: + if user is quantized_node: + continue + user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) + + for user, dq_node in user_dq_nodes: + user.replace_input_with(target_node, dq_node) diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 5e1bcd5692c..c3c19ade894 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -27,8 +27,9 @@ from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand -from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT -from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder +from nncf.experimental.torch_fx.quantization.default_quantization import DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT +from nncf.experimental.torch_fx.transformations import fake_quantize_insertion_tranformation_builder # noqa +from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import StatisticsType @@ -117,7 +118,7 @@ def hw_config(self) -> HWConfig: @property def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: - return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + return DEFAULT_FX_QUANT_TRAIT_TO_OP_DICT @staticmethod def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: @@ -297,7 +298,8 @@ def create_quantizer_insertion_command( quantizer = FXMinMaxAlgoBackend._create_quantizer( quantizer_config, scale_shape, parameters, target_point.target_type ) - transformation = fake_quantize_insertion_tranformation_builder(quantizer, [target_point]) + # transformation = fake_quantize_insertion_tranformation_builder(quantizer, [target_point]) + transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) return FXApplyTransformationCommand(transformation) @staticmethod @@ -315,7 +317,8 @@ def create_unified_scales_quantizers_insertion_commands( quantizer_config, scale_shape, parameters, target_points[0].target_type ) - transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points) + # transformation = fake_quantize_insertion_tranformation_builder(quantizer, target_points) + transformation = qdq_insertion_tranformation_builder(quantizer, target_points) return [FXApplyTransformationCommand(transformation)] @staticmethod diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py index 85f9ef74738..b97be3b6479 100644 --- a/torch_compile_ex_release.py +++ b/torch_compile_ex_release.py @@ -44,7 +44,7 @@ def get_exported_model_from_nn_module(module, example_inputs): return capture_pre_autograd_graph(module, example_inputs) -NNCF_IMPL = True +NNCF_IMPL = False def get_qsetup(exported_model, example_inputs):