diff --git a/nncf/openvino/quantization/quantize_model.py b/nncf/openvino/quantization/quantize_model.py index cefa084cd7d..2faac879d86 100644 --- a/nncf/openvino/quantization/quantize_model.py +++ b/nncf/openvino/quantization/quantize_model.py @@ -244,7 +244,7 @@ def native_quantize_with_accuracy_control_impl( advanced_accuracy_restorer_parameters.max_num_iterations, max_drop, drop_type, - advanced_accuracy_restorer_parameters.num_ranking_processes, + advanced_accuracy_restorer_parameters.num_ranking_workers, ) quantized_model = accuracy_restorer.apply( model, diff --git a/nncf/quantization/advanced_parameters.py b/nncf/quantization/advanced_parameters.py index a8ed96c8ff8..d0409fab8ea 100644 --- a/nncf/quantization/advanced_parameters.py +++ b/nncf/quantization/advanced_parameters.py @@ -190,9 +190,9 @@ class AdvancedAccuracyRestorerParameters: :param ranking_subset_size: Size of a subset that is used to rank layers by their contribution to the accuracy drop. :type ranking_subset_size: Optional[int] - :param num_ranking_processes: The number of parallel processes that are used to rank + :param num_ranking_workers: The number of parallel workers that are used to rank quantization operations. - :type num_ranking_processes: Optional[int] + :type num_ranking_workers: Optional[int] :param intermediate_model_dir: Path to the folder where the model, which was fully quantized with initial parameters, should be saved. :type intermediate_model_dir: Optional[str] @@ -201,7 +201,7 @@ class AdvancedAccuracyRestorerParameters: max_num_iterations: int = sys.maxsize tune_hyperparams: bool = False ranking_subset_size: Optional[int] = None - num_ranking_processes: Optional[int] = None + num_ranking_workers: Optional[int] = None intermediate_model_dir: Optional[str] = None diff --git a/nncf/quantization/algorithms/accuracy_control/algorithm.py b/nncf/quantization/algorithms/accuracy_control/algorithm.py index a04575d4fb8..4c479c16af5 100644 --- a/nncf/quantization/algorithms/accuracy_control/algorithm.py +++ b/nncf/quantization/algorithms/accuracy_control/algorithm.py @@ -22,7 +22,6 @@ from nncf.common.utils.backend import get_backend from nncf.common.utils.os import get_available_cpu_count from nncf.common.utils.os import get_available_memory_amount -from nncf.common.utils.os import is_windows from nncf.data.dataset import Dataset from nncf.parameters import DropType from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend @@ -34,7 +33,7 @@ TTensor = TypeVar("TTensor") PREPARATION_MODEL_THRESHOLD = 1 OVERHEAD_COEFFICIENT = 2 -MEMORY_INCREASE_COEFFICIENT = 4 +MEMORY_INCREASE_COEFFICIENT = 2 def get_algo_backend(backend: BackendType) -> AccuracyControlAlgoBackend: @@ -145,7 +144,7 @@ def __init__( max_num_iterations: int = sys.maxsize, max_drop: float = 0.01, drop_type: DropType = DropType.ABSOLUTE, - num_ranking_processes: Optional[int] = None, + num_ranking_workers: Optional[int] = None, ): """ :param ranking_subset_size: The number of data items that will be selected from @@ -155,23 +154,14 @@ def __init__( :param drop_type: The accuracy drop type, which determines how the maximum accuracy drop between the original model and the compressed model is calculated. - :param num_ranking_processes: The number of parallel processes that are used to rank + :param num_ranking_workers: The number of parallel workers that are used to rank quantization operations. """ self.ranking_subset_size = ranking_subset_size self.max_num_iterations = max_num_iterations self.max_drop = max_drop self.drop_type = drop_type - - if is_windows(): - self.num_ranking_processes = 1 - if num_ranking_processes is not None and num_ranking_processes > 1: - nncf_logger.info( - "Number of parallel processes to rank quantized operations > 1 is not supported on Windows OS. " - "num_ranking_processes = 1 will be used." - ) - else: - self.num_ranking_processes = num_ranking_processes + self.num_ranking_workers = num_ranking_workers def apply( self, @@ -272,19 +262,19 @@ def _apply( nncf_logger.info(f"Total number of quantized operations in the model: {report.num_quantized_operations}") # Calculate number of parallel processes for Ranker - num_ranking_processes = self.num_ranking_processes - if num_ranking_processes is None: + num_ranking_workers = self.num_ranking_workers + if num_ranking_workers is None: model_size = algo_backend.get_model_size(quantized_model) - num_ranking_processes = self._calculate_number_ranker_parallel_proc( + num_ranking_workers = self._calculate_number_ranker_workers( model_size, quantized_metric_results.preparation_time, quantized_metric_results.validation_time, validation_dataset_size, ) - nncf_logger.info(f"Number of parallel processes to rank quantized operations: {num_ranking_processes}") + nncf_logger.info(f"Number of parallel workers to rank quantized operations: {num_ranking_workers}") - ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator, num_ranking_processes) + ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator, num_ranking_workers) groups_to_rank = ranker.find_groups_of_quantizers_to_rank(quantized_model_graph) ranked_groups = ranker.rank_groups_of_quantizers( groups_to_rank, @@ -386,7 +376,7 @@ def _apply( return current_model - def _calculate_number_ranker_parallel_proc( + def _calculate_number_ranker_workers( self, model_size: int, preparation_time: float, @@ -394,13 +384,13 @@ def _calculate_number_ranker_parallel_proc( validation_dataset_size: int, ) -> int: """ - Calculate the number of parallel ranker processes + Calculate the number of parallel ranker workers :param model_size: Target model size. :param preparation_time: The time it takes to prepare the model. :param validation_time: The time it takes to validate the model. :param validation_dataset_size: Validation dataset size. - :return: The number of parallel ranker processes + :return: The number of parallel ranker workers """ if preparation_time < PREPARATION_MODEL_THRESHOLD: return 1 @@ -408,18 +398,18 @@ def _calculate_number_ranker_parallel_proc( # Calculate the number of parallel processes needed to override model preparation and # metric calculation on the ranking subset ranking_time = validation_time * self.ranking_subset_size / validation_dataset_size - n_proc = max(round((preparation_time / ranking_time + 1) * OVERHEAD_COEFFICIENT), 2) + n_workers = max(round((preparation_time / ranking_time + 1) * OVERHEAD_COEFFICIENT), 2) # Apply limitation by number of CPU cores n_cores = get_available_cpu_count(logical=True) - n_proc = max(min(n_proc, n_cores // 2), 1) + n_workers = max(min(n_workers, n_cores // 2), 1) # Apply limitation by memory ram = get_available_memory_amount() n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT) - n_proc = max(min(n_proc, n_copies - 1), 1) + n_workers = max(min(n_workers, n_copies - 1), 1) - return n_proc + return n_workers @staticmethod def _collect_original_biases_and_weights( diff --git a/nncf/quantization/algorithms/accuracy_control/backend.py b/nncf/quantization/algorithms/accuracy_control/backend.py index ef2ab709670..2133673eb9b 100644 --- a/nncf/quantization/algorithms/accuracy_control/backend.py +++ b/nncf/quantization/algorithms/accuracy_control/backend.py @@ -21,19 +21,6 @@ TPModel = TypeVar("TPModel") -class AsyncPreparedModel(ABC): - @abstractmethod - def get(self, timeout) -> TPModel: - """ - Returns the prepared model for inference when it arrives. If timeout is not None and - the result does not arrive within timeout seconds then TimeoutError is raised. If - the remote call raised an exception then that exception will be reraised by get(). - - :param timeout: timeout - :return: A prepared model for inference - """ - - class AccuracyControlAlgoBackend(ABC): # Metatypes @@ -162,13 +149,3 @@ def prepare_for_inference(model: TModel) -> TPModel: :param model: A model that should be prepared. :return: Prepared model for inference. """ - - @staticmethod - @abstractmethod - def prepare_for_inference_async(model: TModel) -> AsyncPreparedModel: - """ - Prepares model for inference asynchronously. - - :param model: A model that should be prepared. - :return: AsyncPreparedModel opbject. - """ diff --git a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py index c5c3190d4b4..c8344f4fec7 100644 --- a/nncf/quantization/algorithms/accuracy_control/openvino_backend.py +++ b/nncf/quantization/algorithms/accuracy_control/openvino_backend.py @@ -9,8 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing -from typing import Any, List, Optional +from typing import List, Optional import numpy as np import openvino.runtime as ov @@ -30,26 +29,6 @@ from nncf.openvino.graph.node_utils import get_weight_value from nncf.openvino.graph.node_utils import is_node_with_bias from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend -from nncf.quantization.algorithms.accuracy_control.backend import AsyncPreparedModel - - -def compile_model(model: ov.Model, done_queue: multiprocessing.Queue) -> None: - compiled_model = ov.Core().compile_model(model, "CPU") - model_stream = compiled_model.export_model() - done_queue.put(model_stream) - - -class OVAsyncPreparedModel(AsyncPreparedModel): - def __init__(self, proc: multiprocessing.Process, done_queue: multiprocessing.Queue): - self.proc = proc - self.done_queue = done_queue - - def get(self, timeout=None) -> ov.CompiledModel: - try: - model_stream = self.done_queue.get(timeout=timeout) - except multiprocessing.TimeoutError as ex: - raise TimeoutError() from ex - return ov.Core().import_model(model_stream, "CPU") class OVAccuracyControlAlgoBackend(AccuracyControlAlgoBackend): @@ -113,12 +92,5 @@ def get_model_size(model: ov.Model) -> int: # Preparation of model @staticmethod - def prepare_for_inference(model: ov.Model) -> Any: + def prepare_for_inference(model: ov.Model) -> ov.CompiledModel: return ov.compile_model(model) - - @staticmethod - def prepare_for_inference_async(model: ov.Model) -> Any: - queue = multiprocessing.Queue() - p = multiprocessing.Process(target=compile_model, args=(model, queue)) - p.start() - return OVAsyncPreparedModel(p, queue) diff --git a/nncf/quantization/algorithms/accuracy_control/ranker.py b/nncf/quantization/algorithms/accuracy_control/ranker.py index 97ed8a9d91e..58e78807685 100644 --- a/nncf/quantization/algorithms/accuracy_control/ranker.py +++ b/nncf/quantization/algorithms/accuracy_control/ranker.py @@ -10,6 +10,7 @@ # limitations under the License. import operator +from concurrent.futures import ThreadPoolExecutor from copy import deepcopy from dataclasses import dataclass from typing import Any, Callable, List, Optional, TypeVar, Union @@ -59,7 +60,7 @@ def __init__( dataset: Dataset, algo_backend: AccuracyControlAlgoBackend, evaluator: Evaluator, - num_processes: int = 1, + num_workers: int = 1, ranking_fn: Optional[Callable[[Any, Any], float]] = None, ): """ @@ -70,6 +71,8 @@ def __init__( :param dataset: Dataset for the ranking process. :param algo_backend: The `AccuracyControlAlgoBackend` algo backend. :param evaluator: Evaluator to validate model. + :param num_workers: The number of parallel workers that are used to rank quantization + operations. :param ranking_fn: a function that compares values returned by `Evaluator.collect_values_for_each_item()` method for initial and quantized model. """ @@ -78,7 +81,7 @@ def __init__( self._algo_backend = algo_backend self._evaluator = evaluator self._ranking_fn = ranking_fn - self._num_processes = num_processes + self._num_workers = num_workers def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) -> List[GroupToRank]: """ @@ -150,8 +153,8 @@ def rank_groups_of_quantizers( nncf_logger.info("Calculating ranking score for groups of quantizers") with timer(): # Calculate ranking score for groups of quantizers. - if self._num_processes > 1: - ranking_scores = self._multiprocessing_calculation_ranking_score( + if self._num_workers > 1: + ranking_scores = self._multithreading_calculation_ranking_score( quantized_model, quantized_model_graph, groups_to_rank, @@ -195,7 +198,7 @@ def _sequential_calculation_ranking_score( return ranking_scores - def _multiprocessing_calculation_ranking_score( + def _multithreading_calculation_ranking_score( self, quantized_model: TModel, quantized_model_graph: NNCFGraph, @@ -205,22 +208,23 @@ def _multiprocessing_calculation_ranking_score( ): ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i] prepared_model_queue = [] + executor = ThreadPoolExecutor(max_workers=self._num_workers) for idx, current_group in enumerate(groups_to_rank): modified_model = revert_operations_to_floating_point_precision( current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph ) - prepared_model_queue.append(self._algo_backend.prepare_for_inference_async(modified_model)) + prepared_model_queue.append(executor.submit(self._algo_backend.prepare_for_inference, modified_model)) - if idx >= (self._num_processes - 1): - prepared_model = prepared_model_queue.pop(0).get() + if idx >= (self._num_workers - 1): + prepared_model = prepared_model_queue.pop(0).result() ranking_score = self._calculate_ranking_score( prepared_model, ranking_subset_indices, reference_values_for_each_item ) ranking_scores.append(float(ranking_score)) - for _ in range(self._num_processes - 1): - prepared_model = prepared_model_queue.pop(0).get() + for _ in range(self._num_workers - 1): + prepared_model = prepared_model_queue.pop(0).result() ranking_score = self._calculate_ranking_score( prepared_model, ranking_subset_indices, reference_values_for_each_item )