diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index d097f318c37..ed767d9bbb7 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -21,6 +21,7 @@ import nncf from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.experimental.torch.fx.node_utils import get_tensor_constant_from_node from nncf.torch.graph.transformations.commands import PTTargetPoint @@ -537,6 +538,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: # to make it easier for algorithms to work # with the target graph BatchNorm operations # are being fused + constant_fold(model) fuse_conv_bn(model) separate_conv_and_bias(model) separate_linear_and_bias(model) diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index 82ce5b921a3..b675bd6e814 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -19,7 +19,6 @@ from torch._export import capture_pre_autograd_graph from torchvision import models -from nncf.experimental.torch.fx.constant_folding import constant_fold from nncf.torch import disable_patching from tests.post_training.pipelines.base import PT_BACKENDS from tests.post_training.pipelines.base import BackendType @@ -80,7 +79,6 @@ def prepare_model(self) -> None: with torch.no_grad(): with disable_patching(): self.model = self.model_params.export_fn(model, (self.dummy_tensor,)) - constant_fold(self.model) elif self.backend in PT_BACKENDS: self.model = model