From 8a83597721310d0bdad3a5cc8fdd6610a08a48d8 Mon Sep 17 00:00:00 2001 From: Nikita Savelyev Date: Thu, 5 Sep 2024 17:21:29 +0200 Subject: [PATCH] BF16 support --- docs/api/source/conf.py | 1 + nncf/openvino/graph/node_utils.py | 5 ++- .../quantization/compression_primitives.py | 7 ++-- .../weight_compression/weight_lowering.py | 9 +++-- nncf/tensor/definitions.py | 1 + nncf/tensor/functions/__init__.py | 3 ++ nncf/tensor/functions/ov.py | 40 +++++++++++++++++++ weight_compression.py | 12 +++--- 8 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 nncf/tensor/functions/ov.py diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index 4637c7875d4..098e091a0f2 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -143,6 +143,7 @@ def collect_api_entities() -> APIInfo: "nncf.tensor.functions.numpy_linalg", "nncf.tensor.functions.torch_numeric", "nncf.tensor.functions.torch_linalg", + "nncf.tensor.functions.ov", ] with mock(mock_modules): diff --git a/nncf/openvino/graph/node_utils.py b/nncf/openvino/graph/node_utils.py index e39c4e49467..bef4badca32 100644 --- a/nncf/openvino/graph/node_utils.py +++ b/nncf/openvino/graph/node_utils.py @@ -8,7 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Any, Callable, Dict, List, Optional, Tuple, Type import numpy as np @@ -116,6 +116,9 @@ def get_const_value(const_node: ov.Node) -> np.ndarray: :return: The constant value. """ if const_node.get_element_type() == ov.Type.bf16: + INPUT_DTYPE = os.environ.get("INPUT_DTYPE", "fp32") + if INPUT_DTYPE == "bf16": + return ov.Tensor(const_node.output(0)) # Fixed FP32 data type as the result for BF16 constant return const_node.get_data(dtype=np.float32) return const_node.data diff --git a/nncf/openvino/quantization/compression_primitives.py b/nncf/openvino/quantization/compression_primitives.py index 29284be2b7e..a9582988c53 100644 --- a/nncf/openvino/quantization/compression_primitives.py +++ b/nncf/openvino/quantization/compression_primitives.py @@ -86,15 +86,16 @@ def _build_compress_model( invert_scale: Optional[bool] = False, return_nodes: bool = False, ): - FP16_INPUT = bool(int(os.environ.get("FP16_INPUT", "0"))) + INPUT_DTYPE = os.environ.get("INPUT_DTYPE", "fp32") INT8_OUTPUT = bool(int(os.environ.get("INT8_OUTPUT", "0"))) SHARE_OUTPUTS = bool(int(os.environ.get("SHARE_OUTPUTS", "0"))) - w = opset.parameter(weight_shape, name="w", dtype=np.float16 if FP16_INPUT else np.float32) + input_dtype = ov.Type.f32 if INPUT_DTYPE == "fp32" else ov.Type.f16 if INPUT_DTYPE == "fp16" else ov.Type.bf16 + w = opset.parameter(weight_shape, name="w", dtype=input_dtype) s = opset.parameter(scale_shape, name="s") parameters = [w, s] - if FP16_INPUT: + if input_dtype != ov.Type.f32: w = opset.convert(w, ov.Type.f32) compressed_w = w * (1 / s) if invert_scale else w / s diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index 83b5845103d..0a4124e5760 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -309,7 +309,7 @@ def calculate_quantized_weight( log_once(logging.INFO, "Compression time may improve after installing OpenVINO") NUMPY_COMPRESSION = bool(int(os.environ.get("NUMPY_COMPRESSION", "0"))) - if weight.backend == TensorBackend.numpy and is_openvino_available() and not NUMPY_COMPRESSION: + if weight.backend in [TensorBackend.numpy, TensorBackend.ov] and is_openvino_available() and not NUMPY_COMPRESSION: from nncf.openvino.quantization.compression_primitives import OV_COMPRESSION_PRIMITIVE_CACHE zero_point_shape = None if zero_point is None else zero_point.shape @@ -317,7 +317,8 @@ def calculate_quantized_weight( config, weight.shape, scale.shape, zero_point_shape ) - assert weight.data.flags["C_CONTIGUOUS"] + if hasattr(weight.data, "flags"): + assert weight.data.flags["C_CONTIGUOUS"] input_tensors = weight.data, scale.data if zero_point is not None: input_tensors += (zero_point.data,) @@ -410,8 +411,8 @@ def do_int_quantization( assert config.is_integer(), "The function supports integer quantization only" group_size = config.group_size - FP16_INPUT = bool(int(os.environ.get("FP16_INPUT", "0"))) - if weight.dtype != TensorDataType.float32 and not FP16_INPUT: + INPUT_DTYPE = os.environ.get("INPUT_DTYPE", "fp32") + if weight.dtype != TensorDataType.float32 and INPUT_DTYPE == "fp32": weight = weight.astype(TensorDataType.float32) if group_size != -1: diff --git a/nncf/tensor/definitions.py b/nncf/tensor/definitions.py index 447a6dd8bb5..a8b05991c2e 100644 --- a/nncf/tensor/definitions.py +++ b/nncf/tensor/definitions.py @@ -54,6 +54,7 @@ class TensorBackend(Enum): numpy = auto() torch = auto() + ov = auto() @dataclass diff --git a/nncf/tensor/functions/__init__.py b/nncf/tensor/functions/__init__.py index 5a286a6fc13..9affab79c90 100644 --- a/nncf/tensor/functions/__init__.py +++ b/nncf/tensor/functions/__init__.py @@ -75,5 +75,8 @@ def _initialize_backends(): import nncf.tensor.functions.torch_linalg import nncf.tensor.functions.torch_numeric # noqa: F401 + with contextlib.suppress(ImportError): + import nncf.tensor.functions.ov # noqa: F401 + _initialize_backends() diff --git a/nncf/tensor/functions/ov.py b/nncf/tensor/functions/ov.py new file mode 100644 index 00000000000..5122dea825f --- /dev/null +++ b/nncf/tensor/functions/ov.py @@ -0,0 +1,40 @@ +import numpy as np +import openvino as ov + +from nncf.tensor import TensorDataType +from nncf.tensor.functions import numeric +from .numpy_numeric import DTYPE_MAP as NP_DTYPE_MAP +from ..definitions import TensorBackend + +DTYPE_MAP = { + TensorDataType.float16: ov.Type.f16, + TensorDataType.bfloat16: ov.Type.bf16, + TensorDataType.float32: ov.Type.f32, + TensorDataType.float64: ov.Type.f64, + TensorDataType.int8: ov.Type.i8, + TensorDataType.int32: ov.Type.i32, + TensorDataType.int64: ov.Type.i64, + TensorDataType.uint8: ov.Type.u8, +} + +DTYPE_MAP_REV = {v: k for k, v in DTYPE_MAP.items()} + + +@numeric.backend.register(ov.Tensor) +def _(a: ov.Tensor) -> TensorBackend: + return TensorBackend.ov + + +@numeric.astype.register(ov.Tensor) +def _(a: ov.Tensor, dtype: TensorDataType) -> np.ndarray: + return a.data.astype(NP_DTYPE_MAP[dtype]) + + +@numeric.dtype.register(ov.Tensor) +def _(a: ov.Tensor) -> TensorDataType: + return DTYPE_MAP_REV[a.get_element_type()] + + +@numeric.size.register(ov.Tensor) +def _(a: ov.Tensor) -> int: + return a.size diff --git a/weight_compression.py b/weight_compression.py index 6b68c6c0f4a..864fcd60d7e 100644 --- a/weight_compression.py +++ b/weight_compression.py @@ -36,7 +36,9 @@ def parse_arguments(): parser.add_argument("--dynamic-compression", action="store_true", help="Enable dynamic compression") - parser.add_argument("--fp16-input", action="store_true", help="Enable FP16 input mode") + parser.add_argument("--input-dtype", type=str, choices=["fp32", "fp16", "bf16"], default="fp32", help="OV model input dtype") + + parser.add_argument("--bf16-input", action="store_true", help="Enable BF16 input mode") parser.add_argument("--int8-output", action="store_true", help="Output in int8") @@ -61,7 +63,7 @@ def main(args): numpy_compression = args.numpy_compression dynamic_compression = args.dynamic_compression - fp16_input = args.fp16_input + input_dtype = args.input_dtype int8_output = args.int8_output recompile = args.recompile share_outputs = args.share_outputs @@ -71,7 +73,7 @@ def main(args): else: log_dir_suffix = "ov-dynamic" if dynamic_compression else "ov-static" log_dir_suffix = f"{log_dir_suffix}_{('output-int8' if int8_output else 'output-fp32')}" - log_dir_suffix = f"{log_dir_suffix}_{('input-fp16' if fp16_input else 'input-fp32')}" + log_dir_suffix = f"{log_dir_suffix}_{f'input-{input_dtype}'}" if recompile: log_dir_suffix = f"{log_dir_suffix}_recompile" if share_outputs: @@ -89,7 +91,7 @@ def main(args): os.environ["NUMPY_COMPRESSION"] = f"{int(numpy_compression)}" os.environ["DYNAMIC_COMPRESSION"] = f"{int(dynamic_compression)}" - os.environ["FP16_INPUT"] = f"{int(fp16_input)}" + os.environ["INPUT_DTYPE"] = input_dtype os.environ["INT8_OUTPUT"] = f"{int(int8_output)}" os.environ["RECOMPILE"] = f"{int(recompile)}" os.environ["SHARE_OUTPUTS"] = f"{int(share_outputs)}" @@ -149,7 +151,7 @@ def main(args): f"{model_path}," f"{numpy_compression}," f"{'-' if numpy_compression else 'Dynamic' if dynamic_compression else 'Static'}," - f"{'-' if numpy_compression else 'FP16' if fp16_input else 'FP32'}," + f"{'-' if numpy_compression else input_dtype.upper()}," f"{'-' if numpy_compression else 'INT8' if int8_output else 'FP32'}," f"{compression_time:.2f}," f"{peak_memory:.2f},"