Skip to content

Commit

Permalink
Move the ranker to a multi-threaded approach (openvinotoolkit#2134)
Browse files Browse the repository at this point in the history
### Changes

- Use multithreading instead of multiprocessing to calculate quantizer
ranking score
- Rename num_ranking_process to num_ranking_workers

### Reason for changes

- Support parallel calculations of quantizer ranking score for Windows
- Introducing more general name of parameter of number parallel workers.

### Related tickets

ref: 119274

### Tests

N/A
  • Loading branch information
alexsu52 authored Sep 25, 2023
1 parent b95e71c commit f43e933
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 93 deletions.
2 changes: 1 addition & 1 deletion nncf/openvino/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def native_quantize_with_accuracy_control_impl(
advanced_accuracy_restorer_parameters.max_num_iterations,
max_drop,
drop_type,
advanced_accuracy_restorer_parameters.num_ranking_processes,
advanced_accuracy_restorer_parameters.num_ranking_workers,
)
quantized_model = accuracy_restorer.apply(
model,
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ class AdvancedAccuracyRestorerParameters:
:param ranking_subset_size: Size of a subset that is used to rank layers by their
contribution to the accuracy drop.
:type ranking_subset_size: Optional[int]
:param num_ranking_processes: The number of parallel processes that are used to rank
:param num_ranking_workers: The number of parallel workers that are used to rank
quantization operations.
:type num_ranking_processes: Optional[int]
:type num_ranking_workers: Optional[int]
:param intermediate_model_dir: Path to the folder where the model, which was fully
quantized with initial parameters, should be saved.
:type intermediate_model_dir: Optional[str]
Expand All @@ -201,7 +201,7 @@ class AdvancedAccuracyRestorerParameters:
max_num_iterations: int = sys.maxsize
tune_hyperparams: bool = False
ranking_subset_size: Optional[int] = None
num_ranking_processes: Optional[int] = None
num_ranking_workers: Optional[int] = None
intermediate_model_dir: Optional[str] = None


Expand Down
42 changes: 16 additions & 26 deletions nncf/quantization/algorithms/accuracy_control/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nncf.common.utils.backend import get_backend
from nncf.common.utils.os import get_available_cpu_count
from nncf.common.utils.os import get_available_memory_amount
from nncf.common.utils.os import is_windows
from nncf.data.dataset import Dataset
from nncf.parameters import DropType
from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend
Expand All @@ -34,7 +33,7 @@
TTensor = TypeVar("TTensor")
PREPARATION_MODEL_THRESHOLD = 1
OVERHEAD_COEFFICIENT = 2
MEMORY_INCREASE_COEFFICIENT = 4
MEMORY_INCREASE_COEFFICIENT = 2


def get_algo_backend(backend: BackendType) -> AccuracyControlAlgoBackend:
Expand Down Expand Up @@ -145,7 +144,7 @@ def __init__(
max_num_iterations: int = sys.maxsize,
max_drop: float = 0.01,
drop_type: DropType = DropType.ABSOLUTE,
num_ranking_processes: Optional[int] = None,
num_ranking_workers: Optional[int] = None,
):
"""
:param ranking_subset_size: The number of data items that will be selected from
Expand All @@ -155,23 +154,14 @@ def __init__(
:param drop_type: The accuracy drop type, which determines how the maximum
accuracy drop between the original model and the compressed model is
calculated.
:param num_ranking_processes: The number of parallel processes that are used to rank
:param num_ranking_workers: The number of parallel workers that are used to rank
quantization operations.
"""
self.ranking_subset_size = ranking_subset_size
self.max_num_iterations = max_num_iterations
self.max_drop = max_drop
self.drop_type = drop_type

if is_windows():
self.num_ranking_processes = 1
if num_ranking_processes is not None and num_ranking_processes > 1:
nncf_logger.info(
"Number of parallel processes to rank quantized operations > 1 is not supported on Windows OS. "
"num_ranking_processes = 1 will be used."
)
else:
self.num_ranking_processes = num_ranking_processes
self.num_ranking_workers = num_ranking_workers

def apply(
self,
Expand Down Expand Up @@ -272,19 +262,19 @@ def _apply(
nncf_logger.info(f"Total number of quantized operations in the model: {report.num_quantized_operations}")

# Calculate number of parallel processes for Ranker
num_ranking_processes = self.num_ranking_processes
if num_ranking_processes is None:
num_ranking_workers = self.num_ranking_workers
if num_ranking_workers is None:
model_size = algo_backend.get_model_size(quantized_model)
num_ranking_processes = self._calculate_number_ranker_parallel_proc(
num_ranking_workers = self._calculate_number_ranker_workers(
model_size,
quantized_metric_results.preparation_time,
quantized_metric_results.validation_time,
validation_dataset_size,
)

nncf_logger.info(f"Number of parallel processes to rank quantized operations: {num_ranking_processes}")
nncf_logger.info(f"Number of parallel workers to rank quantized operations: {num_ranking_workers}")

ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator, num_ranking_processes)
ranker = Ranker(self.ranking_subset_size, validation_dataset, algo_backend, evaluator, num_ranking_workers)
groups_to_rank = ranker.find_groups_of_quantizers_to_rank(quantized_model_graph)
ranked_groups = ranker.rank_groups_of_quantizers(
groups_to_rank,
Expand Down Expand Up @@ -386,40 +376,40 @@ def _apply(

return current_model

def _calculate_number_ranker_parallel_proc(
def _calculate_number_ranker_workers(
self,
model_size: int,
preparation_time: float,
validation_time: float,
validation_dataset_size: int,
) -> int:
"""
Calculate the number of parallel ranker processes
Calculate the number of parallel ranker workers
:param model_size: Target model size.
:param preparation_time: The time it takes to prepare the model.
:param validation_time: The time it takes to validate the model.
:param validation_dataset_size: Validation dataset size.
:return: The number of parallel ranker processes
:return: The number of parallel ranker workers
"""
if preparation_time < PREPARATION_MODEL_THRESHOLD:
return 1

# Calculate the number of parallel processes needed to override model preparation and
# metric calculation on the ranking subset
ranking_time = validation_time * self.ranking_subset_size / validation_dataset_size
n_proc = max(round((preparation_time / ranking_time + 1) * OVERHEAD_COEFFICIENT), 2)
n_workers = max(round((preparation_time / ranking_time + 1) * OVERHEAD_COEFFICIENT), 2)

# Apply limitation by number of CPU cores
n_cores = get_available_cpu_count(logical=True)
n_proc = max(min(n_proc, n_cores // 2), 1)
n_workers = max(min(n_workers, n_cores // 2), 1)

# Apply limitation by memory
ram = get_available_memory_amount()
n_copies = ram // (model_size * MEMORY_INCREASE_COEFFICIENT)
n_proc = max(min(n_proc, n_copies - 1), 1)
n_workers = max(min(n_workers, n_copies - 1), 1)

return n_proc
return n_workers

@staticmethod
def _collect_original_biases_and_weights(
Expand Down
23 changes: 0 additions & 23 deletions nncf/quantization/algorithms/accuracy_control/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,6 @@
TPModel = TypeVar("TPModel")


class AsyncPreparedModel(ABC):
@abstractmethod
def get(self, timeout) -> TPModel:
"""
Returns the prepared model for inference when it arrives. If timeout is not None and
the result does not arrive within timeout seconds then TimeoutError is raised. If
the remote call raised an exception then that exception will be reraised by get().
:param timeout: timeout
:return: A prepared model for inference
"""


class AccuracyControlAlgoBackend(ABC):
# Metatypes

Expand Down Expand Up @@ -162,13 +149,3 @@ def prepare_for_inference(model: TModel) -> TPModel:
:param model: A model that should be prepared.
:return: Prepared model for inference.
"""

@staticmethod
@abstractmethod
def prepare_for_inference_async(model: TModel) -> AsyncPreparedModel:
"""
Prepares model for inference asynchronously.
:param model: A model that should be prepared.
:return: AsyncPreparedModel opbject.
"""
32 changes: 2 additions & 30 deletions nncf/quantization/algorithms/accuracy_control/openvino_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing
from typing import Any, List, Optional
from typing import List, Optional

import numpy as np
import openvino.runtime as ov
Expand All @@ -30,26 +29,6 @@
from nncf.openvino.graph.node_utils import get_weight_value
from nncf.openvino.graph.node_utils import is_node_with_bias
from nncf.quantization.algorithms.accuracy_control.backend import AccuracyControlAlgoBackend
from nncf.quantization.algorithms.accuracy_control.backend import AsyncPreparedModel


def compile_model(model: ov.Model, done_queue: multiprocessing.Queue) -> None:
compiled_model = ov.Core().compile_model(model, "CPU")
model_stream = compiled_model.export_model()
done_queue.put(model_stream)


class OVAsyncPreparedModel(AsyncPreparedModel):
def __init__(self, proc: multiprocessing.Process, done_queue: multiprocessing.Queue):
self.proc = proc
self.done_queue = done_queue

def get(self, timeout=None) -> ov.CompiledModel:
try:
model_stream = self.done_queue.get(timeout=timeout)
except multiprocessing.TimeoutError as ex:
raise TimeoutError() from ex
return ov.Core().import_model(model_stream, "CPU")


class OVAccuracyControlAlgoBackend(AccuracyControlAlgoBackend):
Expand Down Expand Up @@ -113,12 +92,5 @@ def get_model_size(model: ov.Model) -> int:
# Preparation of model

@staticmethod
def prepare_for_inference(model: ov.Model) -> Any:
def prepare_for_inference(model: ov.Model) -> ov.CompiledModel:
return ov.compile_model(model)

@staticmethod
def prepare_for_inference_async(model: ov.Model) -> Any:
queue = multiprocessing.Queue()
p = multiprocessing.Process(target=compile_model, args=(model, queue))
p.start()
return OVAsyncPreparedModel(p, queue)
24 changes: 14 additions & 10 deletions nncf/quantization/algorithms/accuracy_control/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# limitations under the License.

import operator
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, TypeVar, Union
Expand Down Expand Up @@ -59,7 +60,7 @@ def __init__(
dataset: Dataset,
algo_backend: AccuracyControlAlgoBackend,
evaluator: Evaluator,
num_processes: int = 1,
num_workers: int = 1,
ranking_fn: Optional[Callable[[Any, Any], float]] = None,
):
"""
Expand All @@ -70,6 +71,8 @@ def __init__(
:param dataset: Dataset for the ranking process.
:param algo_backend: The `AccuracyControlAlgoBackend` algo backend.
:param evaluator: Evaluator to validate model.
:param num_workers: The number of parallel workers that are used to rank quantization
operations.
:param ranking_fn: a function that compares values returned by
`Evaluator.collect_values_for_each_item()` method for initial and quantized model.
"""
Expand All @@ -78,7 +81,7 @@ def __init__(
self._algo_backend = algo_backend
self._evaluator = evaluator
self._ranking_fn = ranking_fn
self._num_processes = num_processes
self._num_workers = num_workers

def find_groups_of_quantizers_to_rank(self, quantized_model_graph: NNCFGraph) -> List[GroupToRank]:
"""
Expand Down Expand Up @@ -150,8 +153,8 @@ def rank_groups_of_quantizers(
nncf_logger.info("Calculating ranking score for groups of quantizers")
with timer():
# Calculate ranking score for groups of quantizers.
if self._num_processes > 1:
ranking_scores = self._multiprocessing_calculation_ranking_score(
if self._num_workers > 1:
ranking_scores = self._multithreading_calculation_ranking_score(
quantized_model,
quantized_model_graph,
groups_to_rank,
Expand Down Expand Up @@ -195,7 +198,7 @@ def _sequential_calculation_ranking_score(

return ranking_scores

def _multiprocessing_calculation_ranking_score(
def _multithreading_calculation_ranking_score(
self,
quantized_model: TModel,
quantized_model_graph: NNCFGraph,
Expand All @@ -205,22 +208,23 @@ def _multiprocessing_calculation_ranking_score(
):
ranking_scores = [] # ranking_scores[i] is the ranking score for groups_to_rank[i]
prepared_model_queue = []
executor = ThreadPoolExecutor(max_workers=self._num_workers)
for idx, current_group in enumerate(groups_to_rank):
modified_model = revert_operations_to_floating_point_precision(
current_group.operations, current_group.quantizers, quantized_model, quantized_model_graph
)

prepared_model_queue.append(self._algo_backend.prepare_for_inference_async(modified_model))
prepared_model_queue.append(executor.submit(self._algo_backend.prepare_for_inference, modified_model))

if idx >= (self._num_processes - 1):
prepared_model = prepared_model_queue.pop(0).get()
if idx >= (self._num_workers - 1):
prepared_model = prepared_model_queue.pop(0).result()
ranking_score = self._calculate_ranking_score(
prepared_model, ranking_subset_indices, reference_values_for_each_item
)
ranking_scores.append(float(ranking_score))

for _ in range(self._num_processes - 1):
prepared_model = prepared_model_queue.pop(0).get()
for _ in range(self._num_workers - 1):
prepared_model = prepared_model_queue.pop(0).result()
ranking_score = self._calculate_ranking_score(
prepared_model, ranking_subset_indices, reference_values_for_each_item
)
Expand Down

0 comments on commit f43e933

Please sign in to comment.