Skip to content

Commit

Permalink
[PTQ][OV] BF16 support (#2307)
Browse files Browse the repository at this point in the history
### Changes

- Added BF16 type support
- Added FQ parameters generation based on type
- Extended the list of the supported types for OpenVINO input data with
`ov.Tensor`

### Reason for changes

- BF16 support

### Related tickets

- 126782

### Tests

- Updated existing tests with BF16
- manual/post_training_weight_compression/99 - no regressions (failure
due to CI issue)
- manual/post_training_quantization/421 - no regressions (failure due to
CI issue)
  • Loading branch information
KodiaqQ authored Jul 12, 2024
1 parent d033c6a commit 6926cf1
Show file tree
Hide file tree
Showing 12 changed files with 313 additions and 166 deletions.
135 changes: 68 additions & 67 deletions nncf/openvino/graph/model_transformer.py

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,18 @@ def cnt_if_op(model: ov.Model, cnt: int) -> int:
return cnt_if_op(model, 0)


def get_const_value(const_node: ov.Node, dtype: Optional[np.dtype] = None) -> np.ndarray:
def get_const_value(const_node: ov.Node) -> np.ndarray:
"""
Returns the constant tensor for the node.
This method is applicable only for the floating-point constant data.
:param const_node: OpenVINO node.
:param dtype: Destination type.
:return: The constant value.
"""
if dtype is None:
return const_node.data
return const_node.get_data(dtype=dtype)
if const_node.get_element_type() == ov.Type.bf16:
# Fixed FP32 data type as the result for BF16 constant
return const_node.get_data(dtype=np.float32)
return const_node.data


def get_bias_value(node_with_bias: NNCFNode, nncf_graph: NNCFGraph, model: ov.Model) -> np.ndarray:
Expand Down
6 changes: 6 additions & 0 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self, target_point: OVTargetPoint):


class OVOutputInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, output_dtype: ov.Type = ov.Type.f32):
super().__init__(target_point)
self.output_dtype = output_dtype

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
raise NotImplementedError()
Expand All @@ -60,11 +64,13 @@ def __init__(
inplace_op_fn: InplaceInsertionFnType,
fn_output_port_id: int,
last_inplace_node_name: str,
output_dtype: ov.Type = ov.Type.f32,
):
super().__init__(target_point)
self.inplace_op_fn = inplace_op_fn
self.fn_output_port_id = fn_output_port_id
self.last_inplace_node_name = last_inplace_node_name
self.output_dtype = output_dtype

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
Expand Down
8 changes: 4 additions & 4 deletions nncf/openvino/quantization/quantize_ifmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def create_output_insertion_commands_if_node(model: ov.Model, if_node: NNCFNode)
commands = []
name_to_node_mapping = {op.get_friendly_name(): op for op in model.get_ops()}
ov_node = name_to_node_mapping[if_node.node_name]
for port_id in range(len(ov_node.inputs())):
commands.append(
OVOutputInsertionCommand(OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id))
)
for port_id, ov_input in enumerate(ov_node.inputs()):
target_point = OVTargetPoint(TargetType.PRE_LAYER_OPERATION, if_node.node_name, port_id)
ov_input_dtype = ov_input.get_element_type()
commands.append(OVOutputInsertionCommand(target_point, output_dtype=ov_input_dtype))
return commands
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import openvino as ov
from openvino.runtime import opset13 as opset

Expand Down Expand Up @@ -114,7 +113,12 @@ def set_weight(
const_port = node_with_const.input(weight_port_id)
const_node = node_with_const.input_value(weight_port_id).get_node()

new_const_node = ov.runtime.op.Constant(weight.data, shared_memory=True)
shared_memory = True
if const_node.get_element_type() == ov.Type.bf16:
# Shared memory does not work for BF16 precision
shared_memory = False

new_const_node = ov.runtime.op.Constant(weight.data, shared_memory=shared_memory)
new_const_node.set_friendly_name(const_node.get_friendly_name())
const_port.replace_source_output(new_const_node.output(0))

Expand Down Expand Up @@ -167,7 +171,7 @@ def transform_model(
should_add_convert_node = True
break

weight = Tensor(get_const_value(const_node, np.float32 if const_dtype == ov.Type.bf16 else None))
weight = Tensor(get_const_value(const_node))
original_shape = weight.shape
compressed_weight = compress_weight(
weight,
Expand Down
3 changes: 2 additions & 1 deletion tests/openvino/native/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def get_dataset_for_test(model):
input_data = {}
for param in model.get_parameters():
input_shape = param.partial_shape.get_max_shape()
input_data[param.get_output_tensor(0).get_any_name()] = rng.uniform(0, 1, input_shape)
tensor = param.get_output_tensor(0)
input_data[tensor.get_any_name()] = rng.uniform(0, 1, input_shape).astype(tensor.get_element_type().to_dtype())

dataset = Dataset([input_data])
return dataset
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
strict digraph {
"0 Parameter_MatMul.0" [id=0, type=Parameter];
"1 Convert_430" [id=1, type=Convert];
"2 MatMul" [id=2, type=MatMul];
"3 Convert_431" [id=3, type=Convert];
"4 Result_MatMul.0" [id=4, type=Result];
"5 MatMul_const" [id=5, type=Constant];
"0 Parameter_MatMul.0" -> "1 Convert_430" [label="[1, 3, 4, 2]", style=solid];
"1 Convert_430" -> "2 MatMul" [label="[1, 3, 4, 2]", style=solid];
"2 MatMul" -> "3 Convert_431" [label="[1, 3, 2, 5]", style=solid];
"3 Convert_431" -> "4 Result_MatMul.0" [label="[1, 3, 2, 5]", style=solid];
"5 MatMul_const" -> "2 MatMul" [label="[1, 3, 4, 5]", style=solid];
}
14 changes: 7 additions & 7 deletions tests/openvino/native/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,21 +282,21 @@ def _create_ov_model(self):


class FPModel(OVReferenceModel):
def __init__(self, const_dtype="FP32", input_dtype="FP32"):
self.const_dtype = np.float32 if const_dtype == "FP32" else np.float16
self.input_dtype = np.float32 if input_dtype == "FP32" else np.float16
def __init__(self, const_dtype: ov.Type = ov.Type.f32, input_dtype: ov.Type = ov.Type.f32):
self.const_dtype = const_dtype
self.input_dtype = input_dtype
super().__init__()

def _create_ov_model(self):
input_shape = [1, 3, 4, 2]
input_1 = opset.parameter(input_shape, name="Input", dtype=self.input_dtype)
data = self._rng.random((1, 3, 4, 5)).astype(self.const_dtype)
data = opset.constant(value=self._rng.random((1, 3, 4, 5)), dtype=self.const_dtype, name="MatMul_const")
if self.const_dtype != self.input_dtype:
data = opset.convert(data, self.input_dtype)
data = opset.convert(data, self.input_dtype.to_string())
matmul = opset.matmul(input_1, data, transpose_a=True, transpose_b=False, name="MatMul")
bias = self._rng.random((1, 3, 1, 1)).astype(self.const_dtype)
bias = opset.constant(value=self._rng.random((1, 3, 1, 1)), dtype=self.const_dtype, name="MatMul_bias")
if self.const_dtype != self.input_dtype:
bias = opset.convert(bias, self.input_dtype)
bias = opset.convert(bias, self.input_dtype.to_string())
add = opset.add(matmul, bias, name="Add")
result = opset.result(add, name="Result_Add")
result.get_output_tensor(0).set_names(set(["Result_Add"]))
Expand Down
13 changes: 6 additions & 7 deletions tests/openvino/native/quantization/test_fq_params_calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from pathlib import Path

import numpy as np
import openvino as ov
import pytest
import torch
Expand Down Expand Up @@ -163,8 +162,8 @@ def test_synthetic_models_fq_shapes(model_creator_func, ref_shapes, inplace_stat
assert node["output_high"].shape == ref_shapes[node_name]


@pytest.mark.parametrize("const_dtype", ["FP16", "FP32"])
@pytest.mark.parametrize("input_dtype", ["FP16", "FP32"])
@pytest.mark.parametrize("const_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16])
@pytest.mark.parametrize("input_dtype", [ov.Type.f16, ov.Type.f32, ov.Type.bf16])
def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistics):
model = FPModel(const_dtype, input_dtype)
quantized_model = quantize_model(
Expand All @@ -174,10 +173,10 @@ def test_fq_precision_orig_fp32model(const_dtype, input_dtype, inplace_statistic
if op.get_type_name() == "FakeQuantize":
inp_node = op.input(0)
fq_input_node = inp_node.get_source_output().get_node()
if fq_input_node.get_element_type() == "Constant":
assert op.get_element_type() == ov.Type(np.float32 if input_dtype == "FP32" else np.float16)
if fq_input_node.get_type_name() == "Constant":
assert op.get_element_type() == const_dtype
elif op.get_type_name() == "Convert":
inp_node = op.input(0)
fq_input_node = inp_node.get_source_output().get_node()
if fq_input_node.get_element_type() == "Constant":
assert op.get_element_type() == ov.Type(np.float32 if const_dtype == "FP32" else np.float16)
if fq_input_node.get_type_name() == "Constant":
assert op.get_element_type() == input_dtype
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def test_compress_weights(model_creator_func, ref_nodes):

fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize")
assert len(fq_nodes) == len(ref_fqs_names)
for fq_name in fq_nodes:
for fq_node in fq_nodes:
fq_name = fq_node.get_friendly_name()
assert fq_name in ref_fqs_names

for op in quantized_model.get_ops():
Expand Down Expand Up @@ -76,7 +77,8 @@ def test_overflow_fix_applied(model_creator_func, ref_nodes):

fq_nodes = get_nodes_by_type(quantized_model, type_name="FakeQuantize")
assert len(fq_nodes) == len(ref_fqs_names)
for fq_name in fq_nodes:
for fq_node in fq_nodes:
fq_name = fq_node.get_friendly_name()
assert fq_name in ref_fqs_names

for op in quantized_model.get_ops():
Expand Down
Loading

0 comments on commit 6926cf1

Please sign in to comment.