From 2e987f53ef954c49d8f4a2851f0a51e6b8e984a7 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 17 Jul 2023 12:50:40 +0400 Subject: [PATCH 01/10] speed-up calculation of FQ important scores --- nncf/common/factory.py | 12 +- nncf/common/utils/backend.py | 122 +++++++++++++----- nncf/common/utils/os.py | 23 ++++ nncf/common/utils/timer.py | 7 +- nncf/data/dataset.py | 98 ++++++++++++-- nncf/openvino/engine.py | 56 ++++++-- nncf/openvino/quantization/quantize_model.py | 1 + nncf/quantization/advanced_parameters.py | 4 + .../algorithms/accuracy_control/algorithm.py | 70 ++++++++-- .../algorithms/accuracy_control/backend.py | 36 +++++- .../algorithms/accuracy_control/evaluator.py | 73 +++++++++-- .../accuracy_control/openvino_backend.py | 37 ++++++ .../algorithms/accuracy_control/ranker.py | 75 +++++++++-- setup.py | 1 + 14 files changed, 525 insertions(+), 90 deletions(-) diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 6dc355aec90..8ccd054b0c9 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -16,7 +16,9 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.command_creation import CommandCreator from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_available_backends from nncf.common.utils.backend import get_backend +from nncf.common.utils.backend import is_openvino_compiled_model TModel = TypeVar("TModel") @@ -77,9 +79,15 @@ def create(model: TModel) -> Engine: """ Factory method to create backend-specific Engine instance based on the input model. - :param model: backend-specific model instance - :return: backend-specific Engine instance + :param model: backend-specific model instance. + :return: backend-specific Engine instance. """ + available_backends = get_available_backends() + if BackendType.OPENVINO in available_backends and is_openvino_compiled_model(model): + from nncf.openvino.engine import OVCompiledModelEngine + + return OVCompiledModelEngine(model) + model_backend = get_backend(model) if model_backend == BackendType.ONNX: from nncf.onnx.engine import ONNXEngine diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index fcd315dee22..f4d3e47cd22 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -8,9 +8,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib from copy import deepcopy from enum import Enum -from typing import TypeVar +from typing import List, TypeVar TModel = TypeVar("TModel") @@ -22,58 +23,115 @@ class BackendType(Enum): OPENVINO = "OpenVINO" -def get_backend(model) -> BackendType: +def get_available_backends() -> List[BackendType]: """ - Returns the NNCF backend name string inferred from the type of the model object passed into this function. + Returns a list of available backends. - :param model: The framework-specific object representing the trainable model. - :return: A BackendType representing the correct NNCF backend to be used when working with the framework. + :return: A list of avauilable backends. + """ + frameworks = [ + ("torch", BackendType.TORCH), + ("tensorflow", BackendType.TENSORFLOW), + ("onnx", BackendType.ONNX), + ("openvino.runtime", BackendType.OPENVINO), + ] + + available_backends = [] + for module_name, backend in frameworks: + try: + importlib.import_module(module_name) + available_backends.append(backend) + except ImportError: + pass + + return available_backends + + +def is_torch_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of torch.nn.Module, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of torch.nn.Module, otherwise False. + """ + import torch + + return isinstance(model, torch.nn.Module) + + +def is_tensorflow_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of tensorflow.Module, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of tensorflow.Module, otherwise False. + """ + import tensorflow + + return isinstance(model, tensorflow.Module) + + +def is_onnx_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of onnx.ModelProto, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of onnx.ModelProto, otherwise False. + """ + import onnx + + return isinstance(model, onnx.ModelProto) + + +def is_openvino_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of openvino.runtime.Model, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of openvino.runtime.Model, otherwise False. """ - available_frameworks = [] - try: - import torch + import openvino.runtime as ov - available_frameworks.append("PyTorch") - except ImportError: - torch = None + return isinstance(model, ov.Model) - try: - import tensorflow - available_frameworks.append("Tensorflow") - except ImportError: - tensorflow = None +def is_openvino_compiled_model(model: TModel): + """ + Returns True if the model is an instance of openvino.runtime.CompiledModel, otherwise False. - try: - import onnx + :param model: A target model. + :return: True if the model is an instance of openvino.runtime.CompiledModel, otherwise False. + """ + import openvino.runtime as ov - available_frameworks.append("ONNX") - except ImportError: - onnx = None + return isinstance(model, ov.CompiledModel) - try: - import openvino.runtime as ov - available_frameworks.append("OpenVINO") - except ImportError: - ov = None +def get_backend(model: TModel) -> BackendType: + """ + Returns the NNCF backend name string inferred from the type of the model object passed into this function. + + :param model: The framework-specific object representing the trainable model. + :return: A BackendType representing the correct NNCF backend to be used when working with the framework. + """ + available_backends = get_available_backends() - if torch is not None and isinstance(model, torch.nn.Module): + if BackendType.TORCH in available_backends and is_torch_model(model): return BackendType.TORCH - if tensorflow is not None and isinstance(model, tensorflow.Module): + if BackendType.TENSORFLOW in available_backends and is_tensorflow_model(model): return BackendType.TENSORFLOW - if onnx is not None and isinstance(model, onnx.ModelProto): + if BackendType.ONNX in available_backends and is_onnx_model(model): return BackendType.ONNX - if ov is not None and isinstance(model, ov.Model): + if BackendType.OPENVINO in available_backends and is_openvino_model(model): return BackendType.OPENVINO raise RuntimeError( "Could not infer the backend framework from the model type because " "the framework is not available or the model type is unsupported. " - "The available frameworks found: {}.".format(", ".join(available_frameworks)) + "The available frameworks found: {}.".format(", ".join(available_backends)) ) @@ -82,7 +140,7 @@ def copy_model(model: TModel) -> TModel: Function to create copy of the backend-specific model. :param model: the backend-specific model instance - :return: Copy of the backend-specific model instance + :return: Copy of the backend-specific model instance. """ model_backend = get_backend(model) if model_backend == BackendType.OPENVINO: diff --git a/nncf/common/utils/os.py b/nncf/common/utils/os.py index 6e7775026b1..2732f12d8ed 100644 --- a/nncf/common/utils/os.py +++ b/nncf/common/utils/os.py @@ -8,10 +8,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import sys from contextlib import contextmanager from pathlib import Path +import psutil + # pylint: disable=W1514 @contextmanager @@ -37,3 +40,23 @@ def is_windows(): def is_linux(): return "linux" in sys.platform + + +def available_cpu_count() -> int: + """ + :return: Logical CPU count + """ + try: + return multiprocessing.cpu_count() + except Exception as e: # pylint: disable=broad-except + return 1 + + +def available_memory_amount() -> int: + """ + :return: Available memory amount (bytes) + """ + try: + return psutil.virtual_memory()[1] + except Exception as e: # pylint: disable=broad-except + return 0 diff --git a/nncf/common/utils/timer.py b/nncf/common/utils/timer.py index 2d273e67d50..6e87676ce07 100644 --- a/nncf/common/utils/timer.py +++ b/nncf/common/utils/timer.py @@ -20,8 +20,9 @@ def timer(): """ Context manager to measure execution time. """ - start_time = time.perf_counter() - yield - elapsed_time = time.perf_counter() - start_time + start_time = end_time = time.perf_counter() + yield lambda: end_time - start_time + end_time = time.perf_counter() + elapsed_time = end_time - start_time time_string = time.strftime("%H:%M:%S", time.gmtime(elapsed_time)) nncf_logger.info(f"Elapsed Time: {time_string}") diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index 6bf1322c2ec..d7a0f29c654 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import abstractmethod from typing import Callable, Generic, Iterable, List, Optional, TypeVar from nncf.common.utils.api_marker import api @@ -17,21 +18,43 @@ ModelInput = TypeVar("ModelInput") -@api(canonical_alias="nncf.Dataset") -class Dataset(Generic[DataItem, ModelInput]): +class IDataset(Generic[DataItem, ModelInput]): """ - Wrapper for passing custom user datasets into NNCF algorithms. - This class defines the interface by which compression algorithms retrieve data items from the passed data source object. These data items are used for different purposes, for example, model inference and model validation, based on the choice of the exact compression algorithm. + """ - If the data item has been returned from the data source per iteration and it cannot be - used as input for model inference, the transformation function is used to extract the - model's input from this data item. For example, in supervised learning, the data item - usually contains both examples and labels. So transformation function should extract - the examples from the data item. + @abstractmethod + def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: + """ + Returns the iterable object that contains selected data items from the data source as-is. + + :param indices: The zero-based indices of data items that should be selected from + the data source. The indices should be sorted in ascending order. If indices are + not passed all data items are selected from the data source. + :return: The iterable object that contains selected data items from the data source as-is. + """ + + @abstractmethod + def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]: + """ + Returns the iterable object that contains selected data items from the data source which + can be used as the model's input for model inference. + + :param indices: The zero-based indices of data items that should be selected from + the data source. The indices should be sorted in ascending order. If indices are + not passed all data items are selected from the data source. + :return: The iterable object that contains selected data items from the data source which + can be used as the model's input for model inference. + """ + + +@api(canonical_alias="nncf.Dataset") +class Dataset(IDataset): + """ + Wrapper for passing custom user datasets into NNCF algorithms. :param data_source: The iterable object serving as the source of data items. :param transform_func: The function that is used to extract the model's input @@ -73,6 +96,63 @@ def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Mo return DataProvider(self._data_source, self._transform_func, indices) +class CountingDatasetWrapper(IDataset): + """ + Dataset wrapper for calculation number of iterations. + + :param dataset: The dataset. + """ + + def __init__(self, dataset: IDataset): + self._dataset = dataset + self._num_iters = 0 + + def tranform_func(x): + self._num_iters += 1 + return x + + self._transform_func = tranform_func + + @property + def num_iters(self): + """ + :return: The number of iterations performed in the last requested object. + """ + return self._num_iters + + def reset_num_iters(self): + """ + Resets iterations counter. + """ + self._num_iters = 0 + + def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: + """ + Returns the iterable object that contains selected data items from the data source as-is. + + :param indices: The zero-based indices of data items that should be selected from + the data source. The indices should be sorted in ascending order. If indices are + not passed all data items are selected from the data source. + :return: The iterable object that contains selected data items from the data source as-is. + """ + self.reset_num_iters() + return DataProvider(self._dataset.get_data(indices), self._transform_func, None) + + def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]: + """ + Returns the iterable object that contains selected data items from the data source which + can be used as the model's input for model inference. + + :param indices: The zero-based indices of data items that should be selected from + the data source. The indices should be sorted in ascending order. If indices are + not passed all data items are selected from the data source. + :return: The iterable object that contains selected data items from the data source which + can be used as the model's input for model inference. + """ + self.reset_num_iters() + return DataProvider(self._dataset.get_inference_data(indices), self._transform_func, None) + + class DataProvider(Generic[DataItem, ModelInput]): def __init__( self, diff --git a/nncf/openvino/engine.py b/nncf/openvino/engine.py index 697c0cca9a9..f9964f3f750 100644 --- a/nncf/openvino/engine.py +++ b/nncf/openvino/engine.py @@ -18,21 +18,17 @@ from nncf.parameters import TargetDevice -class OVNativeEngine(Engine): +class OVCompiledModelEngine(Engine): """ - Implementation of the engine for OpenVINO backend. + Implementation of the engine to infer OpenVINO compiled model. - OVNativeEngine uses + OVCompiledModelEngine uses [OpenVINO Runtime](https://docs.openvino.ai/latest/openvino_docs_OV_UG_OV_Runtime_User_Guide.html) - to infer the model. + to infer the compiled model. """ - def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.CPU): - if target_device == TargetDevice.ANY: - target_device = TargetDevice.CPU - - ie = ov.Core() - self.compiled_model = ie.compile_model(model, target_device.value) + def __init__(self, model: ov.CompiledModel): + self.compiled_model = model self.input_tensor_names = set() self.number_of_inputs = len(model.inputs) for model_input in model.inputs: @@ -74,3 +70,43 @@ def infer( for tensor_name in tensor.get_names(): output_data[tensor_name] = value return output_data + + +class OVNativeEngine(Engine): + """ + Implementation of the engine for OpenVINO backend. + + OVNativeEngine uses + [OpenVINO Runtime](https://docs.openvino.ai/latest/openvino_docs_OV_UG_OV_Runtime_User_Guide.html) + to infer the model. + """ + + def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.CPU): + self.engine = None + self.model = model + self.target_device = target_device + if target_device == TargetDevice.ANY: + self.target_device = TargetDevice.CPU + + def _get_compiled_model(self) -> ov.CompiledModel: + """ + Returns OpenVINO compiled model + + :return: A compiled model + """ + return ov.Core().compile_model(self.model, self.target_device.value) + + def infer( + self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]] + ) -> Dict[str, np.ndarray]: + """ + Runs model on the provided input via OpenVINO Runtime. + Returns the dictionary of model outputs by node names. + + :param input_data: Inputs for the model. + :return output_data: Model's output. + """ + if self.engine is None: + compiled_model = self._get_compiled_model() + self.engine = OVCompiledModelEngine(compiled_model) + return self.engine.infer(input_data) diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index ee1f25ee2a6..c1003b5903f 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -192,6 +192,7 @@ def native_quantize_with_accuracy_control_impl( max_num_iterations=advanced_accuracy_restorer_parameters.max_num_iterations, max_drop=max_drop, drop_type=drop_type, + num_ranking_processes=advanced_accuracy_restorer_parameters.num_ranking_processes, ) quantized_model = accuracy_aware_loop.apply(model, quantized_model, validation_dataset, validation_fn) if compress_weights: diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index 94bbe59772b..a3ec119ddeb 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -190,11 +190,15 @@ class AdvancedAccuracyRestorerParameters: :param ranking_subset_size: Size of a subset that is used to rank layers by their contribution to the accuracy drop. :type ranking_subset_size: Optional[int] + :param num_ranking_processes: The number of parallel processes that are used to rank + quantization operations. + :type num_ranking_processes: Optional[int] """ max_num_iterations: int = sys.maxsize tune_hyperparams: bool = False ranking_subset_size: Optional[int] = None + num_ranking_processes: Optional[int] = None def changes_asdict(params: Any) -> Dict[str, Any]: diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index fc413c5f842..543b8685b6a 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -20,7 +20,10 @@ from nncf.common.quantization.quantizer_removal import revert_operations_to_floating_point_precision from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend +from nncf.common.utils.os import available_cpu_count +from nncf.common.utils.os import available_memory_amount from nncf.common.utils.timer import timer +from nncf.data.dataset import CountingDatasetWrapper from nncf.data.dataset import Dataset from nncf.parameters import DropType from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend @@ -29,6 +32,9 @@ TModel = TypeVar("TModel") TTensor = TypeVar("TTensor") +PREPARATION_MODEL_THRESHOLD = 1 +OVERHEAD_COEFFICIENT = 2 +MEMORY_INCREASE_COEFFICIENT = 4 def get_algo_backend(backend: BackendType) -> AccuracyControlAlgoBackend: @@ -139,6 +145,7 @@ def __init__( max_num_iterations: int = sys.maxsize, max_drop: float = 0.01, drop_type: DropType = DropType.ABSOLUTE, + num_ranking_processes: Optional[int] = None, ): """ :param ranking_subset_size: The number of data items that will be selected from @@ -148,11 +155,14 @@ def __init__( :param drop_type: The accuracy drop type, which determines how the maximum accuracy drop between the original model and the compressed model is calculated. + :param num_ranking_processes: The number of parallel processes that are used to rank + quantization operations. """ self.ranking_subset_size = ranking_subset_size self.max_num_iterations = max_num_iterations self.max_drop = max_drop self.drop_type = drop_type + self.num_ranking_processes = num_ranking_processes def apply( self, @@ -182,12 +192,17 @@ def apply( # Validate initial and quantized model evaluator = Evaluator(validation_fn, algo_backend) - initial_metric, reference_values_for_each_item = self._collect_metric_and_values( + initial_metric, reference_values_for_each_item, _, _ = self._collect_metric_and_values( initial_model, validation_dataset, evaluator, "initial" ) - quantized_metric, approximate_values_for_each_item = self._collect_metric_and_values( - quantized_model, validation_dataset, evaluator, "quantized" - ) + counting_validation_dataset = CountingDatasetWrapper(validation_dataset) + ( + quantized_metric, + approximate_values_for_each_item, + preperation_time, + validation_time, + ) = self._collect_metric_and_values(quantized_model, counting_validation_dataset, evaluator, "quantized") + validation_dataset_size = counting_validation_dataset.num_iters should_terminate, accuracy_drop = calculate_accuracy_drop( initial_metric, quantized_metric, self.max_drop, self.drop_type @@ -221,7 +236,17 @@ def apply( ) nncf_logger.info(f"Total number of quantized operations in the model: {report.num_quantized_operations}") - ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator) + # Calculate number of parallel processes for Ranker + num_ranking_processes = self.num_ranking_processes + if num_ranking_processes is None: + model_size = algo_backend.get_model_size(quantized_model) + num_ranking_processes = self.compute_number_ranker_parallel_proc( + model_size, preperation_time, validation_time, validation_dataset_size, self.ranking_subset_size + ) + + nncf_logger.info(f"Number of parallel processes to rank quantized operations: {num_ranking_processes}") + + ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator, num_ranking_processes) groups_to_rank = ranker.find_groups_of_quantizers_to_rank(quantized_model_graph) ranked_groups = ranker.rank_groups_of_quantizers( groups_to_rank, @@ -383,7 +408,36 @@ def _collect_metric_and_values( model: TModel, dataset: Dataset, evaluator: Evaluator, model_name: str ) -> Tuple[float, Union[None, List[float], List[List[TTensor]]]]: nncf_logger.info(f"Validation of {model_name} model was started") - with timer(): - metric, values_for_each_item = evaluator.validate(model, dataset) + with timer() as preperation_time: + model_for_inference = evaluator.prepare_model_for_inference(model) + with timer() as validation_time: + metric, values_for_each_item = evaluator.validate_model_for_inference(model_for_inference, dataset) nncf_logger.info(f"Metric of {model_name} model: {metric}") - return metric, values_for_each_item + return metric, values_for_each_item, preperation_time(), validation_time() + + @staticmethod + def compute_number_ranker_parallel_proc( + model_size: int, + preperation_time: float, + validation_time: float, + validation_dataset_size: int, + ranking_subset_size: int, + ) -> int: + if preperation_time < PREPARATION_MODEL_THRESHOLD: + return 1 + + # Calculate the number of parallel processes needed to override model preparation and + # metric calculation on the ranking subset + ranking_time = validation_time * ranking_subset_size / validation_dataset_size + n_proc = max(round(preperation_time / ranking_time * OVERHEAD_COEFFICIENT), 2) + + # Apply limitation by number of CPU cores + n_cores = available_cpu_count() + n_proc = max(min(n_proc, n_cores // 2), 1) + + # Apply limitation by memmory + ram = available_memory_amount() + n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT) + n_proc = max(min(n_proc, n_copies - 1), 1) + + return n_proc diff --git a/nncf/quantization/algorithms/accuracy_control/backend.py b/nncf/quantization/algorithms/accuracy_control/backend.py index 010bc4cad13..ef2ab709670 100644 --- a/nncf/quantization/algorithms/accuracy_control/backend.py +++ b/nncf/quantization/algorithms/accuracy_control/backend.py @@ -18,6 +18,20 @@ from nncf.common.graph.operator_metatypes import OperatorMetatype TModel = TypeVar("TModel") +TPModel = TypeVar("TPModel") + + +class AsyncPreparedModel(ABC): + @abstractmethod + def get(self, timeout) -> TPModel: + """ + Returns the prepared model for inference when it arrives. If timeout is not None and + the result does not arrive within timeout seconds then TimeoutError is raised. If + the remote call raised an exception then that exception will be reraised by get(). + + :param timeout: timeout + :return: A prepared model for inference + """ class AccuracyControlAlgoBackend(ABC): @@ -127,14 +141,34 @@ def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: :return: Weights input port indices. """ + @staticmethod + @abstractmethod + def get_model_size(model: TModel) -> int: + """ + Returns model size + + :param model: A model + :return: Model size (in bytes) + """ + # Preparation of model @staticmethod @abstractmethod - def prepare_for_inference(model: TModel) -> Any: + def prepare_for_inference(model: TModel) -> TPModel: """ Prepares model for inference. :param model: A model that should be prepared. :return: Prepared model for inference. """ + + @staticmethod + @abstractmethod + def prepare_for_inference_async(model: TModel) -> AsyncPreparedModel: + """ + Prepares model for inference asynchronously. + + :param model: A model that should be prepared. + :return: AsyncPreparedModel opbject. + """ diff --git a/nncf/quantization/algorithms/accuracy_control/evaluator.py b/nncf/quantization/algorithms/accuracy_control/evaluator.py index 20c24347d25..daf4f2f3666 100644 --- a/nncf/quantization/algorithms/accuracy_control/evaluator.py +++ b/nncf/quantization/algorithms/accuracy_control/evaluator.py @@ -16,6 +16,7 @@ from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend TModel = TypeVar("TModel") +TPModel = TypeVar("TPModel") TTensor = TypeVar("TTensor") @@ -49,13 +50,22 @@ def is_metric_mode(self) -> bool: """ return self._metric_mode - def validate( - self, model: TModel, dataset: Dataset, indices: Optional[List[int]] = None - ) -> Tuple[float, Union[None, List[float], List[List[TTensor]]]]: + def prepare_model_for_inference(self, model: TModel) -> TPModel: """ - Validates model. + Prepares model for inference. - :param model: Model to validate. + :param model: A model that should be prepared. + :return: Prepared model for inference. + """ + return self._algo_backend.prepare_for_inference(model) + + def validate_model_for_inference( + self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None + ): + """ + Validates prepared model for inference. + + :param model: Prepared model to validate. :param dataset: Dataset to validate the model. :param indices: Zero-based indices of data items that should be selected from the dataset. @@ -66,7 +76,6 @@ def validate( Otherwise, if the condition is false, it represents list of logits for each item. """ - model_for_inference = self._algo_backend.prepare_for_inference(model) if self._metric_mode is None: self._determine_mode(model_for_inference, dataset) @@ -87,7 +96,27 @@ def validate( return float(metric), values_for_each_item - def _determine_mode(self, model_for_inference: TModel, dataset: Dataset) -> None: + def validate( + self, model: TModel, dataset: Dataset, indices: Optional[List[int]] = None + ) -> Tuple[float, Union[None, List[float], List[List[TTensor]]]]: + """ + Validates model. + + :param model: Model to validate. + :param dataset: Dataset to validate the model. + :param indices: Zero-based indices of data items that should be selected from + the dataset. + :return: A tuple (metric_value, values_for_each_item) where + - metric_values: This is a metric for the model. + - values_for_each_item: If the `Evaluator.is_metric_mode()` condition is true, + then `values_for_each_item` represents the list of metric value for each item. + Otherwise, if the condition is false, it represents list of logits for each + item. + """ + model_for_inference = self.prepare_model_for_inference(model) + return self.validate_model_for_inference(model_for_inference, dataset, indices) + + def _determine_mode(self, model_for_inference: TPModel, dataset: Dataset) -> None: """ Determines mode based on the type of returned value from the validation function. @@ -144,13 +173,13 @@ def _determine_mode(self, model_for_inference: TModel, dataset: Dataset) -> None elif values_for_each_item is not None and not isinstance(values_for_each_item[0], list): raise RuntimeError("Unexpected return value from provided validation function.") - def collect_values_for_each_item( - self, model: TModel, dataset: Dataset, indices: Optional[List[int]] = None + def collect_values_for_each_item_using_model_for_inference( + self, model_for_inference: TPModel, dataset: Dataset, indices: Optional[List[int]] = None ) -> Union[List[float], List[List[TTensor]]]: """ - Collects value for each item from the dataset. If `is_metric_mode()` - returns `True` then i-th value is a metric for i-th data item. It - is an output of the model for i-th data item otherwise. + Collects value for each item from the dataset using prepared model for inference. + If `is_metric_mode()` returns `True` then i-th value is a metric for i-th data item. + It is an output of the model for i-th data item otherwise. :param model: Model to infer. :param dataset: Dataset to collect values. @@ -160,13 +189,12 @@ def collect_values_for_each_item( """ if self._metric_mode: # Collect metrics for each item - model_for_inference = self._algo_backend.prepare_for_inference(model) values_for_each_item = [ self._validation_fn(model_for_inference, [data_item])[0] for data_item in dataset.get_data(indices) ] else: # Collect outputs for each item - engine = EngineFactory.create(model) + engine = EngineFactory.create(model_for_inference) values_for_each_item = [] for data_item in dataset.get_inference_data(indices): @@ -174,3 +202,20 @@ def collect_values_for_each_item( values_for_each_item.append(list(logits.values())) return values_for_each_item + + def collect_values_for_each_item( + self, model: TModel, dataset: Dataset, indices: Optional[List[int]] = None + ) -> Union[List[float], List[List[TTensor]]]: + """ + Collects value for each item from the dataset. If `is_metric_mode()` + returns `True` then i-th value is a metric for i-th data item. It + is an output of the model for i-th data item otherwise. + + :param model: A target model. + :param dataset: Dataset to collect values. + :param indices: The zero-based indices of data items that should be selected from + the dataset. + :return: Collected values. + """ + model_for_inference = self.prepare_model_for_inference(model) + return self.collect_values_for_each_item_using_model_for_inference(model_for_inference, dataset, indices) diff --git a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index ccdedd9ff59..29f99bdf992 100644 --- a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing from typing import Any, List, Optional import numpy as np @@ -29,6 +30,26 @@ from nncf.openvino.graph.node_utils import get_weight_value from nncf.openvino.graph.node_utils import is_node_with_bias from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend +from nncf.quantization.algorithms.accuracy_control.backend import AsyncPreparedModel + + +def compile_model(model: ov.Model, done_queue: multiprocessing.Queue) -> None: + compiled_model = ov.Core().compile_model(model, "CPU") + model_stream = compiled_model.export_model() + done_queue.put(model_stream) + + +class OVAsyncPreparedModel(AsyncPreparedModel): + def __init__(self, proc: multiprocessing.Process, done_queue: multiprocessing.Queue): + self.proc = proc + self.done_queue = done_queue + + def get(self, timeout=None) -> ov.CompiledModel: + try: + model_stream = self.done_queue.get(timeout=timeout) + except multiprocessing.TimeoutError as ex: + raise TimeoutError() from ex + return ov.Core().import_model(model_stream, "CPU") class OVAccuracyControlAlgoBackend(AccuracyControlAlgoBackend): @@ -80,8 +101,24 @@ def get_weight_value(node_with_weight: NNCFNode, model: ov.Model, port_id: int) def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: return node.layer_attributes.get_const_port_ids() + @staticmethod + def get_model_size(model: ov.Model) -> int: + model_size = 0 + for op in model.get_ops(): + if op.get_type_name() == "Constant": + model_size += op.data.nbytes + + return model_size + # Preparation of model @staticmethod def prepare_for_inference(model: ov.Model) -> Any: return ov.compile_model(model) + + @staticmethod + def prepare_for_inference_async(model: ov.Model) -> Any: + queue = multiprocessing.Queue() + p = multiprocessing.Process(target=compile_model, args=(model, queue)) + p.start() + return OVAsyncPreparedModel(p, queue) diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index f58f47f4f08..353362501ea 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -31,6 +31,7 @@ from nncf.quantization.passes import remove_shapeof_subgraphs TModel = TypeVar("TModel") +TPModel = TypeVar("TPModel") TTensor = TypeVar("TTensor") @@ -83,6 +84,7 @@ def __init__( dataset: Dataset, algo_backend: AccuracyControlAlgoBackend, evaluator: Evaluator, + num_processes: int = 1, ranking_fn: Optional[Callable[[Any, Any], float]] = None, ): """ @@ -105,6 +107,7 @@ def __init__( # because they don't change. So use this attribute to store # them to improve execution time. self._reference_values_for_each_item = None + self._num_processes = num_processes def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) -> List[GroupToRank]: """ @@ -195,21 +198,69 @@ def rank_groups_of_quantizers( nncf_logger.info("Calculating ranking score for groups of quantizers") with timer(): # Calculate ranking score for groups of quantizers. - ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] - for current_group in groups_to_rank: - modified_model = revert_operations_to_floating_point_precision( - current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph + if self._num_processes > 1: + ranking_scores = self._multiprocessing_calculation_ranking_score( + quantized_model, quantized_model_graph, groups_to_rank, ranking_subset_indices + ) + + else: + ranking_scores = self._sequential_calculation_ranking_score( + quantized_model, quantized_model_graph, groups_to_rank, ranking_subset_indices ) - # Calculate the ranking score for the current group of quantizers. - ranking_score = self._calculate_ranking_score(modified_model, ranking_subset_indices) - ranking_scores.append(float(ranking_score)) # Rank groups. ranked_groups = [group for _, group in sorted(zip(ranking_scores, groups_to_rank), key=operator.itemgetter(0))] return ranked_groups - def _calculate_ranking_score(self, modified_model: TModel, ranking_subset_indices: List[int]) -> float: + def _sequential_calculation_ranking_score( + self, + quantized_model: TModel, + quantized_model_graph: NNCFGraph, + groups_to_rank: List[GroupToRank], + ranking_subset_indices: List[int], + ): + ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] + for current_group in groups_to_rank: + modified_model = revert_operations_to_floating_point_precision( + current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph + ) + + prepared_model = self._algo_backend.prepare_for_inference(modified_model) + ranking_score = self._calculate_ranking_score(prepared_model, ranking_subset_indices) + ranking_scores.append(float(ranking_score)) + + return ranking_scores + + def _multiprocessing_calculation_ranking_score( + self, + quantized_model: TModel, + quantized_model_graph: NNCFGraph, + groups_to_rank: List[GroupToRank], + ranking_subset_indices: List[int], + ): + ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] + prepared_model_queue = [] + for idx, current_group in enumerate(groups_to_rank): + modified_model = revert_operations_to_floating_point_precision( + current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph + ) + + prepared_model_queue.append(self._algo_backend.prepare_for_inference_async(modified_model)) + + if idx >= (self._num_processes - 1): + prepared_model = prepared_model_queue.pop(0).get() + ranking_score = self._calculate_ranking_score(prepared_model, ranking_subset_indices) + ranking_scores.append(float(ranking_score)) + + for _ in range(self._num_processes - 1): + prepared_model = prepared_model_queue.pop(0).get() + ranking_score = self._calculate_ranking_score(prepared_model, ranking_subset_indices) + ranking_scores.append(float(ranking_score)) + + return ranking_scores + + def _calculate_ranking_score(self, prepared_model: TPModel, ranking_subset_indices: List[int]) -> float: """ Calculates the ranking score for the current group of quantizers. @@ -219,11 +270,13 @@ def _calculate_ranking_score(self, modified_model: TModel, ranking_subset_indice """ if self._evaluator.is_metric_mode(): # Calculate ranking score based on metric - ranking_score, _ = self._evaluator.validate(modified_model, self._dataset, ranking_subset_indices) + ranking_score, _ = self._evaluator.validate_model_for_inference( + prepared_model, self._dataset, ranking_subset_indices + ) else: # Calculate ranking score based on differences in logits - approximate_outputs = self._evaluator.collect_values_for_each_item( - modified_model, self._dataset, ranking_subset_indices + approximate_outputs = self._evaluator.collect_values_for_each_item_using_model_for_inference( + prepared_model, self._dataset, ranking_subset_indices ) reference_outputs = [self._reference_values_for_each_item[i] for i in ranking_subset_indices] errors = [self._ranking_fn(a, b) for a, b in zip(reference_outputs, approximate_outputs)] diff --git a/setup.py b/setup.py index 3ff312fcc18..7c3861c1b5f 100644 --- a/setup.py +++ b/setup.py @@ -122,6 +122,7 @@ def find_version(*file_paths): "pandas>=1.1.5,<2.1", "scikit-learn>=0.24.0", "openvino-telemetry", + "psutil", ] From c81790bc82d8d0581d8280fd3d143f86d9f06a71 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Tue, 18 Jul 2023 14:30:09 +0400 Subject: [PATCH 02/10] pylint happy --- nncf/common/utils/os.py | 4 +- .../algorithms/accuracy_control/algorithm.py | 157 +++++++++++++----- 2 files changed, 115 insertions(+), 46 deletions(-) diff --git a/nncf/common/utils/os.py b/nncf/common/utils/os.py index 2732f12d8ed..bbed93a5fdd 100644 --- a/nncf/common/utils/os.py +++ b/nncf/common/utils/os.py @@ -48,7 +48,7 @@ def available_cpu_count() -> int: """ try: return multiprocessing.cpu_count() - except Exception as e: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except return 1 @@ -58,5 +58,5 @@ def available_memory_amount() -> int: """ try: return psutil.virtual_memory()[1] - except Exception as e: # pylint: disable=broad-except + except Exception: # pylint: disable=broad-except return 0 diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index 543b8685b6a..665b18ae8ed 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -12,6 +12,8 @@ import sys from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union +from attr import dataclass + from nncf.common.factory import NNCFGraphFactory from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -134,6 +136,23 @@ def reverted_operations(self) -> List[NNCFNode]: return operations +@dataclass +class MetricResults: + """ + Results of metrics collection. + + :param metric_value: Aggregated metric value. + :param values_for_each_item: Metric values for each data item. + :param preperation_time: Time that it takes to prepare model for validation. + :param validation_time: Time that it takes to validate model. + """ + + metric_value: float + values_for_each_item: Union[None, List[float], List[List[TTensor]]] + preperation_time: float + validation_time: float + + class QuantizationAccuracyRestorer: """ Implementation of the accuracy-aware loop. @@ -192,20 +211,17 @@ def apply( # Validate initial and quantized model evaluator = Evaluator(validation_fn, algo_backend) - initial_metric, reference_values_for_each_item, _, _ = self._collect_metric_and_values( + initial_metric_results = self._collect_metric_and_values( initial_model, validation_dataset, evaluator, "initial" ) counting_validation_dataset = CountingDatasetWrapper(validation_dataset) - ( - quantized_metric, - approximate_values_for_each_item, - preperation_time, - validation_time, - ) = self._collect_metric_and_values(quantized_model, counting_validation_dataset, evaluator, "quantized") + quantized_metric_results = self._collect_metric_and_values( + quantized_model, counting_validation_dataset, evaluator, "quantized" + ) validation_dataset_size = counting_validation_dataset.num_iters should_terminate, accuracy_drop = calculate_accuracy_drop( - initial_metric, quantized_metric, self.max_drop, self.drop_type + initial_metric_results.metric_value, quantized_metric_results.metric_value, self.max_drop, self.drop_type ) if should_terminate: @@ -217,6 +233,48 @@ def apply( if accuracy_drop <= self.max_drop: return quantized_model + return self._apply( + initial_model, + quantized_model, + initial_metric_results, + quantized_metric_results, + algo_backend, + evaluator, + validation_dataset, + validation_dataset_size, + accuracy_drop, + ) + + def _apply( + self, + initial_model: TModel, + quantized_model: TModel, + initial_metric_results: MetricResults, + quantized_metric_results: MetricResults, + algo_backend: AccuracyControlAlgoBackend, + evaluator: Evaluator, + validation_dataset: Dataset, + validation_dataset_size: int, + accuracy_drop: float, + ) -> TModel: + """ + An internal function that implements an iterative approach to restoring the accuracy of + the quantized model by removing the groups of quantizers that contribute the most to + the drop in accuracy. + + :param initial_model: Initial model (not quantized). + :param quantized_model: Quantized model. + :param initial_metric_results: Initial model metrics. + :param quantized_metric_results: Quantized model metrics. + :param algo_backend: The `AccuracyControlAlgoBackend` algo backend. + :param evaluator: The instance of `Evaluator` to validate model and collect values + for each item from dataset. + :param validation_dataset: A dataset for the validation process. + :param validation_dataset_size: Validation dataset size. + :param accuracy_drop: Accuracy drop between initial and quantized models. + :return: The quantized model whose metric `final_metric` is satisfied + the maximum accuracy drop condition. + """ # Accuracy drop is greater than the maximum drop so we need to restore accuracy initial_model_graph = NNCFGraphFactory.create(initial_model) quantized_model_graph = NNCFGraphFactory.create(quantized_model) @@ -240,8 +298,11 @@ def apply( num_ranking_processes = self.num_ranking_processes if num_ranking_processes is None: model_size = algo_backend.get_model_size(quantized_model) - num_ranking_processes = self.compute_number_ranker_parallel_proc( - model_size, preperation_time, validation_time, validation_dataset_size, self.ranking_subset_size + num_ranking_processes = self._calculate_number_ranker_parallel_proc( + model_size, + quantized_metric_results.preperation_time, + quantized_metric_results.validation_time, + validation_dataset_size, ) nncf_logger.info(f"Number of parallel processes to rank quantized operations: {num_ranking_processes}") @@ -253,12 +314,12 @@ def apply( initial_model, quantized_model, quantized_model_graph, - reference_values_for_each_item, - approximate_values_for_each_item, + initial_metric_results.values_for_each_item, + quantized_metric_results.values_for_each_item, ) previous_model = quantized_model - previous_approximate_values_for_each_item = approximate_values_for_each_item + previous_approximate_values_for_each_item = quantized_metric_results.values_for_each_item previous_accuracy_drop = accuracy_drop current_model = None current_approximate_values_for_each_item = None @@ -292,7 +353,7 @@ def apply( ) should_terminate, current_accuracy_drop = calculate_accuracy_drop( - initial_metric, current_metric, self.max_drop, self.drop_type + initial_metric_results.metric_value, current_metric, self.max_drop, self.drop_type ) if not ranked_groups: @@ -336,7 +397,7 @@ def apply( initial_model, current_model, quantized_model_graph, - reference_values_for_each_item, + initial_metric_results.values_for_each_item, current_approximate_values_for_each_item, ) @@ -345,6 +406,41 @@ def apply( return current_model + def _calculate_number_ranker_parallel_proc( + self, + model_size: int, + preperation_time: float, + validation_time: float, + validation_dataset_size: int, + ) -> int: + """ + Calculate the number of parallel ranker processes + + :param model_size: Target model size. + :param preperation_time: The time it takes to prepare the model. + :param validation_time: The time it takes to validate the model. + :param validation_dataset_size: Validation dataset size. + :return: The number of parallel ranker processes + """ + if preperation_time < PREPARATION_MODEL_THRESHOLD: + return 1 + + # Calculate the number of parallel processes needed to override model preparation and + # metric calculation on the ranking subset + ranking_time = validation_time * self.ranking_subset_size / validation_dataset_size + n_proc = max(round(preperation_time / ranking_time * OVERHEAD_COEFFICIENT), 2) + + # Apply limitation by number of CPU cores + n_cores = available_cpu_count() + n_proc = max(min(n_proc, n_cores // 2), 1) + + # Apply limitation by memmory + ram = available_memory_amount() + n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT) + n_proc = max(min(n_proc, n_copies - 1), 1) + + return n_proc + @staticmethod def _collect_original_biases_and_weights( initial_model_graph: NNCFGraph, @@ -406,38 +502,11 @@ def _print_completion_message(accuracy_drop: float, drop_type: DropType) -> None @staticmethod def _collect_metric_and_values( model: TModel, dataset: Dataset, evaluator: Evaluator, model_name: str - ) -> Tuple[float, Union[None, List[float], List[List[TTensor]]]]: + ) -> MetricResults: nncf_logger.info(f"Validation of {model_name} model was started") with timer() as preperation_time: model_for_inference = evaluator.prepare_model_for_inference(model) with timer() as validation_time: metric, values_for_each_item = evaluator.validate_model_for_inference(model_for_inference, dataset) nncf_logger.info(f"Metric of {model_name} model: {metric}") - return metric, values_for_each_item, preperation_time(), validation_time() - - @staticmethod - def compute_number_ranker_parallel_proc( - model_size: int, - preperation_time: float, - validation_time: float, - validation_dataset_size: int, - ranking_subset_size: int, - ) -> int: - if preperation_time < PREPARATION_MODEL_THRESHOLD: - return 1 - - # Calculate the number of parallel processes needed to override model preparation and - # metric calculation on the ranking subset - ranking_time = validation_time * ranking_subset_size / validation_dataset_size - n_proc = max(round(preperation_time / ranking_time * OVERHEAD_COEFFICIENT), 2) - - # Apply limitation by number of CPU cores - n_cores = available_cpu_count() - n_proc = max(min(n_proc, n_cores // 2), 1) - - # Apply limitation by memmory - ram = available_memory_amount() - n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT) - n_proc = max(min(n_proc, n_copies - 1), 1) - - return n_proc + return MetricResults(metric, values_for_each_item, preperation_time(), validation_time()) From ffa15d92032543afeac457376a4db120f8872f6e Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Thu, 20 Jul 2023 15:38:54 +0400 Subject: [PATCH 03/10] replied to comments --- nncf/common/utils/backend.py | 6 +++--- nncf/common/utils/os.py | 4 ++-- .../algorithms/accuracy_control/algorithm.py | 13 ++++++------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index f4d3e47cd22..0cca0d0867c 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -27,7 +27,7 @@ def get_available_backends() -> List[BackendType]: """ Returns a list of available backends. - :return: A list of avauilable backends. + :return: A list of available backends. """ frameworks = [ ("torch", BackendType.TORCH), @@ -95,7 +95,7 @@ def is_openvino_model(model: TModel) -> bool: return isinstance(model, ov.Model) -def is_openvino_compiled_model(model: TModel): +def is_openvino_compiled_model(model: TModel) -> bool: """ Returns True if the model is an instance of openvino.runtime.CompiledModel, otherwise False. @@ -111,7 +111,7 @@ def get_backend(model: TModel) -> BackendType: """ Returns the NNCF backend name string inferred from the type of the model object passed into this function. - :param model: The framework-specific object representing the trainable model. + :param model: The framework-specific model. :return: A BackendType representing the correct NNCF backend to be used when working with the framework. """ available_backends = get_available_backends() diff --git a/nncf/common/utils/os.py b/nncf/common/utils/os.py index bbed93a5fdd..3d40f387a0a 100644 --- a/nncf/common/utils/os.py +++ b/nncf/common/utils/os.py @@ -42,7 +42,7 @@ def is_linux(): return "linux" in sys.platform -def available_cpu_count() -> int: +def get_available_cpu_count() -> int: """ :return: Logical CPU count """ @@ -52,7 +52,7 @@ def available_cpu_count() -> int: return 1 -def available_memory_amount() -> int: +def get_available_memory_amount() -> int: """ :return: Available memory amount (bytes) """ diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index 665b18ae8ed..0ac50565078 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -10,10 +10,9 @@ # limitations under the License. import sys +from dataclasses import dataclass from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union -from attr import dataclass - from nncf.common.factory import NNCFGraphFactory from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -22,8 +21,8 @@ from nncf.common.quantization.quantizer_removal import revert_operations_to_floating_point_precision from nncf.common.utils.backend import BackendType from nncf.common.utils.backend import get_backend -from nncf.common.utils.os import available_cpu_count -from nncf.common.utils.os import available_memory_amount +from nncf.common.utils.os import get_available_cpu_count +from nncf.common.utils.os import get_available_memory_amount from nncf.common.utils.timer import timer from nncf.data.dataset import CountingDatasetWrapper from nncf.data.dataset import Dataset @@ -428,14 +427,14 @@ def _calculate_number_ranker_parallel_proc( # Calculate the number of parallel processes needed to override model preparation and # metric calculation on the ranking subset ranking_time = validation_time * self.ranking_subset_size / validation_dataset_size - n_proc = max(round(preperation_time / ranking_time * OVERHEAD_COEFFICIENT), 2) + n_proc = max(round((preperation_time / ranking_time + 1) * OVERHEAD_COEFFICIENT), 2) # Apply limitation by number of CPU cores - n_cores = available_cpu_count() + n_cores = get_available_cpu_count() n_proc = max(min(n_proc, n_cores // 2), 1) # Apply limitation by memmory - ram = available_memory_amount() + ram = get_available_memory_amount() n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT) n_proc = max(min(n_proc, n_copies - 1), 1) From 9e4eb58e51ce27a33a2d9e372bbb4d0d2e05ce39 Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 24 Jul 2023 17:57:05 +0400 Subject: [PATCH 04/10] replied to comments --- nncf/data/dataset.py | 98 ++----------------- nncf/openvino/engine.py | 18 +--- .../algorithms/accuracy_control/algorithm.py | 9 +- .../algorithms/accuracy_control/evaluator.py | 54 +++++++++- 4 files changed, 71 insertions(+), 108 deletions(-) diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index d7a0f29c654..6bf1322c2ec 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod from typing import Callable, Generic, Iterable, List, Optional, TypeVar from nncf.common.utils.api_marker import api @@ -18,43 +17,21 @@ ModelInput = TypeVar("ModelInput") -class IDataset(Generic[DataItem, ModelInput]): +@api(canonical_alias="nncf.Dataset") +class Dataset(Generic[DataItem, ModelInput]): """ + Wrapper for passing custom user datasets into NNCF algorithms. + This class defines the interface by which compression algorithms retrieve data items from the passed data source object. These data items are used for different purposes, for example, model inference and model validation, based on the choice of the exact compression algorithm. - """ - @abstractmethod - def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: - """ - Returns the iterable object that contains selected data items from the data source as-is. - - :param indices: The zero-based indices of data items that should be selected from - the data source. The indices should be sorted in ascending order. If indices are - not passed all data items are selected from the data source. - :return: The iterable object that contains selected data items from the data source as-is. - """ - - @abstractmethod - def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]: - """ - Returns the iterable object that contains selected data items from the data source which - can be used as the model's input for model inference. - - :param indices: The zero-based indices of data items that should be selected from - the data source. The indices should be sorted in ascending order. If indices are - not passed all data items are selected from the data source. - :return: The iterable object that contains selected data items from the data source which - can be used as the model's input for model inference. - """ - - -@api(canonical_alias="nncf.Dataset") -class Dataset(IDataset): - """ - Wrapper for passing custom user datasets into NNCF algorithms. + If the data item has been returned from the data source per iteration and it cannot be + used as input for model inference, the transformation function is used to extract the + model's input from this data item. For example, in supervised learning, the data item + usually contains both examples and labels. So transformation function should extract + the examples from the data item. :param data_source: The iterable object serving as the source of data items. :param transform_func: The function that is used to extract the model's input @@ -96,63 +73,6 @@ def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[Mo return DataProvider(self._data_source, self._transform_func, indices) -class CountingDatasetWrapper(IDataset): - """ - Dataset wrapper for calculation number of iterations. - - :param dataset: The dataset. - """ - - def __init__(self, dataset: IDataset): - self._dataset = dataset - self._num_iters = 0 - - def tranform_func(x): - self._num_iters += 1 - return x - - self._transform_func = tranform_func - - @property - def num_iters(self): - """ - :return: The number of iterations performed in the last requested object. - """ - return self._num_iters - - def reset_num_iters(self): - """ - Resets iterations counter. - """ - self._num_iters = 0 - - def get_data(self, indices: Optional[List[int]] = None) -> Iterable[DataItem]: - """ - Returns the iterable object that contains selected data items from the data source as-is. - - :param indices: The zero-based indices of data items that should be selected from - the data source. The indices should be sorted in ascending order. If indices are - not passed all data items are selected from the data source. - :return: The iterable object that contains selected data items from the data source as-is. - """ - self.reset_num_iters() - return DataProvider(self._dataset.get_data(indices), self._transform_func, None) - - def get_inference_data(self, indices: Optional[List[int]] = None) -> Iterable[ModelInput]: - """ - Returns the iterable object that contains selected data items from the data source which - can be used as the model's input for model inference. - - :param indices: The zero-based indices of data items that should be selected from - the data source. The indices should be sorted in ascending order. If indices are - not passed all data items are selected from the data source. - :return: The iterable object that contains selected data items from the data source which - can be used as the model's input for model inference. - """ - self.reset_num_iters() - return DataProvider(self._dataset.get_inference_data(indices), self._transform_func, None) - - class DataProvider(Generic[DataItem, ModelInput]): def __init__( self, diff --git a/nncf/openvino/engine.py b/nncf/openvino/engine.py index f9964f3f750..decd31a6364 100644 --- a/nncf/openvino/engine.py +++ b/nncf/openvino/engine.py @@ -82,19 +82,12 @@ class OVNativeEngine(Engine): """ def __init__(self, model: ov.Model, target_device: TargetDevice = TargetDevice.CPU): - self.engine = None - self.model = model - self.target_device = target_device if target_device == TargetDevice.ANY: - self.target_device = TargetDevice.CPU + target_device = TargetDevice.CPU - def _get_compiled_model(self) -> ov.CompiledModel: - """ - Returns OpenVINO compiled model - - :return: A compiled model - """ - return ov.Core().compile_model(self.model, self.target_device.value) + ie = ov.Core() + compiled_model = ie.compile_model(model, target_device.value) + self.engine = OVCompiledModelEngine(compiled_model) def infer( self, input_data: Union[np.ndarray, List[np.ndarray], Tuple[np.ndarray], Dict[str, np.ndarray]] @@ -106,7 +99,4 @@ def infer( :param input_data: Inputs for the model. :return output_data: Model's output. """ - if self.engine is None: - compiled_model = self._get_compiled_model() - self.engine = OVCompiledModelEngine(compiled_model) return self.engine.infer(input_data) diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index 0ac50565078..7727b5ba5c1 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -24,7 +24,6 @@ from nncf.common.utils.os import get_available_cpu_count from nncf.common.utils.os import get_available_memory_amount from nncf.common.utils.timer import timer -from nncf.data.dataset import CountingDatasetWrapper from nncf.data.dataset import Dataset from nncf.parameters import DropType from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend @@ -213,11 +212,13 @@ def apply( initial_metric_results = self._collect_metric_and_values( initial_model, validation_dataset, evaluator, "initial" ) - counting_validation_dataset = CountingDatasetWrapper(validation_dataset) + + evaluator.enable_iteration_count() quantized_metric_results = self._collect_metric_and_values( - quantized_model, counting_validation_dataset, evaluator, "quantized" + quantized_model, validation_dataset, evaluator, "quantized" ) - validation_dataset_size = counting_validation_dataset.num_iters + validation_dataset_size = evaluator.num_passed_iterations + evaluator.disable_iteration_count() should_terminate, accuracy_drop = calculate_accuracy_drop( initial_metric_results.metric_value, quantized_metric_results.metric_value, self.max_drop, self.drop_type diff --git a/nncf/quantization/algorithms/accuracy_control/evaluator.py b/nncf/quantization/algorithms/accuracy_control/evaluator.py index daf4f2f3666..b5b8699f783 100644 --- a/nncf/quantization/algorithms/accuracy_control/evaluator.py +++ b/nncf/quantization/algorithms/accuracy_control/evaluator.py @@ -20,6 +20,26 @@ TTensor = TypeVar("TTensor") +class IterationCounter: + """ + A wrapper for counting the passed iterations of iterable objects. + """ + + def __init__(self, iterable): + self._iterable = iterable + self._num_iterations = 0 + + @property + def num_iterations(self) -> int: + return self._num_iterations + + def __iter__(self): + self._num_iterations = 0 + for x in self._iterable: + self._num_iterations += 1 + yield x + + class Evaluator: """ Evaluator encapsulates a logic to validate model and collect values for each item. @@ -39,6 +59,30 @@ def __init__( self._validation_fn = validation_fn self._algo_backend = algo_backend self._metric_mode = None + self._num_passed_iterations = 0 + self._enable_iteration_count = False + + @property + def num_passed_iterations(self) -> int: + """ + Number of passed iterations during last validation process if the iteration count is enabled. + + :return: Number of passed iterations during last validation process. + """ + + return self._num_passed_iterations + + def enable_iteration_count(self) -> None: + """ + Enable the iteration count. + """ + self._enable_iteration_count = True + + def disable_iteration_count(self) -> None: + """ + Disable the iteration count. + """ + self._enable_iteration_count = False def is_metric_mode(self) -> bool: """ @@ -82,7 +126,13 @@ def validate_model_for_inference( if not self.is_metric_mode() and indices is not None: raise ValueError("The `indices` parameter can be used only if Evaluator.is_metric_mode() = True") - metric, values_for_each_item = self._validation_fn(model_for_inference, dataset.get_data(indices)) + validation_dataset = dataset.get_data(indices) + if self._enable_iteration_count: + validation_dataset = IterationCounter(validation_dataset) + + metric, values_for_each_item = self._validation_fn(model_for_inference, validation_dataset) + + self._num_passed_iterations = validation_dataset.num_iterations if self._enable_iteration_count else 0 if self.is_metric_mode() and values_for_each_item is not None: # This casting is necessary to cover the following cases: @@ -201,6 +251,8 @@ def collect_values_for_each_item_using_model_for_inference( logits = engine.infer(data_item) values_for_each_item.append(list(logits.values())) + self._num_passed_iterations = len(values_for_each_item) if self._enable_iteration_count else 0 + return values_for_each_item def collect_values_for_each_item( From bf1513bd8f609002efd8961b975c7b930cc29f56 Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Wed, 6 Sep 2023 15:30:48 +0200 Subject: [PATCH 05/10] debug --- .../algorithms/accuracy_control/openvino_backend.py | 9 ++++++++- nncf/quantization/algorithms/accuracy_control/ranker.py | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index 29f99bdf992..fafb885ac58 100644 --- a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -34,7 +34,12 @@ def compile_model(model: ov.Model, done_queue: multiprocessing.Queue) -> None: - compiled_model = ov.Core().compile_model(model, "CPU") + ov_core = ov.Core() + #ov_core.set_property("CPU", {"COMPILATION_NUM_THREADS": 8}) + ov_core.set_property("CPU", {"INFERENCE_NUM_THREADS": 8}) + #print(ov_core.get_property("CPU", "COMPILATION_NUM_THREADS")) + #print(ov_core.get_property("CPU", "INFERENCE_NUM_THREADS")) + compiled_model = ov_core.compile_model(model, "CPU") model_stream = compiled_model.export_model() done_queue.put(model_stream) @@ -114,10 +119,12 @@ def get_model_size(model: ov.Model) -> int: @staticmethod def prepare_for_inference(model: ov.Model) -> Any: + print("\n----\n----\IM RUNNING prepare_for_inference\n----\n----\n") return ov.compile_model(model) @staticmethod def prepare_for_inference_async(model: ov.Model) -> Any: + print("\n----\n----\IM RUNNING prepare_for_inference_async\n----\n----\n") queue = multiprocessing.Queue() p = multiprocessing.Process(target=compile_model, args=(model, queue)) p.start() diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index 353362501ea..c690342f5bc 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -84,7 +84,7 @@ def __init__( dataset: Dataset, algo_backend: AccuracyControlAlgoBackend, evaluator: Evaluator, - num_processes: int = 1, + num_processes: int = 5, ranking_fn: Optional[Callable[[Any, Any], float]] = None, ): """ @@ -196,6 +196,7 @@ def rank_groups_of_quantizers( ranking_subset_indices = get_ranking_subset_indices_pot_version(scores, self._ranking_subset_size) nncf_logger.info("Calculating ranking score for groups of quantizers") + print(f"\n\nnum_processes: {self._num_processes}\n\n") with timer(): # Calculate ranking score for groups of quantizers. if self._num_processes > 1: @@ -220,6 +221,7 @@ def _sequential_calculation_ranking_score( groups_to_rank: List[GroupToRank], ranking_subset_indices: List[int], ): + print("\n\nIM RUNNING SEQUENTIAL RANKER\n\n") ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] for current_group in groups_to_rank: modified_model = revert_operations_to_floating_point_precision( @@ -239,6 +241,7 @@ def _multiprocessing_calculation_ranking_score( groups_to_rank: List[GroupToRank], ranking_subset_indices: List[int], ): + print("\n\nIM RUNNING MULTIPROCESSING RANKER\n\n") ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] prepared_model_queue = [] for idx, current_group in enumerate(groups_to_rank): From c8d2b446dbf869312a2ec15c0e6072d73ed6abfe Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Fri, 8 Sep 2023 11:05:45 +0200 Subject: [PATCH 06/10] wip --- nncf/quantization/algorithms/accuracy_control/ranker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index c690342f5bc..65c728c7c8a 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -84,7 +84,7 @@ def __init__( dataset: Dataset, algo_backend: AccuracyControlAlgoBackend, evaluator: Evaluator, - num_processes: int = 5, + num_processes: int = 1, ranking_fn: Optional[Callable[[Any, Any], float]] = None, ): """ From babd750b3c1037847e18d21f3c2ec7540526a31d Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Mon, 11 Sep 2023 13:43:24 +0200 Subject: [PATCH 07/10] cleanup --- .../accuracy_control/openvino_backend.py | 21 +++++-------- .../algorithms/accuracy_control/ranker.py | 30 +++++++++++++++++-- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index fafb885ac58..ba189634bc4 100644 --- a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -10,6 +10,7 @@ # limitations under the License. import multiprocessing +import concurrent.futures from typing import Any, List, Optional import numpy as np @@ -33,15 +34,10 @@ from nncf.quantization.algorithms.accuracy_control.backend import AsyncPreparedModel -def compile_model(model: ov.Model, done_queue: multiprocessing.Queue) -> None: +def compile_model(model: ov.Model) -> None: ov_core = ov.Core() - #ov_core.set_property("CPU", {"COMPILATION_NUM_THREADS": 8}) - ov_core.set_property("CPU", {"INFERENCE_NUM_THREADS": 8}) - #print(ov_core.get_property("CPU", "COMPILATION_NUM_THREADS")) - #print(ov_core.get_property("CPU", "INFERENCE_NUM_THREADS")) compiled_model = ov_core.compile_model(model, "CPU") - model_stream = compiled_model.export_model() - done_queue.put(model_stream) + return compiled_model class OVAsyncPreparedModel(AsyncPreparedModel): @@ -119,13 +115,10 @@ def get_model_size(model: ov.Model) -> int: @staticmethod def prepare_for_inference(model: ov.Model) -> Any: - print("\n----\n----\IM RUNNING prepare_for_inference\n----\n----\n") return ov.compile_model(model) @staticmethod - def prepare_for_inference_async(model: ov.Model) -> Any: - print("\n----\n----\IM RUNNING prepare_for_inference_async\n----\n----\n") - queue = multiprocessing.Queue() - p = multiprocessing.Process(target=compile_model, args=(model, queue)) - p.start() - return OVAsyncPreparedModel(p, queue) + def prepare_for_inference_async(models: ov.Model) -> Any: + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + results = [i for i in executor.map(compile_model, models)] + return results diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index 65c728c7c8a..c3b95dca4ce 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -15,6 +15,7 @@ from typing import Any, Callable, List, Optional, TypeVar, Union import numpy as np +import time from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode @@ -200,7 +201,7 @@ def rank_groups_of_quantizers( with timer(): # Calculate ranking score for groups of quantizers. if self._num_processes > 1: - ranking_scores = self._multiprocessing_calculation_ranking_score( + ranking_scores = self._multithreading_calculation_ranking_score( quantized_model, quantized_model_graph, groups_to_rank, ranking_subset_indices ) @@ -221,7 +222,6 @@ def _sequential_calculation_ranking_score( groups_to_rank: List[GroupToRank], ranking_subset_indices: List[int], ): - print("\n\nIM RUNNING SEQUENTIAL RANKER\n\n") ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] for current_group in groups_to_rank: modified_model = revert_operations_to_floating_point_precision( @@ -241,7 +241,6 @@ def _multiprocessing_calculation_ranking_score( groups_to_rank: List[GroupToRank], ranking_subset_indices: List[int], ): - print("\n\nIM RUNNING MULTIPROCESSING RANKER\n\n") ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] prepared_model_queue = [] for idx, current_group in enumerate(groups_to_rank): @@ -263,6 +262,31 @@ def _multiprocessing_calculation_ranking_score( return ranking_scores + def _multithreading_calculation_ranking_score( + self, + quantized_model: TModel, + quantized_model_graph: NNCFGraph, + groups_to_rank: List[GroupToRank], + ranking_subset_indices: List[int], + ): + + ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] + modified_models = [] + for current_group in groups_to_rank: + modified_model = revert_operations_to_floating_point_precision( + current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph + ) + + modified_models.append(modified_model) + + results = self._algo_backend.prepare_for_inference_async(modified_models) + + for model in results: + ranking_score = self._calculate_ranking_score(model, ranking_subset_indices) + ranking_scores.append(float(ranking_score)) + + return ranking_scores + def _calculate_ranking_score(self, prepared_model: TPModel, ranking_subset_indices: List[int]) -> float: """ Calculates the ranking score for the current group of quantizers. From 968d2a32c193ec76d3cdaef19d28dc74731cc52c Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Mon, 11 Sep 2023 13:45:45 +0200 Subject: [PATCH 08/10] Remove unnecessary import --- nncf/quantization/algorithms/accuracy_control/ranker.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index c3b95dca4ce..d41c0af39ef 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -15,7 +15,6 @@ from typing import Any, Callable, List, Optional, TypeVar, Union import numpy as np -import time from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode From 82302083e84b8ea3250f66fe67cd61b842c3d23d Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Mon, 11 Sep 2023 13:56:44 +0200 Subject: [PATCH 09/10] cleanup --- nncf/common/utils/os.py | 1 - .../algorithms/accuracy_control/algorithm.py | 18 ------------------ 2 files changed, 19 deletions(-) diff --git a/nncf/common/utils/os.py b/nncf/common/utils/os.py index 61be7880899..fb6f146304b 100644 --- a/nncf/common/utils/os.py +++ b/nncf/common/utils/os.py @@ -8,7 +8,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing import sys from contextlib import contextmanager from pathlib import Path diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index 34b4f93cb6c..a04575d4fb8 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -134,23 +134,6 @@ def reverted_operations(self) -> List[NNCFNode]: return operations -@dataclass -class MetricResults: - """ - Results of metrics collection. - - :param metric_value: Aggregated metric value. - :param values_for_each_item: Metric values for each data item. - :param preperation_time: Time that it takes to prepare model for validation. - :param validation_time: Time that it takes to validate model. - """ - - metric_value: float - values_for_each_item: Union[None, List[float], List[List[TTensor]]] - preperation_time: float - validation_time: float - - class QuantizationAccuracyRestorer: """ Implementation of the accuracy-aware loop. @@ -179,7 +162,6 @@ def __init__( self.max_num_iterations = max_num_iterations self.max_drop = max_drop self.drop_type = drop_type - self.num_ranking_processes = num_ranking_processes if is_windows(): self.num_ranking_processes = 1 From 023267b37a67a2761549b1ee832abd9d2d90105a Mon Sep 17 00:00:00 2001 From: p-wysocki Date: Mon, 11 Sep 2023 15:04:05 +0200 Subject: [PATCH 10/10] Minor changes --- .../algorithms/accuracy_control/openvino_backend.py | 2 +- nncf/quantization/algorithms/accuracy_control/ranker.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index 07d64c2f984..461a9701d7c 100644 --- a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -118,7 +118,7 @@ def prepare_for_inference(model: ov.Model) -> Any: return ov.compile_model(model) @staticmethod - def prepare_for_inference_async(models: ov.Model) -> Any: + def prepare_for_inference_async(models: ov.Model, max_workers: int=20) -> Any: with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: results = [i for i in executor.map(compile_model, models)] return results diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index fed061deb25..476cca22b86 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -148,7 +148,6 @@ def rank_groups_of_quantizers( ) nncf_logger.info("Calculating ranking score for groups of quantizers") - print(f"\n\nnum_processes: {self._num_processes}\n\n") with timer(): # Calculate ranking score for groups of quantizers. if self._num_processes > 1: