Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][Conformance] Ultralytics yolov8n and yolo11n #3116

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,18 @@ torchvision/swin_v2_s_backend_OV:
metric_value: 0.83638
torchvision/swin_v2_s_backend_FX_TORCH:
metric_value: 0.8360
ultralytics/yolov8n_backend_FP32:
metric_value: 0.6056
ultralytics/yolov8n_backend_FX_TORCH:
metric_value: 0.61417
ultralytics/yolov8n_backend_OV:
metric_value: 0.6188
ultralytics/yolo11n_backend_FP32:
metric_value: 0.6770
ultralytics/yolo11n_backend_FX_TORCH:
metric_value: 0.6735
ultralytics/yolo11n_backend_OV:
metric_value: 0.6752
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.7275
timm/crossvit_9_240_backend_FP32:
Expand Down
80 changes: 80 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tests.post_training.pipelines.image_classification_torchvision import ImageClassificationTorchvision
from tests.post_training.pipelines.lm_weight_compression import LMWeightCompression
from tests.post_training.pipelines.masked_language_modeling import MaskedLanguageModelingHF
from tests.post_training.pipelines.ultralytics_detection import UltralyticsDetection

QUANTIZATION_MODELS = [
# HF models
Expand Down Expand Up @@ -123,6 +124,85 @@
"backends": [BackendType.FX_TORCH, BackendType.OV],
"batch_size": 1,
},
# Ultralytics models
{
"reported_name": "ultralytics/yolov8n",
"model_id": "yolov8n",
"pipeline_cls": UltralyticsDetection,
"compression_params": {
"preset": nncf.QuantizationPreset.MIXED,
"ignored_scope": nncf.IgnoredScope(
types=["mul", "sub", "sigmoid", "__getitem__"],
subgraphs=[
nncf.Subgraph(
inputs=["cat_13", "cat_14", "cat_15"],
outputs=["output"],
)
],
),
},
"backends": [BackendType.FX_TORCH],
"batch_size": 1,
},
{
"reported_name": "ultralytics/yolov8n",
"model_id": "yolov8n",
"pipeline_cls": UltralyticsDetection,
"compression_params": {
"preset": QuantizationPreset.MIXED,
"ignored_scope": nncf.IgnoredScope(
types=["Multiply", "Subtract", "Sigmoid"],
subgraphs=[
nncf.Subgraph(
inputs=["/model.22/Concat", "/model.22/Concat_1", "/model.22/Concat_2"],
outputs=["output0/sink_port_0"],
)
],
),
},
"backends": [BackendType.OV],
"batch_size": 1,
},
{
"reported_name": "ultralytics/yolo11n",
"model_id": "yolo11n",
"pipeline_cls": UltralyticsDetection,
"compression_params": {
"model_type": nncf.ModelType.TRANSFORMER,
"preset": QuantizationPreset.MIXED,
"ignored_scope": nncf.IgnoredScope(
types=["mul", "sub", "sigmoid", "__getitem__"],
subgraphs=[
nncf.Subgraph(
inputs=["cat_13", "cat_14", "cat_15"],
outputs=["output"],
)
],
),
},
"backends": [BackendType.FX_TORCH],
"batch_size": 1,
},
{
"reported_name": "ultralytics/yolo11n",
"model_id": "yolo11n",
"pipeline_cls": UltralyticsDetection,
"compression_params": {
"model_type": nncf.ModelType.TRANSFORMER,
"preset": QuantizationPreset.MIXED,
"ignored_scope": nncf.IgnoredScope(
types=["Multiply", "Subtract", "Sigmoid"],
subgraphs=[
nncf.Subgraph(
inputs=["/model.23/Concat", "/model.23/Concat_1", "/model.23/Concat_2"],
outputs=["output0/sink_port_0"],
)
],
),
},
"backends": [BackendType.OV],
"batch_size": 1,
},
# Timm models
{
"reported_name": "timm/crossvit_9_240",
Expand Down
128 changes: 128 additions & 0 deletions tests/post_training/pipelines/ultralytics_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path
from typing import Dict, Tuple

import openvino as ov
import torch
from ultralytics import YOLO
from ultralytics.data.utils import check_det_dataset
from ultralytics.engine.validator import BaseValidator as Validator
from ultralytics.utils.torch_utils import de_parallel

import nncf
from nncf.torch import disable_patching
from tests.post_training.pipelines.base import OV_BACKENDS
from tests.post_training.pipelines.base import BackendType
from tests.post_training.pipelines.base import PTQTestPipeline


class UltralyticsDetection(PTQTestPipeline):
"""Pipeline for Yolo detection models from the Ultralytics repository"""

def prepare_model(self) -> None:
if self.batch_size != 1:
raise RuntimeError("Batch size > 1 is not supported")

model_path = f"{self.fp32_model_dir}/{self.model_id}"
yolo = YOLO(f"{model_path}.pt")
self.validator, self.data_loader = self._prepare_validation(yolo, "coco128.yaml")
self.dummy_tensor = torch.ones((1, 3, 640, 640))

if self.backend in OV_BACKENDS + [BackendType.FP32]:
onnx_model_path = Path(f"{model_path}.onnx")
ir_model_path = self.fp32_model_dir / "model_fp32.xml"
yolo.export(format="onnx", dynamic=True, half=False)
ov.save_model(ov.convert_model(onnx_model_path), ir_model_path)
self.model = ov.Core().read_model(ir_model_path)

if self.backend == BackendType.FX_TORCH:
pt_model = yolo.model
# Run mode one time to initialize all
# internal variables
pt_model(self.dummy_tensor)

with torch.no_grad():
with disable_patching():
self.model = torch.export.export(pt_model, args=(self.dummy_tensor,), strict=False).module()

def prepare_preprocessor(self) -> None:
pass

@staticmethod
def _validate_fx(
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
) -> Tuple[Dict, int, int]:
compiled_model = torch.compile(model, backend="openvino")
for batch_i, batch in enumerate(data_loader):
if num_samples is not None and batch_i == num_samples:
break
batch = validator.preprocess(batch)
preds = compiled_model(batch["img"])
preds = validator.postprocess(preds)
validator.update_metrics(preds, batch)
stats = validator.get_stats()
return stats, validator.seen, validator.nt_per_class.sum()

@staticmethod
def _validate_ov(
model: ov.Model, data_loader: torch.utils.data.DataLoader, validator: Validator, num_samples: int = None
) -> Tuple[Dict, int, int]:
model.reshape({0: [1, 3, -1, -1]})
compiled_model = ov.compile_model(model)
output_layer = compiled_model.output(0)
for batch_i, batch in enumerate(data_loader):
if num_samples is not None and batch_i == num_samples:
break
batch = validator.preprocess(batch)
preds = torch.from_numpy(compiled_model(batch["img"])[output_layer])
preds = validator.postprocess(preds)
validator.update_metrics(preds, batch)
stats = validator.get_stats()
return stats, validator.seen, validator.nt_per_class.sum()

def get_transform_calibration_fn(self):
def transform_func(batch):
return self.validator.preprocess(batch)["img"]

return transform_func

def prepare_calibration_dataset(self):
self.calibration_dataset = nncf.Dataset(self.data_loader, self.get_transform_calibration_fn())

@staticmethod
def _prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]:
custom = {"rect": False, "batch": 1} # method defaults
args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right

validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks)
stride = 32 # default stride
validator.stride = stride # used in get_dataloader() for padding
validator.data = check_det_dataset(data)
validator.init_metrics(de_parallel(model))

data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch)

return validator, data_loader

def _validate(self):
if self.backend == BackendType.FP32:
stats, _, _ = self._validate_ov(self.model, self.data_loader, self.validator)
elif self.backend in OV_BACKENDS:
stats, _, _ = self._validate_ov(self.compressed_model, self.data_loader, self.validator)
elif self.backend == BackendType.FX_TORCH:
stats, _, _ = self._validate_fx(self.compressed_model, self.data_loader, self.validator)
else:
raise RuntimeError(f"Backend {self.backend} is not supported in UltralyticsDetection")

self.run_info.metric_name = "mAP50(B)"
self.run_info.metric_value = stats["metrics/mAP50(B)"]
1 change: 1 addition & 0 deletions tests/post_training/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ timm==0.9.2
transformers==4.38.2
whowhatbench @ git+https://github.com/andreyanufr/who_what_benchmark@456d3584ce628f6c8605f37cd9a3ab2db1ebf933
datasets==2.21.0
ultralytics==8.3.27