diff --git a/nncf/common/factory.py b/nncf/common/factory.py index c5a921c8068..6f17dfc7fbd 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -9,7 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypeVar +import os +from typing import Any, Dict, Optional, Tuple, TypeVar import nncf from nncf.common.engine import Engine @@ -26,13 +27,20 @@ class NNCFGraphFactory: @staticmethod - def create(model: TModel) -> NNCFGraph: + def create( + model: TModel, input_args: Optional[Tuple[Any, ...]] = None, input_kwargs: Optional[Dict[str, Any]] = None + ) -> NNCFGraph: """ Factory method to create backend-specific NNCFGraph instance based on the input model. :param model: backend-specific model instance :return: backend-specific NNCFGraph instance """ + if input_args is None: + input_args = () + if input_kwargs is None: + input_kwargs = {} + model_backend = get_backend(model) if model_backend == BackendType.ONNX: from nncf.onnx.graph.nncf_graph_builder import GraphConverter @@ -47,7 +55,13 @@ def create(model: TModel) -> NNCFGraph: return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: - return model.nncf.get_graph() + if os.getenv("NNCF_EXPERIMENTAL_TORCH_TRACING") is None: + return model.nncf.get_graph() + else: + from nncf.experimental.torch2.function_hook.nncf_graph.nncf_graph_builder import build_nncf_graph + + return build_nncf_graph(model, *input_args, **input_kwargs) + raise nncf.UnsupportedBackendError( "Cannot create backend-specific graph because {} is not supported!".format(model_backend.value) ) diff --git a/nncf/common/model.py b/nncf/common/model.py new file mode 100644 index 00000000000..b0685180376 --- /dev/null +++ b/nncf/common/model.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, TypeVar + +from nncf.common.factory import NNCFGraphFactory +from nncf.common.graph.graph import NNCFGraph +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend + +TModel = TypeVar("TModel") + + +@dataclass +class ModelAttributes: + """ + A class to store model attributes. + + :param example_input_args: Example input arguments for the model. + :param example_input_kwargs: Example input keyword arguments for the model. + """ + + example_input_args: Optional[Tuple[Any]] = None + example_input_kwargs: Optional[Dict[str, Any]] = None + + +class ModelWrapper: + """ + A wrapper class for the original model. + + :param _model: The original model to be wrapped. + :param _graph: The graph representation of the model. + :param _attributes: The storage of the model attributes. + :param _backend: The backend of the model. + """ + + def __init__( + self, model: TModel, *, graph: Optional[NNCFGraph] = None, attributes: Optional[ModelAttributes] = None + ) -> None: + self._model = model + self._graph = graph + self._attributes = attributes or ModelAttributes() + self._backend = get_backend(model) + + @property + def model(self) -> TModel: + """ + Retrieves the original model. + """ + return self._model + + @property + def graph(self) -> NNCFGraph: + """ + Returns the NNCFGraph representation of the model. + + If the graph has not been created yet, it will be created using the model, + example input arguments, and example input keyword arguments stored in the state. + """ + if self._graph is None: + self._graph = NNCFGraphFactory.create( + self.model, self.attributes.example_input_args, self.attributes.example_input_kwargs + ) + return self._graph + + @property + def attributes(self) -> ModelAttributes: + """ + Retrieves the model attributes. + """ + return self._attributes + + @property + def backend(self) -> BackendType: + """ + Retrieves the model backend. + """ + return self._backend + + def unwrap(self) -> Tuple[TModel, NNCFGraph]: + """ + Retrieves the model and graph. + + :return: A tuple of the model and graph. + """ + return self.model, self.graph diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 00b0be8bac1..db368561d34 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -10,10 +10,8 @@ # limitations under the License. from copy import deepcopy -from typing import Optional +from typing import Optional, cast -import torch -import torch.fx from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat @@ -22,8 +20,8 @@ from torch.fx.passes.infra.pass_manager import PassManager import nncf -from nncf.common.factory import NNCFGraphFactory from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed @@ -46,7 +44,7 @@ def quantize_impl( - model: torch.fx.GraphModule, + model: GraphModule, calibration_dataset: Dataset, mode: Optional[QuantizationMode] = None, preset: Optional[QuantizationPreset] = None, @@ -56,7 +54,7 @@ def quantize_impl( model_type: Optional[ModelType] = None, ignored_scope: Optional[IgnoredScope] = None, advanced_parameters: Optional[AdvancedQuantizationParameters] = None, -) -> torch.fx.GraphModule: +) -> GraphModule: """ Implementation of the `quantize()` method for the Torch FX backend. """ @@ -86,9 +84,9 @@ def quantize_impl( # To make it easier for bias correction algorithms. apply_quantization_transformations(copied_model) - - nncf_graph = NNCFGraphFactory.create(copied_model) - quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(copied_model) + quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset) + quantized_model = cast(GraphModule, quantized_model_wrapper.model) if is_weight_compression_needed(advanced_parameters): compress_post_quantize_transformation(quantized_model) @@ -116,7 +114,7 @@ def quantize_impl( def compress_weights_impl( - model: torch.fx.GraphModule, + model: GraphModule, dataset: Dataset, mode: CompressWeightsMode, ratio: float, @@ -131,7 +129,7 @@ def compress_weights_impl( lora_correction: bool, backup_mode: BackupMode, advanced_parameters: Optional[AdvancedCompressionParameters] = None, -) -> torch.fx.GraphModule: +) -> GraphModule: """ Implementation of the `compress_weights()` method for the Torch Fx backend. """ @@ -151,8 +149,9 @@ def compress_weights_impl( backup_mode, advanced_parameters, ) - graph = NNCFGraphFactory.create(model) - compressed_model = compression_algorithm.apply(model, graph, dataset=dataset) + model_wrapper = ModelWrapper(model) + compressed_model_wrapper = compression_algorithm.apply(model_wrapper, dataset=dataset) + compressed_model = compressed_model_wrapper.model compressed_model = GraphModule(compressed_model, compressed_model.graph) compressed_model = _disallow_eval_train(compressed_model) diff --git a/nncf/onnx/quantization/quantize_model.py b/nncf/onnx/quantization/quantize_model.py index 7a4665d1a0c..d82c8ed5b8a 100644 --- a/nncf/onnx/quantization/quantize_model.py +++ b/nncf/onnx/quantization/quantize_model.py @@ -9,16 +9,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast import onnx import nncf from nncf.common.logging.logger import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.onnx.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS -from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.parameters import DropType from nncf.parameters import ModelType from nncf.parameters import QuantizationMode @@ -78,10 +78,12 @@ def quantize_impl( advanced_parameters=advanced_parameters, ) - graph = GraphConverter.create_nncf_graph(model) - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset) - + model_wrapper = ModelWrapper(model) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) + quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset) + quantized_model = cast(onnx.ModelProto, quantized_model_wrapper.model) return quantized_model diff --git a/nncf/openvino/quantization/quantize_ifmodel.py b/nncf/openvino/quantization/quantize_ifmodel.py index 07d22171a17..c559846ec29 100644 --- a/nncf/openvino/quantization/quantize_ifmodel.py +++ b/nncf/openvino/quantization/quantize_ifmodel.py @@ -10,7 +10,7 @@ # limitations under the License. from itertools import islice -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, cast import openvino.runtime as ov @@ -25,6 +25,7 @@ from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates @@ -155,7 +156,13 @@ def apply_algorithm_if_bodies( """ nncf_logger.info(f"Iteration [{current_model_num}/{len(graphs)}] ...") parent_graph = graphs[graph_id] - quantized_model = algorithm.apply(parent_model, parent_graph, parent_statistic_points, parent_dataset) + + model_wrapper = ModelWrapper(parent_model, graph=parent_graph) + quantized_model_wrapper = algorithm.apply( + model_wrapper, statistic_points=parent_statistic_points, dataset=parent_dataset + ) + quantized_model = cast(ov.Model, quantized_model_wrapper.model) + if get_number_if_op(parent_model) == 0: return quantized_model, current_model_num model_transformer_fp32 = factory.ModelTransformerFactory.create(parent_model) diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index 423b9d6d42e..30d689b7d1c 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -11,7 +11,7 @@ from copy import deepcopy from pathlib import Path -from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union, cast import openvino.runtime as ov from openvino._offline_transformations import compress_quantize_weights_transformation @@ -19,13 +19,13 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.openvino.graph.metatypes.groups import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS from nncf.openvino.graph.metatypes.openvino_metatypes import OVIfMetatype from nncf.openvino.graph.metatypes.openvino_metatypes import get_node_metatype from nncf.openvino.graph.model_utils import remove_friendly_name_duplicates -from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_number_if_op from nncf.openvino.quantization.backend_parameters import BackendParameters from nncf.openvino.quantization.backend_parameters import is_weight_compression_needed @@ -166,9 +166,12 @@ def native_quantize_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) - graph = GraphConverter.create_nncf_graph(model) - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(model, graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(model) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) + quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset) + quantized_model = cast(ov.Model, quantized_model_wrapper.model) if is_weight_compression_needed(advanced_parameters): compress_quantize_weights_transformation(quantized_model) @@ -383,7 +386,7 @@ def compress_weights_impl( Implementation of the `compress_weights()` method for the OpenVINO backend. """ model = remove_friendly_name_duplicates(model) - graph = NNCFGraphFactory.create(model) + model_wrapper = ModelWrapper(model) compression_algorithm = WeightCompression( mode, ratio, @@ -405,18 +408,20 @@ def compress_weights_impl( # If there is no such directory, then caches statistics statistics_path = Path(advanced_parameters.statistics_path) if not statistics_path.exists(): - cache_weight_compression_statistics(model, graph, dataset, subset_size, statistics_path) + cache_weight_compression_statistics( + model_wrapper.model, model_wrapper.graph, dataset, subset_size, advanced_parameters.statistics_path + ) statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) compression_algorithm.set_backend_entity(model) - _, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(graph) + _, matmul_input_to_output_nodes_map = compression_algorithm.get_compression_nodes_info(model_wrapper.graph) register_statistics_for_algorithm( statistics_aggregator, - model, - graph, + model_wrapper.model, + model_wrapper.graph, compression_algorithm, matmul_input_to_output_nodes_map, ) statistics_aggregator.load_statistics_from_dir(statistics_path) statistics_points = statistics_aggregator.statistic_points - return compression_algorithm.apply(model, graph, statistics_points, dataset) + return compression_algorithm.apply(model_wrapper, statistic_points=statistics_points, dataset=dataset).model diff --git a/nncf/quantization/algorithms/algorithm.py b/nncf/quantization/algorithms/algorithm.py index befe0a82f9d..86b21694120 100644 --- a/nncf/quantization/algorithms/algorithm.py +++ b/nncf/quantization/algorithms/algorithm.py @@ -14,7 +14,7 @@ from typing import List, Optional, TypeVar from nncf import Dataset -from nncf.common.graph.graph import NNCFGraph +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -38,27 +38,25 @@ def available_backends(self) -> List[BackendType]: @abstractmethod def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: """ Applies the algorithm to the model. - :param model: Model for applying algorithm. - :param graph: Model graph. + :param model_wrapper: A wrapper object containing the model to be applied. :param statistic_points: Statistic points with collected statistics values. :param dataset: A representative dataset for the calibration process. :return: A resulting model. """ @abstractmethod - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: """ Returns statistic points, for which StatisticsCollector should collect statistics. - :param model: Model for statistics collection. - :param graph: Model graph. + :param model_wrapper: A wrapper object containing the model for statistics collection. :return: Statistic points, for which StatisticsCollector should collect statistics. """ diff --git a/nncf/quantization/algorithms/bias_correction/algorithm.py b/nncf/quantization/algorithms/bias_correction/algorithm.py index 63db2ee0adf..d0bea42e41d 100644 --- a/nncf/quantization/algorithms/bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/bias_correction/algorithm.py @@ -25,11 +25,11 @@ from nncf.common.graph.transformations.commands import TransformationCommand from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import copy_model -from nncf.common.utils.backend import get_backend from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel from nncf.quantization.algorithms.algorithm import Algorithm from nncf.tensor import Tensor @@ -107,13 +107,12 @@ def __init__( def available_backends(self) -> List[BackendType]: return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH_FX] - def _set_backend_entity(self, model: TModel) -> None: + def _set_backend_entity(self, model_backend: BackendType) -> None: """ Creates a helper class with a backed-specific logic of the algorithm. - :param model: Backend-specific input model. + :param model_backend: Backend of a model. """ - model_backend = get_backend(model) if model_backend == BackendType.ONNX: from nncf.quantization.algorithms.bias_correction.onnx_backend import ONNXBiasCorrectionAlgoBackend @@ -133,12 +132,15 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: - self._set_backend_entity(model) + ) -> ModelWrapper: + self._set_backend_entity(model_wrapper.backend) + + model = model_wrapper.model + main_transformations_layout = TransformationLayout() main_model_transformer = ModelTransformerFactory.create(model) @@ -205,7 +207,8 @@ def apply( # to reduce memory usage during the algorithm's pipeline. self._remove_unnecessary_stats(position, subgraphs_data) - return main_model_transformer.transform(main_transformations_layout) + transformed_model = main_model_transformer.transform(main_transformations_layout) + return ModelWrapper(transformed_model, attributes=model_wrapper.attributes) def _is_node_correctable(self, node: NNCFNode, nncf_graph: NNCFGraph) -> bool: """ @@ -553,8 +556,10 @@ def output_filter_func(point): output_fp.extend(tensor_collector.get_statistics().mean_values) return output_fp - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.backend) + model, graph = model_wrapper.unwrap() + statistic_container = StatisticPointsContainer() nodes_with_bias = [ diff --git a/nncf/quantization/algorithms/channel_alignment/algorithm.py b/nncf/quantization/algorithms/channel_alignment/algorithm.py index b30749b6d2c..d616e8e77e8 100644 --- a/nncf/quantization/algorithms/channel_alignment/algorithm.py +++ b/nncf/quantization/algorithms/channel_alignment/algorithm.py @@ -25,6 +25,7 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -93,11 +94,12 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: + model, graph = model_wrapper.unwrap() self._set_backend_entity(model) model_transformer = ModelTransformerFactory.create(model) transformation_layout = TransformationLayout() @@ -127,7 +129,7 @@ def filter_func(point: StatisticPoint) -> bool: ): nncf_logger.debug( f"Skipping channel alignment for pairs {conv_in.node_name}, {conv_out.node_name} " - " because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algortihm yet." + "because one of the node is 1D MatMul, 1D Matmuls are not supported by CA algorithm yet." ) continue @@ -170,7 +172,7 @@ def filter_func(point: StatisticPoint) -> bool: transformation_layout.register(command) transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return ModelWrapper(transformed_model, attributes=model_wrapper.attributes) @staticmethod def _align_means( @@ -381,8 +383,9 @@ def _get_target_point_and_node_in(self, conv_in, add_in) -> Tuple[TargetPoint, N node_in, ) - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.model) + graph = model_wrapper.graph statistic_container = StatisticPointsContainer() for conv_in, add_in, _ in self._get_node_pairs(graph): diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 3d104cad3c9..0f3f4142fbc 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -16,17 +16,16 @@ from nncf import Dataset from nncf.common.factory import EngineFactory from nncf.common.factory import ModelTransformerFactory -from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetPoint from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.layout import TransformationLayout from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType -from nncf.common.utils.backend import get_backend from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel from nncf.quantization.algorithms.algorithm import Algorithm from nncf.tensor import Tensor @@ -95,13 +94,12 @@ def __init__( def available_backends(self) -> List[BackendType]: return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] - def _set_backend_entity(self, model: TModel) -> None: + def _set_backend_entity(self, model_backend: BackendType) -> None: """ Creates a helper class with a backed-specific logic of the algorithm. - :param model: Backend-specific input model. + :param model_backend: Backend of a model. """ - model_backend = get_backend(model) if model_backend == BackendType.ONNX: from nncf.quantization.algorithms.fast_bias_correction.onnx_backend import ONNXFastBiasCorrectionAlgoBackend @@ -129,12 +127,13 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: - self._set_backend_entity(model) + ) -> ModelWrapper: + self._set_backend_entity(model_wrapper.backend) + model, graph = model_wrapper.unwrap() model_transformer = ModelTransformerFactory.create(model) @@ -207,7 +206,7 @@ def apply( transformation_layout.register(self._backend_entity.create_bias_correction_command(node, bias_value, graph)) transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return ModelWrapper(transformed_model, attributes=model_wrapper.attributes) @staticmethod def _get_bias_shift_magnitude(current_bias_value: Tensor, updated_bias_value: Tensor) -> Tensor: @@ -345,8 +344,9 @@ def _get_bias_shift( bias_shift = fns.stack(output_fp) - q_outputs return bias_shift - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.backend) + graph = model_wrapper.graph nodes_with_bias = [ node for node in graph.get_all_nodes() if self._backend_entity.is_node_with_bias(node, graph) ] diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py index 7eda61ce64a..a40f0637a74 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_backend.py @@ -109,5 +109,5 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG return input_node_name, output_node_name @staticmethod - def get_activation_channel_axis(node: NNCFNode, pord_id: int, input_shape: Tuple[int]) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: return node.metatype.output_channel_axis diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index c6c538a52b4..cda0b3b30c6 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -98,5 +98,5 @@ def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFG return node.node_name, node.node_name @staticmethod - def get_activation_channel_axis(node: NNCFNode, pord_id: int, input_shape: Tuple[int]) -> int: + def get_activation_channel_axis(node: NNCFNode, port_id: int, input_shape: Tuple[int]) -> int: return node.metatype.output_channel_axis diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 4d35107b7a5..7d2d64a769a 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -31,6 +31,7 @@ from nncf.common.hardware.config import get_hw_config_type from nncf.common.insertion_point_graph import InsertionPointGraph from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.quantization.config_assignment import assign_qconfig_lists_to_modules from nncf.common.quantization.initialization.range import RangeInitCollectorParams from nncf.common.quantization.quantizer_propagation.solver import QuantizerPropagationRule @@ -954,14 +955,17 @@ def _get_quantization_points_overflow_fix( def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: transformation_layout = TransformationLayout() - model_transformer = ModelTransformerFactory.create(model) - quantization_target_points, unified_scale_groups = self._get_quantization_target_points(model, graph) + model_transformer = ModelTransformerFactory.create(model_wrapper.model) + graph = model_wrapper.graph + quantization_target_points, unified_scale_groups = self._get_quantization_target_points( + model_wrapper.model, graph + ) quantization_points_overflow_fix = self._get_quantization_points_overflow_fix( self._overflow_fix, quantization_target_points, graph ) @@ -1052,12 +1056,12 @@ def filter_func(point: StatisticPoint) -> bool: if not transformation_layout.transformations: nncf_logger.info("The model has no operations to apply quantization.") quantized_model = model_transformer.transform(transformation_layout) - return quantized_model + return ModelWrapper(quantized_model, attributes=model_wrapper.attributes) - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: - self._set_backend_entity(model) + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + self._set_backend_entity(model_wrapper.model) self._reset_cache() - quantization_target_points, _ = self._get_quantization_target_points(model, graph) + quantization_target_points, _ = self._get_quantization_target_points(model_wrapper.model, model_wrapper.graph) output = StatisticPointsContainer() for quantization_target_point, qconfig in quantization_target_points.items(): nncf_logger.debug( @@ -1065,7 +1069,7 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin f" with type {quantization_target_point.type} for statistics collection" ) stat_collector = self._get_stat_collector( - graph, quantization_target_point, qconfig, self._batchwise_statistics + model_wrapper.graph, quantization_target_point, qconfig, self._batchwise_statistics ) output.add_statistic_point( StatisticPoint( diff --git a/nncf/quantization/algorithms/pipeline.py b/nncf/quantization/algorithms/pipeline.py index cd615258553..0159955ae0c 100644 --- a/nncf/quantization/algorithms/pipeline.py +++ b/nncf/quantization/algorithms/pipeline.py @@ -11,10 +11,9 @@ from typing import Dict, List, Optional, TypeVar, Union -from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory -from nncf.common.graph.graph import NNCFGraph from nncf.common.logging import nncf_logger +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend @@ -27,26 +26,25 @@ def collect_statistics( containers: Union[StatisticPointsContainer, List[StatisticPointsContainer]], - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, dataset: Dataset, ) -> StatisticPointsContainer: """ Utility method for collecting statistics by model. :param statistic_points: Statistic points that need to be collected. - :param model: A model. - :param graph: A graph assosiated with a model. + :param model_wrapper: A wrapper object containing the model + :param graph: A graph associated with a model. :param dataset: A dataset. :return: Collected statistics. """ if not isinstance(containers, list): containers = [containers] - statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) + statistics_aggregator = StatisticsAggregatorFactory.create(model_wrapper.model, dataset) for container in containers: statistics_aggregator.register_statistic_points(container) - statistics_aggregator.collect_statistics(model, graph) + statistics_aggregator.collect_statistics(model_wrapper.model, model_wrapper.graph) return statistics_aggregator.statistic_points @@ -96,8 +94,7 @@ def run_step( self, step_index: int, step_statistics: StatisticPointsContainer, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, ) -> TModel: """ Executes a provided pipeline step on the provided model. @@ -105,36 +102,31 @@ def run_step( :param step_index: Zero-based index of the pipeline step that should be executed :param step_statistics: Statistics required to execute a pipeline step. :param model: A model to which a pipeline step will be applied. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :return: The updated model after executing the pipeline step. """ - current_model = model - current_graph = graph + current_model = model_wrapper - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(model_wrapper.backend) pipeline_step = pipeline_steps[step_index] - for algorithm in pipeline_step[:-1]: - current_model = algorithm.apply(current_model, current_graph, step_statistics) - current_graph = NNCFGraphFactory.create(current_model) - current_model = pipeline_step[-1].apply(current_model, current_graph, step_statistics) - + for algorithm in pipeline_step: + current_model = algorithm.apply(current_model, statistic_points=step_statistics) return current_model def run_from_step( self, - model: TModel, + model_wrapper: ModelWrapper, dataset: Dataset, - graph: Optional[NNCFGraph] = None, start_step_index: int = 0, step_index_to_statistics: Optional[Dict[int, StatisticPointsContainer]] = None, - ) -> TModel: + ) -> ModelWrapper: """ Executes the pipeline from the specified pipeline step to the end. :param model: This is the model after the (start_step_index - 1)-th pipeline step, or the initial model if start_step_index is 0. :param dataset: A dataset that holds the data items for pipeline steps. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :param start_step_index: Zero-based pipeline step index from which the pipeline should be executed. :param step_index_to_statistics: A mapping from pipeline step index to statistics @@ -142,47 +134,38 @@ def run_from_step( :return: The updated model after executing the pipeline from the specified pipeline step to the end. """ - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(model_wrapper.backend) if step_index_to_statistics is None: step_index_to_statistics = {} # The `step_model` and `step_graph` entities are required to execute `step_index`-th pipeline step - step_model = model - step_graph = graph + step_model_wrapper = model_wrapper for step_index in range(start_step_index, len(pipeline_steps)): - # Create graph required to run current pipeline step - if step_graph is None: - step_graph = NNCFGraphFactory.create(step_model) - # Collect statistics required to run current pipeline step step_statistics = step_index_to_statistics.get(step_index) if step_statistics is None: - statistic_points = self.get_statistic_points_for_step(step_index, step_model, step_graph) - step_statistics = collect_statistics(statistic_points, step_model, step_graph, dataset) + statistic_points = self.get_statistic_points_for_step(step_index, step_model_wrapper) + step_statistics = collect_statistics(statistic_points, step_model_wrapper, dataset) # Run current pipeline step - step_model = self.run_step(step_index, step_statistics, step_model, step_graph) - - step_graph = None # We should rebuild the graph for the next pipeline step + step_model_wrapper = self.run_step(step_index, step_statistics, step_model_wrapper) - return step_model + return step_model_wrapper - def get_statistic_points_for_step( - self, step_index: int, model: TModel, graph: NNCFGraph - ) -> StatisticPointsContainer: + def get_statistic_points_for_step(self, step_index: int, model_wrapper: ModelWrapper) -> StatisticPointsContainer: """ Returns statistics that should be collected to execute `step_index`-th pipeline step. :param step_index: Zero-based index of the pipeline step. :param model: A model. - :param graph: A graph assosiated with a model. + :param graph: A graph associated with a model. :return: Statistics that should be collected to execute `step_index`-th pipeline step. """ container = StatisticPointsContainer() - pipeline_steps = self._remove_unsupported_algorithms(get_backend(model)) + pipeline_steps = self._remove_unsupported_algorithms(get_backend(model_wrapper.model)) pipeline_step = pipeline_steps[step_index] for algorithm in pipeline_step: - for statistic_points in algorithm.get_statistic_points(model, graph).values(): + for statistic_points in algorithm.get_statistic_points(model_wrapper).values(): for statistic_point in statistic_points: container.add_statistic_point(statistic_point) diff --git a/nncf/quantization/algorithms/post_training/algorithm.py b/nncf/quantization/algorithms/post_training/algorithm.py index 862dc5d5037..9466688a0d7 100644 --- a/nncf/quantization/algorithms/post_training/algorithm.py +++ b/nncf/quantization/algorithms/post_training/algorithm.py @@ -10,10 +10,11 @@ # limitations under the License. import itertools -from typing import Callable, List, Optional, TypeVar +from typing import List, Optional, TypeVar from nncf import Dataset from nncf.common.graph.graph import NNCFGraph +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -26,7 +27,6 @@ from nncf.scopes import IgnoredScope TModel = TypeVar("TModel") -TPass = Callable[[TModel], TModel] class PostTrainingQuantization(Algorithm): @@ -55,7 +55,7 @@ def __init__( - `performance`: Symmetric quantization of weights and activations. - `mixed`: Symmetric quantization of weights and asymmetric quantization of activations. Default value is None. In this case, `mixed` preset is used for `transformer` - model type otherwise `performace`. + model type otherwise `performance`. :param target_device: A target device the specificity of which will be taken into account while compressing in order to obtain the best performance for this type of device. @@ -94,11 +94,11 @@ def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPoin def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: if dataset is None and len(self._pipeline.pipeline_steps) > 1: raise ValueError( "A dataset is required for the post-training quantization " @@ -109,4 +109,4 @@ def apply( if statistic_points: step_index_to_statistics = {0: statistic_points} - return self._pipeline.run_from_step(model, dataset, graph, 0, step_index_to_statistics) + return self._pipeline.run_from_step(model_wrapper, dataset, 0, step_index_to_statistics) diff --git a/nncf/quantization/algorithms/smooth_quant/algorithm.py b/nncf/quantization/algorithms/smooth_quant/algorithm.py index 83aefc6709a..a23ebf3b8ef 100644 --- a/nncf/quantization/algorithms/smooth_quant/algorithm.py +++ b/nncf/quantization/algorithms/smooth_quant/algorithm.py @@ -24,6 +24,7 @@ from nncf.common.graph.utils import get_reduction_axes from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.utils.backend import BackendType @@ -98,11 +99,13 @@ def _set_backend_entity(self, model: TModel) -> None: def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: + model = model_wrapper.model + graph = model_wrapper.graph self._set_backend_entity(model) alpha_map = self._get_alpha_map() @@ -176,7 +179,7 @@ def apply( transformation_layout.register(scale_insertion_command) transformed_model = model_transformer.transform(transformation_layout) - return transformed_model + return ModelWrapper(transformed_model, attributes=model_wrapper.attributes) @staticmethod def _calculate_scale_and_ratio( @@ -245,7 +248,10 @@ def _get_statistics_for_node( statistics_for_node.append(statistic) return statistics_for_node - def get_statistic_points(self, model: TModel, graph: NNCFGraph) -> StatisticPointsContainer: + def get_statistic_points(self, model_wrapper: ModelWrapper) -> StatisticPointsContainer: + model = model_wrapper.model + graph = model_wrapper.graph + statistic_container = StatisticPointsContainer() self._set_backend_entity(model) diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index c5a4e2d221c..6f016326fa8 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -24,6 +24,7 @@ from nncf.common.graph.transformations.commands import TargetType from nncf.common.logging import nncf_logger from nncf.common.logging.track_progress import track +from nncf.common.model import ModelWrapper from nncf.common.scopes import should_consider_scope from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer @@ -480,11 +481,12 @@ def _get_ignored_scope_weight_statistics(self, model: TModel, graph: NNCFGraph) def apply( self, - model: TModel, - graph: NNCFGraph, + model_wrapper: ModelWrapper, + *, statistic_points: Optional[StatisticPointsContainer] = None, dataset: Optional[Dataset] = None, - ) -> TModel: + ) -> ModelWrapper: + model, graph = model_wrapper.unwrap() self.set_backend_entity(model) nodes_to_compress = self.get_nodes_to_compress(graph) @@ -667,7 +669,7 @@ def apply( }, algo_name="weight_compression", ) - return transformed_model + return ModelWrapper(transformed_model, attributes=model_wrapper.attributes) def _get_activation_node_and_port(self, node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[NNCFNode, int]: """ diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index eb520bfcd1b..b039919aa5c 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -67,7 +67,7 @@ def warning_model_no_batchwise_support( :param graph: Model's NNCFGraph. :param advanced_quantization_parameters: AdvancedQuantizationParameters. :param model_type: Model type algorithm option. - :param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support. + :param no_batchwise_support_metatypes: Metatypes having no batchwise statistics support. """ if is_model_no_batchwise_support( graph, advanced_quantization_parameters, model_type, no_batchwise_support_metatypes @@ -87,7 +87,7 @@ def is_model_no_batchwise_support( :param graph: Model's NNCFGraph. :param advanced_quantization_parameters: AdvancedQuantizationParameters. :param model_type: Model type algorithm option. - :param no_batchwise_support_metatypes: Meatypes having no batchwise statistics support. + :param no_batchwise_support_metatypes: Metatypes having no batchwise statistics support. """ return ( advanced_quantization_parameters diff --git a/nncf/torch/quantization/quantize_model.py b/nncf/torch/quantization/quantize_model.py index 23cb451f5fe..2b260cf3041 100644 --- a/nncf/torch/quantization/quantize_model.py +++ b/nncf/torch/quantization/quantize_model.py @@ -10,12 +10,12 @@ # limitations under the License. from copy import deepcopy -from typing import Optional +from typing import Optional, cast import torch import nncf -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.data import Dataset from nncf.parameters import BackupMode @@ -32,6 +32,7 @@ from nncf.scopes import IgnoredScope from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS from nncf.torch.model_creation import wrap_model +from nncf.torch.nncf_network import NNCFNetwork DEFAULT_RANGE_TYPE = "mean_min_max" @@ -72,12 +73,15 @@ def quantize_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) - graph = nncf_network.nncf.get_graph() - warning_model_no_batchwise_support(graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS) - quantized_model = quantization_algorithm.apply(nncf_network, graph, dataset=calibration_dataset) + model_wrapper = ModelWrapper(nncf_network) + warning_model_no_batchwise_support( + model_wrapper.graph, advanced_parameters, model_type, OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS + ) - quantized_model.nncf.disable_dynamic_graph_building() + quantized_model_wrapper = quantization_algorithm.apply(model_wrapper, dataset=calibration_dataset) + quantized_model = cast(NNCFNetwork, quantized_model_wrapper.model) + quantized_model.nncf.disable_dynamic_graph_building() return quantized_model @@ -117,5 +121,5 @@ def compress_weights_impl( backup_mode, advanced_parameters, ) - graph = NNCFGraphFactory.create(model) - return compression_algorithm.apply(model, graph, dataset=dataset) + model_wrapper = ModelWrapper(model) + return compression_algorithm.apply(model_wrapper, dataset=dataset).model diff --git a/tests/cross_fw/test_templates/test_bias_correction.py b/tests/cross_fw/test_templates/test_bias_correction.py index 81b638eb900..bfdbc2a13c9 100644 --- a/tests/cross_fw/test_templates/test_bias_correction.py +++ b/tests/cross_fw/test_templates/test_bias_correction.py @@ -14,7 +14,7 @@ import pytest -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.data import Dataset from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix @@ -81,7 +81,7 @@ def backend_specific_model(model: TModel, tmp_dir: str) -> TModel: @staticmethod @abstractmethod - def check_bias(model: TModel, ref_biases: Dict) -> None: + def check_bias(model_wrapper: ModelWrapper, ref_biases: Dict) -> None: """ Checks biases values. """ @@ -121,9 +121,8 @@ def quantized_test_model(self, tmpdir) -> TModel: dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn()) quantization_algorithm = self.get_quantization_algorithm(disable_bias_correction=True) - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) - modified_model = self.remove_fq_from_inputs(quantized_model) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) + modified_model = self.remove_fq_from_inputs(quantized_model.model) return modified_model @pytest.mark.parametrize( @@ -150,17 +149,17 @@ def test_update_bias(self, model_cls, ref_biases, tmpdir): dataset = Dataset(self.get_dataset(model_cls.INPUT_SIZE), self.get_transform_fn()) quantization_algorithm = self.get_quantization_algorithm() - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model_wrapper = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) mapped_ref_biases = self.map_references(ref_biases, model_cls) - self.check_bias(quantized_model, mapped_ref_biases) + self.check_bias(quantized_model_wrapper, mapped_ref_biases) def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref_data): - nncf_graph = NNCFGraphFactory.create(quantized_test_model) + model_wrapper = ModelWrapper(quantized_test_model) + nncf_graph = model_wrapper.graph bc_algo = self.get_bias_correction_algorithm() - bc_algo._set_backend_entity(quantized_test_model) + bc_algo._set_backend_entity(model_wrapper.backend) node = nncf_graph.get_node_by_name(layer_name) bc_algo._collected_stat_inputs_map.update(ref_data["collected_inputs"]) @@ -171,10 +170,9 @@ def test__get_subgraph_data_for_node(self, quantized_test_model, layer_name, ref def test_verify_collected_stat_inputs_map(self, model_cls, ref_stat_inputs_map, tmpdir): model = self.backend_specific_model(model_cls(), tmpdir) - graph = NNCFGraphFactory.create(model) bc_algo = self.get_bias_correction_algorithm() - bc_algo.get_statistic_points(model, graph) + bc_algo.get_statistic_points(ModelWrapper(model)) collected_stat_inputs_map = getattr(bc_algo, "_collected_stat_inputs_map") assert collected_stat_inputs_map == ref_stat_inputs_map diff --git a/tests/cross_fw/test_templates/test_channel_alignment.py b/tests/cross_fw/test_templates/test_channel_alignment.py index 7995f91961c..ab9d330ec21 100644 --- a/tests/cross_fw/test_templates/test_channel_alignment.py +++ b/tests/cross_fw/test_templates/test_channel_alignment.py @@ -22,6 +22,7 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationType +from nncf.common.model import ModelWrapper from nncf.common.tensor_statistics.statistic_point import StatisticPoint from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer from nncf.common.tensor_statistics.statistics import MinMaxTensorStatistic @@ -415,7 +416,9 @@ def dims_iter(*args, **kwargs): ref_bias_in_after_scale_align, ) ) - algorithm.apply(None, nncf_graph, statistic_points) + + mocker.patch("nncf.common.model.get_backend", return_value=None) + algorithm.apply(ModelWrapper(None, graph=nncf_graph), statistic_points=statistic_points) if empty_statistics or one_dim_mm: assert algorithm._align_means.call_count == 0 @@ -511,7 +514,8 @@ class MockBackend(backend_cls): MockBackend.get_statistic_collector = mocker.MagicMock(return_value=ref_stat_collector) algorithm._backend_entity = MockBackend - statistic_container = algorithm.get_statistic_points(None, nncf_graph) + mocker.patch("nncf.common.model.get_backend", return_value=None) + statistic_container = algorithm.get_statistic_points(ModelWrapper(None, graph=nncf_graph)) backend_cls = self.get_backend_cls() target_node_name = "/Add_1_0" if num_biases else "/Conv_1_0" diff --git a/tests/cross_fw/test_templates/test_fast_bias_correction.py b/tests/cross_fw/test_templates/test_fast_bias_correction.py index 899be7d9a1a..873226f9a54 100644 --- a/tests/cross_fw/test_templates/test_fast_bias_correction.py +++ b/tests/cross_fw/test_templates/test_fast_bias_correction.py @@ -14,7 +14,7 @@ import pytest -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import OverflowFix from nncf.quantization.algorithms.fast_bias_correction.algorithm import FastBiasCorrection @@ -115,7 +115,6 @@ def test_update_bias(self, model_cls, ref_bias, tmpdir): dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type) quantization_algorithm = self.get_quantization_algorithm() - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model_wrapper = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset) - self.check_bias(quantized_model, ref_bias) + self.check_bias(quantized_model_wrapper, ref_bias) diff --git a/tests/cross_fw/test_templates/test_ptq_params.py b/tests/cross_fw/test_templates/test_ptq_params.py index eacf57652e7..6d2039b6d7f 100644 --- a/tests/cross_fw/test_templates/test_ptq_params.py +++ b/tests/cross_fw/test_templates/test_ptq_params.py @@ -21,6 +21,7 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype from nncf.common.graph.operator_metatypes import OutputNoopMetatype from nncf.common.graph.transformations.commands import TargetType +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig @@ -203,7 +204,7 @@ def test_range_estimator_per_tensor(self, test_params, range_estimator_params): assert min_max_algo._range_estimator_params[QuantizerGroup.ACTIVATIONS] == range_estimator_params params = test_params["test_range_estimator_per_tensor"] - stat_points = min_max_algo.get_statistic_points(params["model"], params["nncf_graph"]) + stat_points = min_max_algo.get_statistic_points(ModelWrapper(params["model"], graph=params["nncf_graph"])) assert len(stat_points) == params["stat_points_num"] for _, stat_point in stat_points.items(): @@ -374,7 +375,10 @@ def test_unified_scales_command_creation(self, mocker): Tensor(self.get_backend_tensor(idx - 1)), Tensor(self.get_backend_tensor(idx + 2)) ) stats.add_statistic_point(StatisticPoint(tp, tc, algo._algorithm_key)) - algo.apply(model, model.nncf_graph, stats) + + mocker.patch("nncf.common.model.get_backend", return_value=None) + model_wrapper = ModelWrapper(model, graph=model.nncf_graph) + algo.apply(model_wrapper, statistic_points=stats) mock_transformer.transform.assert_called_once() layout = mock_transformer.transform.call_args.args[0] self.check_unified_scale_layout(layout, unified_scales_group) @@ -423,7 +427,5 @@ def test_empty_statistics(self, mode, mocker): "nncf.quantization.algorithms.min_max.algorithm.MinMaxQuantization._get_quantization_points_overflow_fix", return_value=mocker.MagicMock(), ) - with pytest.raises(nncf.InternalError) as exc_info: - algo.apply(None, None, stat_points) - - assert str(exc_info.value) == "Statistics were not collected for the node A" + with pytest.raises(nncf.InternalError, match="Statistics were not collected for the node A"): + algo.apply(mocker.MagicMock(), statistic_points=stat_points) diff --git a/tests/cross_fw/test_templates/test_smooth_quant.py b/tests/cross_fw/test_templates/test_smooth_quant.py index f4ea260c14e..83336a57701 100644 --- a/tests/cross_fw/test_templates/test_smooth_quant.py +++ b/tests/cross_fw/test_templates/test_smooth_quant.py @@ -19,6 +19,7 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.factory import StatisticsAggregatorFactory from nncf.common.graph.graph import NNCFNode +from nncf.common.model import ModelWrapper from nncf.experimental.common.tensor_statistics.collectors import AbsMaxReducer from nncf.experimental.common.tensor_statistics.collectors import MaxAggregator from nncf.parameters import ModelType @@ -165,8 +166,7 @@ def test_smooth_quant_algo(self, model_cls, reference_values, tmpdir): dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type) quantization_algorithm = self.get_quantization_algorithm(self.get_ignored_scope(model_cls)) - graph = NNCFGraphFactory.create(model) - quantized_model = quantization_algorithm.apply(model, graph, dataset=dataset) + quantized_model = quantization_algorithm.apply(ModelWrapper(model), dataset=dataset).model self.check_scales(quantized_model, reference_values, model_cls) @@ -244,9 +244,10 @@ def test_empty_stats(self, mocker, tmpdir): model = self.backend_specific_model(model_cls(), tmpdir) dataset = get_static_dataset(model_cls.INPUT_SIZE, self.get_transform_fn(), self.fn_to_type) - graph = NNCFGraphFactory.create(model) + model_wrapper = ModelWrapper(model) + graph = model_wrapper.graph algo = SmoothQuant(subset_size=1, inplace_statistics=False) - algo_statistic_points = algo.get_statistic_points(model, graph) + algo_statistic_points = algo.get_statistic_points(model_wrapper) statistics_aggregator = StatisticsAggregatorFactory.create(model, dataset) statistics_aggregator.register_statistic_points(algo_statistic_points) statistics_aggregator.collect_statistics(model, graph) @@ -260,7 +261,8 @@ def test_empty_stats(self, mocker, tmpdir): mocked_transformer = mocker.MagicMock() mocker.patch("nncf.common.factory.ModelTransformerFactory.create", return_value=mocked_transformer) - algo.apply(model, graph, algo_statistic_points) + mocker.patch("nncf.common.model.get_backend", return_value=None) + algo.apply(model_wrapper, statistic_points=algo_statistic_points) mocked_transformer.transform.assert_called_once() arg = mocked_transformer.transform.call_args.args[0] diff --git a/tests/onnx/quantization/common.py b/tests/onnx/quantization/common.py index 18f36b29ee4..48bf787f2d1 100644 --- a/tests/onnx/quantization/common.py +++ b/tests/onnx/quantization/common.py @@ -16,6 +16,7 @@ import onnx from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.onnx_helper import get_edge_dtype @@ -108,7 +109,6 @@ def min_max_quantize_model( ) -> onnx.ModelProto: if convert_model_opset: original_model = convert_opset_version(original_model) - graph = GraphConverter.create_nncf_graph(original_model) dataset = get_random_dataset_for_test(original_model, dataset_has_batch_size) quantization_params = {} if quantization_params is None else quantization_params @@ -123,8 +123,8 @@ def min_max_quantize_model( post_training_quantization = PostTrainingQuantization(subset_size=1, **quantization_params) - quantized_model = post_training_quantization.apply(original_model, graph, dataset=dataset) - return quantized_model + quantized_model = post_training_quantization.apply(ModelWrapper(original_model), dataset=dataset) + return quantized_model.model def ptq_quantize_model( diff --git a/tests/onnx/quantization/test_bias_correction.py b/tests/onnx/quantization/test_bias_correction.py index 8ffe20c6102..5c4187d929e 100644 --- a/tests/onnx/quantization/test_bias_correction.py +++ b/tests/onnx/quantization/test_bias_correction.py @@ -16,7 +16,7 @@ import pytest import torch -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.onnx.graph.model_utils import remove_fq_from_inputs from nncf.onnx.graph.nncf_graph_builder import GraphConverter from nncf.onnx.graph.node_utils import get_bias_value @@ -75,8 +75,8 @@ def compare_nncf_graphs(model: onnx.ModelProto, ref_path: str) -> None: return compare_nncf_graph(model, ref_path) @staticmethod - def check_bias(model: onnx.ModelProto, ref_biases: Dict) -> None: - nncf_graph = NNCFGraphFactory.create(model) + def check_bias(model_wrapper: ModelWrapper, ref_biases: Dict) -> None: + model, nncf_graph = model_wrapper.unwrap() for ref_name, ref_value in ref_biases.items(): node = nncf_graph.get_node_by_name(ref_name) ref_value = np.array(ref_value) diff --git a/tests/onnx/quantization/test_fast_bias_correction.py b/tests/onnx/quantization/test_fast_bias_correction.py index 0ed364ecb80..9dfc92c170f 100644 --- a/tests/onnx/quantization/test_fast_bias_correction.py +++ b/tests/onnx/quantization/test_fast_bias_correction.py @@ -15,7 +15,7 @@ import onnx import torch -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.onnx.graph.node_utils import get_bias_value from nncf.onnx.graph.node_utils import is_node_with_bias from nncf.quantization.algorithms.fast_bias_correction.onnx_backend import ONNXFastBiasCorrectionAlgoBackend @@ -58,9 +58,9 @@ def transform_fn(data_item): return transform_fn @staticmethod - def check_bias(model: onnx.ModelProto, ref_bias: list): + def check_bias(model_wrapper: ModelWrapper, ref_bias: list): ref_bias = np.array(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) + model, nncf_graph = model_wrapper.unwrap() for node in nncf_graph.get_all_nodes(): if not is_node_with_bias(node): continue diff --git a/tests/openvino/native/quantization/test_fq_params_calculation.py b/tests/openvino/native/quantization/test_fq_params_calculation.py index 5751a34f39b..67a024a1fa4 100644 --- a/tests/openvino/native/quantization/test_fq_params_calculation.py +++ b/tests/openvino/native/quantization/test_fq_params_calculation.py @@ -15,8 +15,8 @@ import pytest import torch +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset -from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.statistics.aggregator import OVStatisticsAggregator from nncf.parameters import QuantizationMode from nncf.quantization.advanced_parameters import OverflowFix @@ -66,15 +66,15 @@ def get_fq_nodes_stats_algo(model): def quantize_model(ov_model, q_params): dataset = get_dataset_for_test(ov_model) - graph = GraphConverter.create_nncf_graph(ov_model) + model_wrapper = ModelWrapper(ov_model) min_max_algo = MinMaxQuantization(subset_size=1, **q_params) statistics_aggregator = OVStatisticsAggregator(dataset) - statistic_points = min_max_algo.get_statistic_points(ov_model, graph) + statistic_points = min_max_algo.get_statistic_points(model_wrapper) statistics_aggregator.register_statistic_points(statistic_points) - statistics_aggregator.collect_statistics(ov_model, graph) - quantized_model = min_max_algo.apply(ov_model, graph, statistics_aggregator.statistic_points) - return quantized_model + statistics_aggregator.collect_statistics(model_wrapper.model, model_wrapper.graph) + quantized_model = min_max_algo.apply(model_wrapper, statistic_points=statistics_aggregator.statistic_points) + return quantized_model.model @pytest.fixture(params=[True, False], ids=["inplace", "out_of_place"], name="inplace_statistics") diff --git a/tests/openvino/native/quantization/test_graphs.py b/tests/openvino/native/quantization/test_graphs.py index 7dc3c94c081..9ddf2e0e767 100644 --- a/tests/openvino/native/quantization/test_graphs.py +++ b/tests/openvino/native/quantization/test_graphs.py @@ -18,8 +18,8 @@ import pytest from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset -from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.quantization.quantize_model import quantize_impl from nncf.openvino.statistics.aggregator import OVStatisticsAggregator from nncf.parameters import ModelType @@ -133,14 +133,16 @@ def test_real_models_sq_placement(model_name_params, tmp_path): def smooth_quant_model(ov_model: ov.Model, q_params: Dict, quantize=True): dataset = get_dataset_for_test(ov_model) - graph = GraphConverter.create_nncf_graph(ov_model) + model_wrapper = ModelWrapper(ov_model) smooth_quant_algo = SmoothQuant(subset_size=1) statistics_aggregator = OVStatisticsAggregator(dataset) - statistic_points = smooth_quant_algo.get_statistic_points(ov_model, graph) + statistic_points = smooth_quant_algo.get_statistic_points(model_wrapper) statistics_aggregator.register_statistic_points(statistic_points) - statistics_aggregator.collect_statistics(ov_model, graph) - modified_model = smooth_quant_algo.apply(ov_model, graph, statistics_aggregator.statistic_points) + statistics_aggregator.collect_statistics(model_wrapper.model, model_wrapper.graph) + modified_model = smooth_quant_algo.apply( + model_wrapper, statistic_points=statistics_aggregator.statistic_points + ).model if quantize: modified_model = quantize_model(modified_model, q_params) diff --git a/tests/openvino/native/test_bias_correction.py b/tests/openvino/native/test_bias_correction.py index 711c831facd..b8af94b31fe 100644 --- a/tests/openvino/native/test_bias_correction.py +++ b/tests/openvino/native/test_bias_correction.py @@ -16,7 +16,7 @@ import pytest import torch -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.openvino.graph.model_utils import remove_fq_from_inputs from nncf.openvino.graph.nncf_graph_builder import GraphConverter from nncf.openvino.graph.node_utils import get_bias_value @@ -79,8 +79,8 @@ def compare_nncf_graphs(model: ov.Model, ref_path: str) -> None: return compare_nncf_graphs(model, ref_path) @staticmethod - def check_bias(model: ov.Model, ref_biases: Dict) -> None: - nncf_graph = NNCFGraphFactory.create(model) + def check_bias(model_wrapper: ModelWrapper, ref_biases: Dict) -> None: + model, nncf_graph = model_wrapper.unwrap() for ref_name, ref_value in ref_biases.items(): node = nncf_graph.get_node_by_name(ref_name) ref_value = np.array(ref_value) diff --git a/tests/openvino/native/test_fast_bias_correction.py b/tests/openvino/native/test_fast_bias_correction.py index 6de9523bead..9c51f334d81 100644 --- a/tests/openvino/native/test_fast_bias_correction.py +++ b/tests/openvino/native/test_fast_bias_correction.py @@ -15,7 +15,7 @@ import openvino as ov import torch -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.common.utils.os import is_macos from nncf.openvino.graph.node_utils import get_bias_value from nncf.openvino.graph.node_utils import is_node_with_bias @@ -52,9 +52,9 @@ def transform_fn(data_item): return transform_fn @staticmethod - def check_bias(model: ov.Model, ref_bias: list): + def check_bias(model_wrapper: ModelWrapper, ref_bias: list): ref_bias = np.array(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) + model, nncf_graph = model_wrapper.unwrap() atol = 0.0001 if not is_macos() else 0.01 diff --git a/tests/torch/fx/test_bias_correction.py b/tests/torch/fx/test_bias_correction.py index 06db212fe79..34d5edf0b21 100644 --- a/tests/torch/fx/test_bias_correction.py +++ b/tests/torch/fx/test_bias_correction.py @@ -13,11 +13,10 @@ from typing import Any, Dict, List import numpy as np -import openvino as ov import pytest import torch.fx -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.experimental.torch.fx.model_utils import remove_fq_from_inputs from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter from nncf.experimental.torch.fx.node_utils import get_bias_value @@ -77,8 +76,8 @@ def remove_fq_from_inputs(model: torch.fx.GraphModule) -> torch.fx.GraphModule: return remove_fq_from_inputs(model, graph) @staticmethod - def check_bias(model: ov.Model, ref_biases: Dict) -> None: - nncf_graph = NNCFGraphFactory.create(model) + def check_bias(model_wrapper: ModelWrapper, ref_biases: Dict) -> None: + model, nncf_graph = model_wrapper.unwrap() for ref_name, ref_value in ref_biases.items(): node = nncf_graph.get_node_by_name(ref_name) ref_value = torch.tensor(ref_value) diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 0de35aef29e..4ea029ba5cd 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -13,6 +13,7 @@ import pytest import torch +from torch.fx import GraphModule import nncf from nncf import BackupMode @@ -52,7 +53,7 @@ def get_model_size(model): def get_compressed_modules_weights( - compressed_model: torch.fx.GraphModule, dtype: torch.dtype, compressed_node_weight_port: Dict[str, int] + compressed_model: GraphModule, dtype: torch.dtype, compressed_node_weight_port: Dict[str, int] ): n_target_modules = 0 n_compressed_weights = 0 diff --git a/tests/torch/fx/test_fast_bias_correction.py b/tests/torch/fx/test_fast_bias_correction.py index 8c94a32cafd..16da7d24003 100644 --- a/tests/torch/fx/test_fast_bias_correction.py +++ b/tests/torch/fx/test_fast_bias_correction.py @@ -15,7 +15,7 @@ import torch import torch.fx -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import FXFastBiasCorrectionAlgoBackend from nncf.torch.model_graph_manager import OPERATORS_WITH_BIAS_METATYPES from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm @@ -49,9 +49,9 @@ def transform_fn(data_item): return transform_fn @staticmethod - def check_bias(model: torch.fx.GraphModule, ref_bias: list): + def check_bias(model_wrapper: ModelWrapper, ref_bias: list): + model, nncf_graph = model_wrapper.unwrap() ref_bias = torch.Tensor(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) for node in nncf_graph.get_all_nodes(): if node.metatype not in OPERATORS_WITH_BIAS_METATYPES: continue @@ -77,7 +77,3 @@ def backend_specific_model(model: bool, tmp_dir: str): @staticmethod def fn_to_type(tensor): return torch.Tensor(tensor).cuda() - - @staticmethod - def check_bias(model: torch.fx.GraphModule, ref_bias: list): - TestTorchFXFBCAlgorithm.check_bias(model, ref_bias) diff --git a/tests/torch/ptq/test_calculation_quantizer_params.py b/tests/torch/ptq/test_calculation_quantizer_params.py index 556b5f9e387..06cbfd32120 100644 --- a/tests/torch/ptq/test_calculation_quantizer_params.py +++ b/tests/torch/ptq/test_calculation_quantizer_params.py @@ -20,6 +20,7 @@ from nncf import Dataset from nncf.common.graph.transformations.commands import TargetType +from nncf.common.model import ModelWrapper from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode from nncf.common.quantization.structs import QuantizerConfig @@ -314,16 +315,16 @@ def test_quantizer_parameters_export(tmp_path: Path, _seed): statistics_aggregator = PTStatisticsAggregator(dataset) nncf_network = wrap_model(model, torch.ones([1, 3, 32, 32]), True) - statistic_points = min_max_algo.get_statistic_points(nncf_network, nncf_network.nncf.get_graph()) + statistic_points = min_max_algo.get_statistic_points(ModelWrapper(nncf_network)) statistics_aggregator.register_statistic_points(statistic_points) statistics_aggregator.collect_statistics(model, nncf_network.nncf.get_graph()) torch_quantized_model = min_max_algo.apply( - nncf_network, nncf_network.nncf.get_graph(), statistics_aggregator.statistic_points + ModelWrapper(nncf_network), statistic_points=statistics_aggregator.statistic_points ) path = str(tmp_path / "torch_ptq_model.onnx") torch.onnx.export( - torch_quantized_model, + torch_quantized_model.model, input_data, path, export_params=True, diff --git a/tests/torch/ptq/test_fast_bias_correction.py b/tests/torch/ptq/test_fast_bias_correction.py index 5b3c6c0ce3c..04c0af32a7f 100644 --- a/tests/torch/ptq/test_fast_bias_correction.py +++ b/tests/torch/ptq/test_fast_bias_correction.py @@ -14,11 +14,10 @@ import pytest import torch -from nncf.common.factory import NNCFGraphFactory +from nncf.common.model import ModelWrapper from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend from nncf.torch.model_graph_manager import get_fused_bias_value from nncf.torch.model_graph_manager import is_node_with_fused_bias -from nncf.torch.nncf_network import NNCFNetwork from tests.cross_fw.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm from tests.torch.ptq.helpers import get_nncf_network @@ -49,9 +48,9 @@ def transform_fn(data_item): return transform_fn @staticmethod - def check_bias(model: NNCFNetwork, ref_bias: list): + def check_bias(model_wrapper: ModelWrapper, ref_bias: list): ref_bias = torch.Tensor(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) + model, nncf_graph = model_wrapper.unwrap() for node in nncf_graph.get_all_nodes(): if not is_node_with_fused_bias(node, nncf_graph): continue @@ -78,9 +77,9 @@ def fn_to_type(tensor): return torch.Tensor(tensor).cuda() @staticmethod - def check_bias(model: NNCFNetwork, ref_bias: list): + def check_bias(model_wrapper: ModelWrapper, ref_bias: list): + model, nncf_graph = model_wrapper.unwrap() ref_bias = torch.Tensor(ref_bias) - nncf_graph = NNCFGraphFactory.create(model) for node in nncf_graph.get_all_nodes(): if not is_node_with_fused_bias(node, nncf_graph): continue diff --git a/tests/torch/ptq/test_fq_params_calculation.py b/tests/torch/ptq/test_fq_params_calculation.py index 6d71760cd33..9c2bbe861b8 100644 --- a/tests/torch/ptq/test_fq_params_calculation.py +++ b/tests/torch/ptq/test_fq_params_calculation.py @@ -16,6 +16,7 @@ import torch import nncf +from nncf.common.model import ModelWrapper from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters from nncf.quantization.advanced_parameters import OverflowFix @@ -58,8 +59,8 @@ def transform_fn(sample): original_model.eval() nncf_network = wrap_model(original_model, torch.ones([1, 1, 10, 10]), trace_parameters=True) - quantized_model = post_training_quantization.apply(nncf_network, nncf_network.nncf.get_graph(), dataset=dataset) - return quantized_model + quantized_model = post_training_quantization.apply(ModelWrapper(nncf_network), dataset=dataset) + return quantized_model.model def get_fq_nodes(model: NNCFNetwork) -> Dict[Scope, torch.nn.Module]: diff --git a/tests/torch/ptq/test_graphs.py b/tests/torch/ptq/test_graphs.py index eba35163c7c..be902f9daa5 100644 --- a/tests/torch/ptq/test_graphs.py +++ b/tests/torch/ptq/test_graphs.py @@ -16,6 +16,7 @@ import torch from nncf import Dataset +from nncf.common.model import ModelWrapper from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters @@ -121,8 +122,7 @@ def test_min_max_classification_quantized_graphs(desc: ModelDesc, quantization_p quantization_algorithm = PostTrainingQuantization(**quantization_parameters) quantized_model = quantization_algorithm.apply( - nncf_network, - nncf_network.nncf.get_graph(), + ModelWrapper(nncf_network), dataset=Dataset([example_input]), - ) + ).model check_graph(quantized_model.nncf.get_graph(), desc.dot_filename(), graph_dir)