diff --git a/examples/torch/common/export.py b/examples/torch/common/export.py index 8d3a4c266e8..7d98bcbf70d 100644 --- a/examples/torch/common/export.py +++ b/examples/torch/common/export.py @@ -11,7 +11,9 @@ import torch from nncf.api.compression import CompressionAlgorithmController +from nncf.torch.exporter import count_tensors from nncf.torch.exporter import generate_input_names_list +from nncf.torch.exporter import get_export_args def export_model(ctrl: CompressionAlgorithmController, save_path: str, no_strip_on_export: bool) -> None: @@ -26,11 +28,9 @@ def export_model(ctrl: CompressionAlgorithmController, save_path: str, no_strip_ model = ctrl.model if no_strip_on_export else ctrl.strip() model = model.eval().cpu() - input_names = generate_input_names_list(len(model.nncf.input_infos)) - input_tensor_list = [] - for info in model.nncf.input_infos: - input_shape = tuple([1] + list(info.shape)[1:]) - input_tensor_list.append(torch.rand(input_shape)) + + export_args = get_export_args(model) + input_names = generate_input_names_list(count_tensors(export_args)) with torch.no_grad(): - torch.onnx.export(model, tuple(input_tensor_list), save_path, input_names=input_names) + torch.onnx.export(model, export_args, save_path, input_names=input_names) diff --git a/examples/torch/object_detection/main.py b/examples/torch/object_detection/main.py index 7feffdf1802..85c88a36412 100644 --- a/examples/torch/object_detection/main.py +++ b/examples/torch/object_detection/main.py @@ -364,7 +364,7 @@ def create_train_data_loader(batch_size): def create_model(config: SampleConfig): input_info = FillerInputInfo.from_nncf_config(config.nncf_config) - image_size = input_info[0].shape[-1] + image_size = input_info.elements[0].shape[-1] ssd_net = build_ssd(config.model, config.ssd_params, image_size, config.num_classes, config) weights = config.get("weights") if weights: diff --git a/nncf/torch/exporter.py b/nncf/torch/exporter.py index a7ddc5b9683..0a9a8d45bb4 100644 --- a/nncf/torch/exporter.py +++ b/nncf/torch/exporter.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch from torch.onnx import OperatorExportTypes @@ -20,6 +20,7 @@ from nncf.telemetry.events import NNCF_PT_CATEGORY from nncf.torch.dynamic_graph.graph_tracer import create_dummy_forward_fn from nncf.torch.nested_objects_traversal import objwalk +from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.utils import get_model_device from nncf.torch.utils import is_tensor @@ -44,6 +45,21 @@ def counter_fn(x: torch.Tensor) -> torch.Tensor: return count +def get_export_args(model: NNCFNetwork, model_args: Optional[Tuple[Any, ...]] = None) -> Tuple: + args, kwargs = model.nncf.input_infos.get_forward_inputs() + + if model_args is not None: + args = tuple(list(args) + list(model_args[:-1])) + kwargs.update(**model_args[-1]) + + def to_single_batch_tensors(obj: torch.Tensor): + return obj[0:1] + + args = objwalk(args, is_tensor, to_single_batch_tensors) + kwargs = objwalk(kwargs, is_tensor, to_single_batch_tensors) + return *args, kwargs # according to a variant of passing kwargs in torch.onnx.export doc + + class PTExportFormat: ONNX = "onnx" @@ -126,18 +142,7 @@ def _export_to_onnx(self, save_path: str, opset_version: int) -> None: original_device = get_model_device(self._model) model = self._model.eval().cpu() - args, kwargs = self._model.nncf.input_infos.get_forward_inputs() - - if self._model_args is not None: - args = tuple(list(args) + list(self._model_args[:-1])) - kwargs.update(**self._model_args[-1]) - - def to_single_batch_tensors(obj: torch.Tensor): - return obj[0:1] - - args = objwalk(args, is_tensor, to_single_batch_tensors) - kwargs = objwalk(kwargs, is_tensor, to_single_batch_tensors) - export_args = (*args, kwargs) # according to a variant of passing kwargs in torch.onnx.export doc + export_args = get_export_args(self._model, model_args=self._model_args) if self._input_names is not None: input_names = self._input_names diff --git a/tests/torch/quantization/test_sanity_sample.py b/tests/torch/quantization/test_sanity_sample.py index e42832f1186..03635e73461 100644 --- a/tests/torch/quantization/test_sanity_sample.py +++ b/tests/torch/quantization/test_sanity_sample.py @@ -14,6 +14,7 @@ from typing import Dict import pytest +import torch from torch import nn from nncf import NNCFConfig @@ -314,6 +315,7 @@ def setup_spy(self, mocker): ctrl_mock = mocker.MagicMock(spec=QuantizationController) model_mock = mocker.MagicMock(spec=nn.Module) + mocker.patch("examples.torch.common.export.get_export_args", return_value=((torch.Tensor([1, 1]),), {})) create_model_location = sample_location + ".create_compressed_model" create_model_patch = mocker.patch(create_model_location) diff --git a/tests/torch/requirements.txt b/tests/torch/requirements.txt index 7c504fc5214..584c016f894 100644 --- a/tests/torch/requirements.txt +++ b/tests/torch/requirements.txt @@ -13,7 +13,7 @@ pyparsing<3.0 transformers[torch]~=4.30.0 # Required for movement_sparsity tests -datasets~=2.12.0 +datasets~=2.14.0 evaluate==0.3.0 timm==0.9.2 openvino-dev==2023.1