From aea0bdf7b12684eebf2a40c6de0dcc941e0c8d74 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 2 Dec 2024 17:24:01 +0100 Subject: [PATCH] Correct use of transform_for_annotation --- .../torch/fx/quantization/quantize_pt2e.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py index 78ba278826e..ed41166b9ca 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_pt2e.py +++ b/nncf/experimental/torch/fx/quantization/quantize_pt2e.py @@ -28,6 +28,7 @@ from nncf.experimental.common.quantization.algorithms.post_training.algorithm import ( ExperimentalPostTrainingQuantization, ) +from nncf.experimental.common.quantization.algorithms.quantizer.base_quantizer import NNCFQuantizer from nncf.experimental.common.quantization.algorithms.quantizer.fx_quantizer import NNCFFXQuantizer from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.experimental.torch.fx.transformations import QUANTIZE_NODE_TARGETS @@ -63,8 +64,18 @@ def quantize_pt2e( copied_model = deepcopy(model) + # To make it easier for bias correction algorithms, + # biases are being separated by the followng calls. + fuse_conv_bn(copied_model) + # Call ao quantizer transform_for_annotation + # before the NNCFGraph creation + quantizer.transform_for_annotation(copied_model) + + if not isinstance(quantizer, NNCFQuantizer): + quantizer = NNCFFXQuantizer(quantizer) + quantization_algorithm = ExperimentalPostTrainingQuantization( - quantizer=NNCFFXQuantizer(quantizer), + quantizer=quantizer, subset_size=subset_size, fast_bias_correction=fast_bias_correction, smooth_quant=smooth_quant, @@ -74,10 +85,6 @@ def quantize_pt2e( weights_range_estimator_params=weights_range_estimator_params, ) - # To make it easier for bias correction algorithms, - # biases are being separated by the followng calls. - fuse_conv_bn(copied_model) - nncf_graph = NNCFGraphFactory.create(copied_model) quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset)