Skip to content

Commit

Permalink
Update calibrate tool (#2174)
Browse files Browse the repository at this point in the history
### Changes

- Add support for models that don't have per-sample metrics.

### Reason for changes

Some models don't have per-sample metrics. An error occurred with these
models.

### Related tickets

N/A

### Tests

<!--- How was the correctness of changes tested and whether new tests
were added -->
  • Loading branch information
andrey-churkin authored Oct 6, 2023
1 parent 2c417f5 commit 403597c
Showing 1 changed file with 30 additions and 3 deletions.
33 changes: 30 additions & 3 deletions tests/openvino/tools/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,43 @@ class ACValidationFunction:
"ndcg": "sigmoid_recom_loss",
}

def __init__(self, model_evaluator: ModelEvaluator, metric_name: str, requests_number: Optional[int] = None):
SPECIAL_METRICS = [
"cmc",
"reid_map",
"pairwise_accuracy_subsets",
"pairwise_accuracy",
"normalized_embedding_accuracy",
"face_recognition_tafa_pair_metric",
"localization_recall",
"coco_orig_keypoints_precision",
"coco_orig_segm_precision",
"coco_orig_keypoints_precision",
"spearman_correlation_coef",
"pearson_correlation_coef",
]

def __init__(
self, model_evaluator: ModelEvaluator, metric_name: str, metric_type: str, requests_number: Optional[int] = None
):
"""
:param model_evaluator: Model Evaluator.
:param metric_name: Name of a metric.
:param metric_type: Type of a metric.
:param requests_number: A number of infer requests. If it is `None`,
the count will be selected automatically.
"""
self._model_evaluator = model_evaluator
self._metric_name = metric_name
self._metric_type = metric_type
self._persample_metric_name = self.METRIC_TO_PERSAMPLE_METRIC.get(self._metric_name, self._metric_name)
registered_metrics = model_evaluator.get_metrics_attributes()
if self._persample_metric_name not in registered_metrics:
self._model_evaluator.register_metric(self._persample_metric_name)
self._requests_number = requests_number
self._values_for_each_item = []

self._collect_outputs = self._metric_type in self.SPECIAL_METRICS

def __call__(self, compiled_model: ov.CompiledModel, indices: Optional[Iterable[int]] = None) -> float:
"""
Calculates metrics for the provided model.
Expand Down Expand Up @@ -203,6 +224,11 @@ def _output_callback(self, raw_predictions, **kwargs):
return

for sample_id, results in metrics_result.items():
if self._collect_outputs:
output = list(raw_predictions.values())[0]
self._values_for_each_item.append({"sample_id": sample_id, "metric_value": output})
continue

for metric_result in results:
if metric_result.metric_name != self._persample_metric_name:
continue
Expand Down Expand Up @@ -940,10 +966,11 @@ def quantize_model_with_accuracy_control(
)
model_evaluator.load_network([{"model": ov_model}])

metric_type = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0]["type"]
metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0].get("name", None)
if metric_name is None:
metric_name = accuracy_checker_config["models"][0]["datasets"][0]["metrics"][0]["type"]
validation_fn = ACValidationFunction(model_evaluator, metric_name)
metric_name = metric_type
validation_fn = ACValidationFunction(model_evaluator, metric_name, metric_type)

name_to_quantization_impl_map = {
"pot": pot_quantize_with_native_accuracy_control,
Expand Down

0 comments on commit 403597c

Please sign in to comment.