Skip to content

Commit

Permalink
[TorchFX][Conformance] Move all models to export_for_training (openvi…
Browse files Browse the repository at this point in the history
…notoolkit#3078)

### Changes

All `capture_pre_autograd_graph` calls in the conformance test were
replaced by `torch.export.export_for_training`.

### Reason for changes

To remove deprecated `capture_pre_autograd_graph` from the conformance
test.

### Related tickets

openvinotoolkit#2766 

### Tests

post_training_quantization/555/ have finished succesfully
  • Loading branch information
daniil-lyakhov committed Dec 2, 2024
1 parent 1bce93a commit ed97feb
Showing 1 changed file with 7 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import onnx
import openvino as ov
import torch
from torch._export import capture_pre_autograd_graph
from torchvision import models

from nncf.torch import disable_patching
Expand All @@ -25,11 +24,11 @@
from tests.post_training.pipelines.image_classification_base import ImageClassificationBase


def _capture_pre_autograd_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
return capture_pre_autograd_graph(model, args)
def _torch_export_for_training(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
return torch.export.export_for_training(model, args).module()


def _export_graph_module(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
def _torch_export(model: torch.nn.Module, args: Tuple[Any, ...]) -> torch.fx.GraphModule:
return torch.export.export(model, args).module()


Expand All @@ -44,15 +43,15 @@ class ImageClassificationTorchvision(ImageClassificationBase):
"""Pipeline for Image Classification model from torchvision repository"""

models_vs_model_params = {
models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _capture_pre_autograd_module),
models.resnet18: VisionModelParams(models.ResNet18_Weights.DEFAULT, _torch_export_for_training),
models.mobilenet_v3_small: VisionModelParams(
models.MobileNet_V3_Small_Weights.DEFAULT, _capture_pre_autograd_module
models.MobileNet_V3_Small_Weights.DEFAULT, _torch_export_for_training
),
models.vit_b_16: VisionModelParams(
models.ViT_B_16_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
models.ViT_B_16_Weights.DEFAULT, _torch_export_for_training, export_torch_before_ov_convert=True
),
models.swin_v2_s: VisionModelParams(
models.Swin_V2_S_Weights.DEFAULT, _export_graph_module, export_torch_before_ov_convert=True
models.Swin_V2_S_Weights.DEFAULT, _torch_export, export_torch_before_ov_convert=True
),
}

Expand Down

0 comments on commit ed97feb

Please sign in to comment.