From 9269453eceafd04705443013f256edbe2d5ac8ae Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Mon, 11 Nov 2024 16:08:07 +0000 Subject: [PATCH 1/6] Draft of GQA decomposition: half/half RoPE, SDPA. Not implemented: real GQA, inplace KV cache, not balanced batch, rope_interleaved etc. --- .../com.microsoft/group_query_attention.cpp | 237 ++++++++++++++++++ 1 file changed, 237 insertions(+) create mode 100644 src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp new file mode 100644 index 00000000000000..cca20a26b86743 --- /dev/null +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -0,0 +1,237 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "core/null_node.hpp" +#include "core/operator_set.hpp" +#include "openvino/frontend/exception.hpp" + +#include "openvino/op/add.hpp" +#include "openvino/op/broadcast.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/divide.hpp" +#include "openvino/op/equal.hpp" +#include "openvino/op/floor.hpp" +#include "openvino/op/floor_mod.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/greater.hpp" +#include "openvino/op/greater_eq.hpp" +#include "openvino/op/less.hpp" +#include "openvino/op/log.hpp" +#include "openvino/op/logical_not.hpp" +#include "openvino/op/logical_or.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/maximum.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/pad.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/select.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/softmax.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/sqrt.hpp" +#include "openvino/op/squeeze.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/transpose.hpp" +#include "openvino/op/unsqueeze.hpp" + +#include "utils/split.hpp" +using namespace ov::op; +using ov::Shape; + +namespace ov { +namespace frontend { +namespace onnx { +namespace com_microsoft { +namespace detail { +namespace { + +// FIXME: Reuse the same function from file attention.cpp +ov::NodeVector split_to_QKV(const Output& node, + int64_t num_heads, + const std::vector& qkv_hidden_sizes); + +ov::Output get_elements(const ov::Output& shape, const std::vector& dims) { + static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared(shape, dims_const, zero); +} + +ov::Output get_dimensions(const ov::Output& node, const std::vector& dims) { + return get_elements(std::make_shared(node), dims); +} + +ov::Output rope(const Output& x, Output cos, Output sin, bool interleaved, const Output& head_size, const Output& seq_len) { + OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode + + using v1::Split; + using v0::Constant; + using v1::Multiply; + using v1::Add; + using v8::Slice; + using v0::Concat; + using v1::Subtract; + using Output = Output; + using std::make_shared; + + Output zero = Constant::create(element::i32, Shape{1}, {0}); + Output step = Constant::create(element::i32, Shape{1}, {1}); + + // cut for the current sequence length + cos = make_shared(cos, zero, seq_len, step, zero); + sin = make_shared(sin, zero, seq_len, step, zero); + + Output cos_multiplier = make_shared(OutputVector{cos, cos}, 1); + OutputVector x_split = make_shared(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs(); + + Output res_0 = make_shared( + make_shared(x_split[0], cos), + make_shared(x_split[1], sin) + ); + + Output res_1 = make_shared( + make_shared(x_split[0], sin), + make_shared(x_split[1], cos) + ); + + return make_shared(OutputVector{res_0, res_1}, -1); +} + + +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + auto nodes = node.get_ov_inputs(); + const auto& input = nodes[0]; + + ov::Output Q, K, V, head_size; + + const auto num_heads = node.get_attribute_value("num_heads"); + + if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) { + const auto split_result = detail::split_to_QKV(input, num_heads, {}); + Q = split_result[0]; + K = split_result[1]; + V = split_result[2]; + head_size = split_result[3]; + } else { + Q = input; + K = nodes[1]; + V = nodes[2]; + head_size = detail::get_dimensions(Q, {-1}); + } + + const auto& past_K = nodes[3]; + const auto& past_V = nodes[4]; + + const auto& total_sequence_length = nodes[6]; + const auto& cos = nodes[7]; + const auto& sin = nodes[8]; + const bool rope_interleaved = node.get_attribute_value("rotary_interleaved", 0); + + Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, total_sequence_length); + K = detail::rope(K, cos, sin, rope_interleaved, head_size, total_sequence_length); + + K = std::make_shared(ov::OutputVector{past_K, K}, 2); + V = std::make_shared(ov::OutputVector{past_V, V}, 2); + + // FIXME: Unaligned batch of sequences is not supported. + // That means all input sequence length should be the same and match input.shape[1] + // We do not check that here because it depends on runtime values. + + // FIXME: The same tensor at input/output of past/preset K and V are not supported. + // It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that. + + auto output = std::make_shared(Q, K, V, true); + + return {output, K, V}; +} +ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); +} // namespace opset_1 + +namespace detail { +namespace { + + + +std::shared_ptr get_hidden_size(const std::shared_ptr& node_shape) { + // node has shape (batch_size, sequence_length, 3 * hidden_size) + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto hidden_size_x3 = get_elements(node_shape, {2}); + const auto three = v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + const auto hidden_size = std::make_shared(hidden_size_x3, three); + return hidden_size; +} + +ov::NodeVector split_to_QKV(const Output& node, + int64_t num_heads, + const std::vector& qkv_hidden_sizes) { + ov::OutputVector split; + std::shared_ptr head_size = nullptr; + const auto& node_type = node.get_element_type(); + const auto node_shape = std::make_shared(node); + // node has shape (batch_size, sequence_length, 3 * hidden_size) + // fetch the first two dimensions + const auto batch_size_seq_len = get_elements(node_shape, {0, 1}); + const auto num_heads_node = v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads}); + if (qkv_hidden_sizes.size() == 0) { + const auto hidden_size = get_hidden_size(node_shape); + // head_size = hidden_size / num_heads + head_size = std::make_shared(hidden_size, num_heads_node); + // split the node into 3 even parts Q, K, V with shape (batch_size, sequence_len, hidden_size) + split = ov::op::util::make_split(node, 3, 2); + // and reshape each part to new shape (batch_size, sequence_len, num_heads, head_size) + auto new_shape = std::make_shared(ov::OutputVector{batch_size_seq_len, num_heads_node, head_size}, 0); + for (size_t i = 0; i < split.size(); i++) { + split[i] = std::make_shared(split[i], new_shape, false); + } + head_size = std::make_shared(head_size, node_type); + } else { + // in this case, weights have shape + // (input_hidden_size, qkv_hidden_sizes[0] + qkv_hidden_sizes[1] + qkv_hidden_sizes[2]) + // so user specified hidden_sizes for Q, K and V + FRONT_END_GENERAL_CHECK(qkv_hidden_sizes.size() == 3, "qkv_hidden_sizes attribute needs to have 3 values"); + FRONT_END_GENERAL_CHECK(qkv_hidden_sizes[0] == qkv_hidden_sizes[1], + "qkv_hidden_sizes first element should be same as the second"); + // split the node into 3 parts Q, K, V with shapes + // Q: (batch_size, sequence_len, qkv_hidden_sizes[0]) + // K: (batch_size, sequence_len, qkv_hidden_sizes[1]) + // V: (batch_size, sequence_len, qkv_hidden_sizes[2]) + split = ov::op::util::make_split(node, qkv_hidden_sizes, 2); + // and reshape each part to new shape (batch_size, sequence_len, num_heads, head_size) + for (size_t i = 0; i < split.size(); i++) { + auto new_shape = std::make_shared( + ov::OutputVector{batch_size_seq_len, + num_heads_node, + v0::Constant::create(ov::element::i64, ov::Shape{1}, {qkv_hidden_sizes[i] / num_heads})}, + 0); + split[i] = std::make_shared(split[i], new_shape, false); + } + float head_size_val = qkv_hidden_sizes[0] > 0 ? static_cast(qkv_hidden_sizes[0]) / num_heads + : static_cast(qkv_hidden_sizes[2]) / num_heads; + head_size = v0::Constant::create(node_type, ov::Shape{1}, {head_size_val}); + } + + // transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size) + auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}); + auto Q = std::make_shared(split[0], perm); + auto K = std::make_shared(split[1], perm); + auto V = std::make_shared(split[2], perm); + + return {Q, K, V, head_size}; +} + + +} // namespace +} // namespace detail +} // namespace com_microsoft +} // namespace onnx +} // namespace frontend +} // namespace ov From 7eb07a600ea32bc5fea9b364fdd6f3c1459757f5 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Tue, 12 Nov 2024 15:29:41 +0000 Subject: [PATCH 2/6] Group broadcast is implemented via commonly used UBR-pattern --- .../com.microsoft/group_query_attention.cpp | 45 +++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index cca20a26b86743..946a8b8093d5fd 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -6,6 +6,8 @@ #include "core/operator_set.hpp" #include "openvino/frontend/exception.hpp" +// TODO: Filter out unused headers + #include "openvino/op/add.hpp" #include "openvino/op/broadcast.hpp" #include "openvino/op/concat.hpp" @@ -51,7 +53,7 @@ namespace com_microsoft { namespace detail { namespace { -// FIXME: Reuse the same function from file attention.cpp +// FIXME: Reuse the same function from file attention.cpp, but it requires a bit of adaptation -- I have redesigned part of the inputs a bit here and in the helper functions below ov::NodeVector split_to_QKV(const Output& node, int64_t num_heads, const std::vector& qkv_hidden_sizes); @@ -86,7 +88,6 @@ ov::Output rope(const Output& x, Output cos, Outpu cos = make_shared(cos, zero, seq_len, step, zero); sin = make_shared(sin, zero, seq_len, step, zero); - Output cos_multiplier = make_shared(OutputVector{cos, cos}, 1); OutputVector x_split = make_shared(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs(); Output res_0 = make_shared( @@ -102,6 +103,36 @@ ov::Output rope(const Output& x, Output cos, Outpu return make_shared(OutputVector{res_0, res_1}, -1); } +ov::Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) { + if(num_kv_heads == 1 || num_kv_heads == num_heads) { + // No broadcast or there is the broadcast that SDPA broadcastability can handle + return cache; + } + + OPENVINO_ASSERT(num_heads % num_kv_heads == 0); + const auto broadcast_multiplier = num_heads/num_kv_heads; + + auto unsqueeze = std::make_shared(cache, v0::Constant::create(element::i32, Shape{}, {2})); + auto shapeof = std::make_shared(cache, element::i32); + + auto broadcast_shape = std::make_shared(OutputVector{ + get_elements(shapeof, {0, 1}), + v0::Constant::create(element::i32, Shape{1}, {broadcast_multiplier}), + get_elements(shapeof, {2, 3}) + }, 0); + + auto broadcast = std::make_shared(unsqueeze, broadcast_shape); + + auto reshape_shape = std::make_shared(OutputVector{ + v0::Constant::create(element::i32, Shape{3}, {0, num_heads, -1}), + get_elements(shapeof, {3}) + }, 0); + + auto reshape = std::make_shared(broadcast, reshape_shape, true); + + return reshape; +} + } // namespace } // namespace detail @@ -142,12 +173,20 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { K = std::make_shared(ov::OutputVector{past_K, K}, 2); V = std::make_shared(ov::OutputVector{past_V, V}, 2); + const auto num_kv_heads = node.get_attribute_value("kv_num_heads"); + + K = detail::broadcast_groups(K, num_kv_heads, num_heads); + V = detail::broadcast_groups(V, num_kv_heads, num_heads); + // FIXME: Unaligned batch of sequences is not supported. - // That means all input sequence length should be the same and match input.shape[1] + // That means all input sequence lengths should be the same and match input.shape[2] // We do not check that here because it depends on runtime values. + // If we want to implement not aligned batch of dimensions we have to form not uniform causal mask for attention that + // adds a significant porition of the code. // FIXME: The same tensor at input/output of past/preset K and V are not supported. // It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that. + // Also inplace KV-cache modification logic is not supported efficiently in any plugins (CPU, GPU and NPU). auto output = std::make_shared(Q, K, V, true); From 6452b3d395fcc8583ce068ec993f813febf397a6 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Wed, 13 Nov 2024 15:38:23 +0000 Subject: [PATCH 3/6] Correct handling of kv-cache position range (with limitations and assumptions mentioned in the code). --- .../com.microsoft/group_query_attention.cpp | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 946a8b8093d5fd..c1b0dfac0cce14 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -68,7 +68,15 @@ ov::Output get_dimensions(const ov::Output& node, const std: return get_elements(std::make_shared(node), dims); } -ov::Output rope(const Output& x, Output cos, Output sin, bool interleaved, const Output& head_size, const Output& seq_len) { +ov::Output rope( + const Output& x, + Output cos, + Output sin, + bool interleaved, + const Output& head_size, + const Output& pos_id_begin, + const Output& pos_id_end +) { OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode using v1::Split; @@ -85,8 +93,8 @@ ov::Output rope(const Output& x, Output cos, Outpu Output step = Constant::create(element::i32, Shape{1}, {1}); // cut for the current sequence length - cos = make_shared(cos, zero, seq_len, step, zero); - sin = make_shared(sin, zero, seq_len, step, zero); + cos = make_shared(cos, pos_id_begin, pos_id_end, step, zero); + sin = make_shared(sin, pos_id_begin, pos_id_end, step, zero); OutputVector x_split = make_shared(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs(); @@ -161,14 +169,22 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { const auto& past_K = nodes[3]; const auto& past_V = nodes[4]; - + const auto& seqlens_k = nodes[5]; const auto& total_sequence_length = nodes[6]; const auto& cos = nodes[7]; const auto& sin = nodes[8]; const bool rope_interleaved = node.get_attribute_value("rotary_interleaved", 0); - Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, total_sequence_length); - K = detail::rope(K, cos, sin, rope_interleaved, head_size, total_sequence_length); + // FIXME: It works only when KV cache is dynamically growing and doesn't have unused space inside. So it is not compatible with statically-shaped KV cache. + // const auto past_seq_len = detail::get_dimensions(past_K, {0}); + // TODO: GQA spec is not compatible with test model. Spec supposes 1D tensor, in the test model we have 2D tensor, flattening to work in both cases. + + // FIXME: Unaligned elements in KV cache are not supported. + // We just get one of the seq lens as a common value for all past sequences + const auto& past_seq_len = detail::get_elements(std::make_shared(seqlens_k, v0::Constant::create(element::i32, Shape{1}, {-1}), false), {0}); + + Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); + K = detail::rope(K, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); K = std::make_shared(ov::OutputVector{past_K, K}, 2); V = std::make_shared(ov::OutputVector{past_V, V}, 2); @@ -178,7 +194,7 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { K = detail::broadcast_groups(K, num_kv_heads, num_heads); V = detail::broadcast_groups(V, num_kv_heads, num_heads); - // FIXME: Unaligned batch of sequences is not supported. + // FIXME: Unaligned batch of sequences is not supported. All past key-value are assumed to have the same length. // That means all input sequence lengths should be the same and match input.shape[2] // We do not check that here because it depends on runtime values. // If we want to implement not aligned batch of dimensions we have to form not uniform causal mask for attention that From 033ab9dc6ff6570a12a9a199239b00749395b002 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 14 Nov 2024 13:32:16 +0000 Subject: [PATCH 4/6] Fix warning --- .../frontend/src/op/com.microsoft/group_query_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index c1b0dfac0cce14..c35f303789483f 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -155,7 +155,7 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { const auto num_heads = node.get_attribute_value("num_heads"); if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) { - const auto split_result = detail::split_to_QKV(input, num_heads, {}); + const auto split_result = detail::split_to_QKV(input, static_cast(num_heads), {}); Q = split_result[0]; K = split_result[1]; V = split_result[2]; From 557a1ba47b3f17bc08caf34b59be5797137a975d Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Thu, 14 Nov 2024 14:58:39 +0000 Subject: [PATCH 5/6] Another attempt to fix the warning --- .../frontend/src/op/com.microsoft/group_query_attention.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index c35f303789483f..5b33b9306915b9 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -152,10 +152,10 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { ov::Output Q, K, V, head_size; - const auto num_heads = node.get_attribute_value("num_heads"); + const auto num_heads = static_cast(node.get_attribute_value("num_heads")); if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) { - const auto split_result = detail::split_to_QKV(input, static_cast(num_heads), {}); + const auto split_result = detail::split_to_QKV(input, num_heads, {}); Q = split_result[0]; K = split_result[1]; V = split_result[2]; From d33579dd6297ee33569a88f1b4f227e94dd44b31 Mon Sep 17 00:00:00 2001 From: Sergey Lyalin Date: Fri, 15 Nov 2024 15:16:42 +0000 Subject: [PATCH 6/6] Implemented update of statically shaped kv cache. --- .../com.microsoft/group_query_attention.cpp | 163 +++++++++++------- 1 file changed, 103 insertions(+), 60 deletions(-) diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 5b33b9306915b9..1c1f5dc2acec0e 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -31,6 +31,7 @@ #include "openvino/op/range.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/scatter_update.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/softmax.hpp" @@ -53,48 +54,56 @@ namespace com_microsoft { namespace detail { namespace { + +using v1::Split; +using v0::Constant; +using v1::Multiply; +using v1::Add; +using v8::Slice; +using v0::Concat; +using v1::Subtract; +using v3::ShapeOf; +using v3::Broadcast; +using v1::Reshape; +using v0::Unsqueeze; +using v4::Range; +using v3::ScatterUpdate; +using v15::Squeeze; +using Output = ov::Output; +using std::make_shared; + // FIXME: Reuse the same function from file attention.cpp, but it requires a bit of adaptation -- I have redesigned part of the inputs a bit here and in the helper functions below -ov::NodeVector split_to_QKV(const Output& node, +ov::NodeVector split_to_QKV(const Output& node, int64_t num_heads, const std::vector& qkv_hidden_sizes); -ov::Output get_elements(const ov::Output& shape, const std::vector& dims) { +Output get_elements(const Output& shape, const std::vector& dims) { static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); return std::make_shared(shape, dims_const, zero); } -ov::Output get_dimensions(const ov::Output& node, const std::vector& dims) { - return get_elements(std::make_shared(node), dims); +Output get_dimensions(const Output& node, const std::vector& dims) { + return get_elements(std::make_shared(node, element::i32), dims); } -ov::Output rope( - const Output& x, - Output cos, - Output sin, +Output rope( + const Output& x, + const Output& cos_cache, + const Output& sin_cache, bool interleaved, - const Output& head_size, - const Output& pos_id_begin, - const Output& pos_id_end + const Output& head_size, + const Output& pos_id_begin, + const Output& pos_id_end ) { OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode - using v1::Split; - using v0::Constant; - using v1::Multiply; - using v1::Add; - using v8::Slice; - using v0::Concat; - using v1::Subtract; - using Output = Output; - using std::make_shared; - Output zero = Constant::create(element::i32, Shape{1}, {0}); Output step = Constant::create(element::i32, Shape{1}, {1}); // cut for the current sequence length - cos = make_shared(cos, pos_id_begin, pos_id_end, step, zero); - sin = make_shared(sin, pos_id_begin, pos_id_end, step, zero); + Output cos = make_shared(cos_cache, pos_id_begin, pos_id_end, step, zero); + Output sin = make_shared(sin_cache, pos_id_begin, pos_id_end, step, zero); OutputVector x_split = make_shared(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs(); @@ -111,7 +120,7 @@ ov::Output rope( return make_shared(OutputVector{res_0, res_1}, -1); } -ov::Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) { +Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) { if(num_kv_heads == 1 || num_kv_heads == num_heads) { // No broadcast or there is the broadcast that SDPA broadcastability can handle return cache; @@ -120,41 +129,53 @@ ov::Output broadcast_groups(const Output& cache, const int n OPENVINO_ASSERT(num_heads % num_kv_heads == 0); const auto broadcast_multiplier = num_heads/num_kv_heads; - auto unsqueeze = std::make_shared(cache, v0::Constant::create(element::i32, Shape{}, {2})); - auto shapeof = std::make_shared(cache, element::i32); + auto unsqueeze = make_shared(cache, Constant::create(element::i32, Shape{}, {2})); + auto shapeof = make_shared(cache, element::i32); - auto broadcast_shape = std::make_shared(OutputVector{ + auto broadcast_shape = make_shared(OutputVector{ get_elements(shapeof, {0, 1}), - v0::Constant::create(element::i32, Shape{1}, {broadcast_multiplier}), + Constant::create(element::i32, Shape{1}, {broadcast_multiplier}), get_elements(shapeof, {2, 3}) }, 0); - auto broadcast = std::make_shared(unsqueeze, broadcast_shape); + auto broadcast = make_shared(unsqueeze, broadcast_shape); - auto reshape_shape = std::make_shared(OutputVector{ - v0::Constant::create(element::i32, Shape{3}, {0, num_heads, -1}), + auto reshape_shape = make_shared(OutputVector{ + Constant::create(element::i32, Shape{3}, {0, num_heads, -1}), get_elements(shapeof, {3}) }, 0); - auto reshape = std::make_shared(broadcast, reshape_shape, true); + auto reshape = make_shared(broadcast, reshape_shape, true); return reshape; } +Output concat_cache(const Output& past, const Output& current) { + return make_shared(ov::OutputVector{past, current}, 2); // 2 is the dimension index that corresponds to sequence len +} -} // namespace -} // namespace detail +Output squeeze_1d(const Output& x) { + return make_shared(x, Constant::create(element::i32, Shape{0}, {1})); +} -namespace opset_1 { -ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { - auto nodes = node.get_ov_inputs(); - const auto& input = nodes[0]; +Output update_cache(const Output& past, const Output& current, const Output& past_len) { + Output update_len = get_dimensions(current, {2}); + Output update_end = make_shared(past_len, update_len); + Output update_indices = make_shared(squeeze_1d(past_len), squeeze_1d(update_end), Constant::create(element::i32, Shape{}, {1}), element::i32); + return make_shared(past, update_indices, current, Constant::create(element::i32, Shape{1}, {2})); // 2 is the dimension index that corresponds to sequence len +} - ov::Output Q, K, V, head_size; +ov::OutputVector group_query_attention_decomposition( + const ov::OutputVector& inputs, + int num_heads, + bool rotary_interleaved, + int kv_num_heads +) { + const auto& input = inputs[0]; - const auto num_heads = static_cast(node.get_attribute_value("num_heads")); + Output Q, K, V, head_size; - if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) { + if(ov::op::util::is_null(inputs[1]) || ov::op::util::is_null(inputs[2])) { const auto split_result = detail::split_to_QKV(input, num_heads, {}); Q = split_result[0]; K = split_result[1]; @@ -162,37 +183,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { head_size = split_result[3]; } else { Q = input; - K = nodes[1]; - V = nodes[2]; + K = inputs[1]; + V = inputs[2]; head_size = detail::get_dimensions(Q, {-1}); } - const auto& past_K = nodes[3]; - const auto& past_V = nodes[4]; - const auto& seqlens_k = nodes[5]; - const auto& total_sequence_length = nodes[6]; - const auto& cos = nodes[7]; - const auto& sin = nodes[8]; - const bool rope_interleaved = node.get_attribute_value("rotary_interleaved", 0); + const auto& past_K = inputs[3]; + const auto& past_V = inputs[4]; + const auto& seqlens_k = inputs[5]; + const auto& total_sequence_length = inputs[6]; + const auto& cos = inputs[7]; + const auto& sin = inputs[8]; // FIXME: It works only when KV cache is dynamically growing and doesn't have unused space inside. So it is not compatible with statically-shaped KV cache. // const auto past_seq_len = detail::get_dimensions(past_K, {0}); // TODO: GQA spec is not compatible with test model. Spec supposes 1D tensor, in the test model we have 2D tensor, flattening to work in both cases. // FIXME: Unaligned elements in KV cache are not supported. - // We just get one of the seq lens as a common value for all past sequences + // We just get the first of the seq lens as a common value for all past sequences ignoring others, under assumption that they are all the same const auto& past_seq_len = detail::get_elements(std::make_shared(seqlens_k, v0::Constant::create(element::i32, Shape{1}, {-1}), false), {0}); - Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); - K = detail::rope(K, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); + Q = rope(Q, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length); + K = rope(K, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length); - K = std::make_shared(ov::OutputVector{past_K, K}, 2); - V = std::make_shared(ov::OutputVector{past_V, V}, 2); + if(past_K.get_partial_shape()[2].is_dynamic()) { + K = concat_cache(past_K, K); + } else { + K = update_cache(past_K, K, past_seq_len); + } - const auto num_kv_heads = node.get_attribute_value("kv_num_heads"); + if(past_V.get_partial_shape()[2].is_dynamic()) { + V = concat_cache(past_V, V); + } else { + V = update_cache(past_V, V, past_seq_len); + } - K = detail::broadcast_groups(K, num_kv_heads, num_heads); - V = detail::broadcast_groups(V, num_kv_heads, num_heads); + K = broadcast_groups(K, kv_num_heads, num_heads); + V = broadcast_groups(V, kv_num_heads, num_heads); // FIXME: Unaligned batch of sequences is not supported. All past key-value are assumed to have the same length. // That means all input sequence lengths should be the same and match input.shape[2] @@ -204,11 +231,27 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { // It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that. // Also inplace KV-cache modification logic is not supported efficiently in any plugins (CPU, GPU and NPU). - auto output = std::make_shared(Q, K, V, true); + auto output = make_shared(Q, K, V, true); return {output, K, V}; } + + +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + return detail::group_query_attention_decomposition( + node.get_ov_inputs(), + static_cast(node.get_attribute_value("num_heads")), + static_cast(node.get_attribute_value("rotary_interleaved", 0)), + static_cast(node.get_attribute_value("kv_num_heads")) + ); +} + ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + } // namespace opset_1 namespace detail { @@ -225,7 +268,7 @@ std::shared_ptr get_hidden_size(const std::shared_ptr& no return hidden_size; } -ov::NodeVector split_to_QKV(const Output& node, +ov::NodeVector split_to_QKV(const Output& node, int64_t num_heads, const std::vector& qkv_hidden_sizes) { ov::OutputVector split;