Skip to content

Commit

Permalink
Convert precision use optimized version for bf16 -> f16
Browse files Browse the repository at this point in the history
  • Loading branch information
praasz committed Jul 10, 2024
1 parent b75b040 commit 702a3db
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,26 @@ std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::f32, ov::el
return new_constant;
}

template <>
std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f16>(
std::shared_ptr<ov::op::v0::Constant>& constant) {
using src_type = typename element_type_traits<ov::element::Type_t::bf16>::value_type;
using dst_type = typename element_type_traits<ov::element::Type_t::f16>::value_type;

const auto* src_data = constant->get_data_ptr<src_type>();
const auto size = shape_size(constant->get_shape());

auto new_constant = std::make_shared<ov::op::v0::Constant>(ov::element::Type_t::f16, constant->get_shape());
new_constant->output(0).set_names(constant->output(0).get_names());
auto* dst_data = const_cast<dst_type*>(reinterpret_cast<const dst_type*>(new_constant->get_data_ptr()));
if (dst_data == nullptr)
OPENVINO_THROW("Can't get destination data pointer");

ov::reference::convert_from_bf16_to_f16_with_clamp(src_data, dst_data, size);

return new_constant;
}

template <>
std::shared_ptr<Node> change_constant_precision<ov::element::Type_t::f16, ov::element::Type_t::f32>(
std::shared_ptr<ov::op::v0::Constant>& constant) {
Expand Down Expand Up @@ -1326,6 +1346,8 @@ bool fuse_type_to_constant(const std::shared_ptr<ov::Node>& node,
new_const = change_constant_precision<ov::element::Type_t::f64, ov::element::Type_t::f32>(constant);
} else if (from == ov::element::bf16 && to == ov::element::f32) {
new_const = change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f32>(constant);
} else if (from == ov::element::bf16 && to == ov::element::f16) {
new_const = change_constant_precision<ov::element::Type_t::bf16, ov::element::Type_t::f16>(constant);
} else if (from == ov::element::f32 && to == ov::element::f16) {
new_const = change_constant_precision<ov::element::Type_t::f32, ov::element::Type_t::f16>(constant);
} else if (from == ov::element::f16 && to == ov::element::f32) {
Expand Down
34 changes: 34 additions & 0 deletions src/common/transformations/tests/utils/convert_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,40 @@ TEST(TransformationTests, ConvertPrecision_Convert_clamp_1) {
ASSERT_TRUE(res.valid) << res.message;
}

TEST(TransformationTests, ConvertPrecision_Convert_clamp_bf16_f16) {
// fp16 out of range should be clamped to [fp16_min, fp16_max]
std::shared_ptr<Model> model(nullptr), model_ref(nullptr);
{
auto input = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 1000, 3});
auto const_node = opset10::Constant::create(element::bf16, Shape{3}, {100000.0f, -100000.0f, 10.0f});
auto convert = std::make_shared<opset4::Convert>(const_node, element::f16);
auto add_1 = make_shared<opset10::Add>(input, convert);
model = std::make_shared<Model>(NodeVector{add_1}, ParameterVector{input});

pass::Manager manager;
static const precisions_map precisions = {{element::bf16, element::f16}};
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertPrecision>(precisions);
manager.run_passes(model);
}

{
auto max_fp16 = static_cast<float>(std::numeric_limits<ov::float16>::max());
auto input = std::make_shared<opset4::Parameter>(element::f16, Shape{1, 1000, 3});
auto const_node = opset10::Constant::create(element::f16, Shape{3}, {max_fp16, -max_fp16, 10.0f});
auto add_1 = make_shared<opset10::Add>(input, const_node);

model_ref = std::make_shared<Model>(NodeVector{add_1}, ParameterVector{input});
}
ASSERT_NO_THROW(check_rt_info(model));
const auto fc = FunctionsComparator::with_default()
.enable(FunctionsComparator::PRECISIONS)
.enable(FunctionsComparator::CONST_VALUES)
.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
const auto res = fc.compare(model, model_ref);
ASSERT_TRUE(res.valid) << res.message;
}

#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
TEST(TransformationTests, ConvertPrecision_Convert_clamp_2) {
#else
Expand Down
3 changes: 3 additions & 0 deletions src/core/reference/include/openvino/reference/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,8 @@ size_t count_out_of_f16_range(const float* arg, size_t count);

// Convert values from f32 to f16 with clamping to f16 min/max when value is out of normal finite numbers range
void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t count);

// Convert values from bf16 to f16 with clamping to f16 min/max when value is out of normal finite numbers range
void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count);
} // namespace reference
} // namespace ov
38 changes: 38 additions & 0 deletions src/core/reference/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ void jit_convert_vec<bfloat16, float16>(jit::Generator& gen, const Xbyak::RegExp
gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination
}

template <>
void jit_convert_vec<bfloat16, float16, true>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
const auto f32vec = gen.ymm4;
const auto f16vec = gen.xmm3;

auto upper_bound = gen.ymm5;
auto lower_bound = gen.ymm6;

gen.vpmovzxwd(f32vec, gen.yword[src]); // load bf16 into tmp
gen.vpslld(f32vec, f32vec, 16); // convert bf16->f32 by bit shift
gen.vminps(f32vec, f32vec, upper_bound); // clamp f16 max
gen.vmaxps(f32vec, f32vec, lower_bound); // clamp f16 lowest
gen.vcvtps2ph(f16vec, f32vec, 0); // convert f32 -> f16
gen.vmovdqu(gen.xword[dst], f16vec); // move result to destination
}

template <>
void jit_convert_vec<bfloat16, float>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
const auto f32vec = gen.ymm4;
Expand All @@ -92,6 +108,11 @@ void jit_convert_vec_prepare<float, float16, true>(jit::Generator& gen) {
gen.vmovdqu(lower_bound, gen.yword[addr]);
}

template <>
void jit_convert_vec_prepare<bfloat16, float16, true>(jit::Generator& gen) {
jit_convert_vec_prepare<float, float16, true>(gen);
}

template <>
void jit_convert_vec<float, float16, true>(jit::Generator& gen, const Xbyak::RegExp& src, const Xbyak::RegExp& dst) {
auto f16vec = gen.xmm3;
Expand Down Expand Up @@ -552,6 +573,23 @@ void convert_from_f32_to_f16_with_clamp(const float* arg, float16* out, size_t c
#endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
}

void convert_from_bf16_to_f16_with_clamp(const bfloat16* arg, float16* out, size_t count) {
#if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
convert_impl<bfloat16, float16, true>(arg, out, count);
#else
// FIXME CVS-125496: duplicate and stub for ARM, provide optimized solution
for (size_t i = 0; i < count; ++i) {
if (arg[i] > std::numeric_limits<ov::float16>::max()) {
out[i] = std::numeric_limits<ov::float16>::max();
} else if (arg[i] < std::numeric_limits<ov::float16>::lowest()) {
out[i] = std::numeric_limits<ov::float16>::lowest();
} else {
out[i] = static_cast<ov::float16>(arg[i]);
}
}
#endif // defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64)
}

size_t count_out_of_f16_range(const float* arg, size_t count) {
size_t num_out_of_range = 0;

Expand Down

0 comments on commit 702a3db

Please sign in to comment.