-
Notifications
You must be signed in to change notification settings - Fork 48
/
config.py
48 lines (35 loc) · 1.41 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from dataclasses import dataclass, field
from typing import Any, Dict, Optional
from ...import_utils import openvino_version
from ..config import BackendConfig
@dataclass
class OVConfig(BackendConfig):
name: str = "openvino"
version: Optional[str] = openvino_version()
_target_: str = "optimum_benchmark.backends.openvino.backend.OVBackend"
# load options
no_weights: bool = False
# export options
export: bool = True
use_cache: bool = True
use_merged: bool = False
# openvino config
openvino_config: Dict[str, Any] = field(default_factory=dict)
# compilation options
half: bool = False
reshape: bool = False
# quantization options
quantization: bool = False
quantization_config: Dict[str, Any] = field(default_factory=dict)
# calibration options
calibration: bool = False
calibration_config: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
super().__post_init__()
self.device = self.device.lower()
if self.device not in ["cpu", "gpu"]:
raise ValueError(f"OVBackend only supports CPU devices, got {self.device}")
if self.intra_op_num_threads is not None:
raise NotImplementedError("OVBackend does not support intra_op_num_threads")
if self.quantization and not self.calibration:
raise ValueError("OpenVINO quantization requires enabling calibration.")