From 6ba8a34b524058012a861581e34f6bf2921ee7a5 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 4 Jul 2024 11:51:20 +0200 Subject: [PATCH] Comments --- Makefile | 1 - nncf/experimental/torch/fx/model_transformer.py | 8 +++----- .../torch/fx/quantization/quantize_model.py | 14 ++++++++++++-- nncf/experimental/torch/fx/transformations.py | 2 +- tests/torch/fx/requirements.txt | 7 ------- tests/torch/requirements.txt | 5 +++++ 6 files changed, 21 insertions(+), 16 deletions(-) delete mode 100644 tests/torch/fx/requirements.txt diff --git a/Makefile b/Makefile index 42adac4551c..da65aa89e8b 100644 --- a/Makefile +++ b/Makefile @@ -126,7 +126,6 @@ install-torch-test: pip install -e . pip install "git+https://github.com/openvinotoolkit/open_model_zoo.git@37f60eb#egg=accuracy_checker&subdirectory=tools/accuracy_checker" pip install -r tests/torch/requirements.txt - pip install -r tests/torch/fx/requirements.txt pip install -r tests/cross_fw/install/requirements.txt pip install -r tests/cross_fw/examples/requirements.txt diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index d68ff7f18a8..b4db5ed4fa7 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -10,8 +10,6 @@ # limitations under the License. from collections import defaultdict - -# from functools import partial from typing import Callable, List, Union import torch @@ -63,9 +61,9 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G if transformations: model = transformation_fn(model, transformations) - # Do not eliminate dead code as - # the dead code is computing statistics :) - # model.graph.eliminate_dead_code() + # Do not use model.graph.eliminate_dead_code() + # because the computational statistics code + # is interpolated as dead code. model.recompile() return model diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 486b3987980..08bb73ee854 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -94,11 +94,21 @@ def quantize_impl( advanced_parameters=advanced_parameters, ) + # BatchNorm operations have 3 output ports, + # to make it easier for alorithms to work + # with the target graph BatchNorm operations + # are being fused _fuse_conv_bn_(copied_model) - # BN fuses to conv bias, conv+bias joined op - # needs to be splited for nncf + + # To make it easier for bias correction algorithms, + # biases are being separated by the followng calls. separate_linear_and_bias(copied_model) separate_conv_and_bias(copied_model) + + # View requires at least one dimension spans + # across two contiguous subspaces and reshape is not. + # To prevent error during statistics collection + # all view operation are translated to reshape. view_to_reshape(copied_model) nncf_graph = NNCFGraphFactory.create(copied_model) diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 304584ed929..dc8a2122adb 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -140,7 +140,7 @@ def insert_one_qdq_before_node(model: torch.fx.GraphModule, target_node: torch.f # 1. extract information for inserting q/dq node from activation_post_process node_type = "call_function" quantize_op: Optional[Callable] = None - # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 if quantizer.is_per_channel: qparams = { diff --git a/tests/torch/fx/requirements.txt b/tests/torch/fx/requirements.txt deleted file mode 100644 index cc2e4494d68..00000000000 --- a/tests/torch/fx/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ --c ../../../constraints.txt -pytest -pytest-cov -openvino -torch -torchvision -fastdownload==0.0.7 \ No newline at end of file diff --git a/tests/torch/requirements.txt b/tests/torch/requirements.txt index bbd3a45c57e..be82652d65f 100644 --- a/tests/torch/requirements.txt +++ b/tests/torch/requirements.txt @@ -19,3 +19,8 @@ datasets==2.14.7 evaluate==0.3.0 openvino timm==0.9.2 + + +# Required for torch/fx tests +torchvision +fastdownload==0.0.7