From ed97feb3246ca58918931f83917a1023a17f0631 Mon Sep 17 00:00:00 2001 From: Daniil Lyakhov Date: Tue, 26 Nov 2024 00:58:23 -0800 Subject: [PATCH] [TorchFX][Conformance] Move all models to export_for_training (#3078) ### Changes All `capture_pre_autograd_graph` calls in the conformance test were replaced by `torch.export.export_for_training`. ### Reason for changes To remove deprecated `capture_pre_autograd_graph` from the conformance test. ### Related tickets #2766 ### Tests post_training_quantization/555/ have finished succesfully --- .../pipelines/image_classification_torchvision.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index b675bd6e814..eb2ffeb96a0 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -16,7 +16,6 @@ import onnx import openvino as ov import torch -from torch._export import capture_pre_autograd_graph from torchvision import models from nncf.torch import disable_patching @@ -25,11 +24,11 @@ from tests.post_training.pipelines.image_classification_base import ImageClassificationBase -def _capture_pre_autograd_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: - return capture_pre_autograd_graph(model, args) +def _torch_export_for_training(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: + return torch.export.export_for_training(model, args).module() -def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: +def _torch_export(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule: return torch.export.export(model, args).module() @@ -44,15 +43,15 @@ class ImageClassificationTorchvision(ImageClassificationBase): """Pipeline for Image Classification model from torchvision repository""" models_vs_model_params = { - models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _capture_pre_autograd_module), + models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _torch_export_for_training), models.mobilenet_v3_small: VisionModelParams( - models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module + models.MobileNet_V3_Small_Weights.DEFAULT, _torch_export_for_training ), models.vit_b_16: VisionModelParams( - models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True + models.ViT_B_16_Weights.DEFAULT, _torch_export_for_training, export_torch_before_ov_convert=True ), models.swin_v2_s: VisionModelParams( - models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True + models.Swin_V2_S_Weights.DEFAULT, _torch_export, export_torch_before_ov_convert=True ), }