diff --git a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp index 5b33b9306915b9..1c1f5dc2acec0e 100644 --- a/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp +++ b/src/frontends/onnx/frontend/src/op/com.microsoft/group_query_attention.cpp @@ -31,6 +31,7 @@ #include "openvino/op/range.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/scaled_dot_product_attention.hpp" +#include "openvino/op/scatter_update.hpp" #include "openvino/op/select.hpp" #include "openvino/op/shape_of.hpp" #include "openvino/op/softmax.hpp" @@ -53,48 +54,56 @@ namespace com_microsoft { namespace detail { namespace { + +using v1::Split; +using v0::Constant; +using v1::Multiply; +using v1::Add; +using v8::Slice; +using v0::Concat; +using v1::Subtract; +using v3::ShapeOf; +using v3::Broadcast; +using v1::Reshape; +using v0::Unsqueeze; +using v4::Range; +using v3::ScatterUpdate; +using v15::Squeeze; +using Output = ov::Output; +using std::make_shared; + // FIXME: Reuse the same function from file attention.cpp, but it requires a bit of adaptation -- I have redesigned part of the inputs a bit here and in the helper functions below -ov::NodeVector split_to_QKV(const Output& node, +ov::NodeVector split_to_QKV(const Output& node, int64_t num_heads, const std::vector& qkv_hidden_sizes); -ov::Output get_elements(const ov::Output& shape, const std::vector& dims) { +Output get_elements(const Output& shape, const std::vector& dims) { static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); return std::make_shared(shape, dims_const, zero); } -ov::Output get_dimensions(const ov::Output& node, const std::vector& dims) { - return get_elements(std::make_shared(node), dims); +Output get_dimensions(const Output& node, const std::vector& dims) { + return get_elements(std::make_shared(node, element::i32), dims); } -ov::Output rope( - const Output& x, - Output cos, - Output sin, +Output rope( + const Output& x, + const Output& cos_cache, + const Output& sin_cache, bool interleaved, - const Output& head_size, - const Output& pos_id_begin, - const Output& pos_id_end + const Output& head_size, + const Output& pos_id_begin, + const Output& pos_id_end ) { OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode - using v1::Split; - using v0::Constant; - using v1::Multiply; - using v1::Add; - using v8::Slice; - using v0::Concat; - using v1::Subtract; - using Output = Output; - using std::make_shared; - Output zero = Constant::create(element::i32, Shape{1}, {0}); Output step = Constant::create(element::i32, Shape{1}, {1}); // cut for the current sequence length - cos = make_shared(cos, pos_id_begin, pos_id_end, step, zero); - sin = make_shared(sin, pos_id_begin, pos_id_end, step, zero); + Output cos = make_shared(cos_cache, pos_id_begin, pos_id_end, step, zero); + Output sin = make_shared(sin_cache, pos_id_begin, pos_id_end, step, zero); OutputVector x_split = make_shared(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs(); @@ -111,7 +120,7 @@ ov::Output rope( return make_shared(OutputVector{res_0, res_1}, -1); } -ov::Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) { +Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) { if(num_kv_heads == 1 || num_kv_heads == num_heads) { // No broadcast or there is the broadcast that SDPA broadcastability can handle return cache; @@ -120,41 +129,53 @@ ov::Output broadcast_groups(const Output& cache, const int n OPENVINO_ASSERT(num_heads % num_kv_heads == 0); const auto broadcast_multiplier = num_heads/num_kv_heads; - auto unsqueeze = std::make_shared(cache, v0::Constant::create(element::i32, Shape{}, {2})); - auto shapeof = std::make_shared(cache, element::i32); + auto unsqueeze = make_shared(cache, Constant::create(element::i32, Shape{}, {2})); + auto shapeof = make_shared(cache, element::i32); - auto broadcast_shape = std::make_shared(OutputVector{ + auto broadcast_shape = make_shared(OutputVector{ get_elements(shapeof, {0, 1}), - v0::Constant::create(element::i32, Shape{1}, {broadcast_multiplier}), + Constant::create(element::i32, Shape{1}, {broadcast_multiplier}), get_elements(shapeof, {2, 3}) }, 0); - auto broadcast = std::make_shared(unsqueeze, broadcast_shape); + auto broadcast = make_shared(unsqueeze, broadcast_shape); - auto reshape_shape = std::make_shared(OutputVector{ - v0::Constant::create(element::i32, Shape{3}, {0, num_heads, -1}), + auto reshape_shape = make_shared(OutputVector{ + Constant::create(element::i32, Shape{3}, {0, num_heads, -1}), get_elements(shapeof, {3}) }, 0); - auto reshape = std::make_shared(broadcast, reshape_shape, true); + auto reshape = make_shared(broadcast, reshape_shape, true); return reshape; } +Output concat_cache(const Output& past, const Output& current) { + return make_shared(ov::OutputVector{past, current}, 2); // 2 is the dimension index that corresponds to sequence len +} -} // namespace -} // namespace detail +Output squeeze_1d(const Output& x) { + return make_shared(x, Constant::create(element::i32, Shape{0}, {1})); +} -namespace opset_1 { -ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { - auto nodes = node.get_ov_inputs(); - const auto& input = nodes[0]; +Output update_cache(const Output& past, const Output& current, const Output& past_len) { + Output update_len = get_dimensions(current, {2}); + Output update_end = make_shared(past_len, update_len); + Output update_indices = make_shared(squeeze_1d(past_len), squeeze_1d(update_end), Constant::create(element::i32, Shape{}, {1}), element::i32); + return make_shared(past, update_indices, current, Constant::create(element::i32, Shape{1}, {2})); // 2 is the dimension index that corresponds to sequence len +} - ov::Output Q, K, V, head_size; +ov::OutputVector group_query_attention_decomposition( + const ov::OutputVector& inputs, + int num_heads, + bool rotary_interleaved, + int kv_num_heads +) { + const auto& input = inputs[0]; - const auto num_heads = static_cast(node.get_attribute_value("num_heads")); + Output Q, K, V, head_size; - if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) { + if(ov::op::util::is_null(inputs[1]) || ov::op::util::is_null(inputs[2])) { const auto split_result = detail::split_to_QKV(input, num_heads, {}); Q = split_result[0]; K = split_result[1]; @@ -162,37 +183,43 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { head_size = split_result[3]; } else { Q = input; - K = nodes[1]; - V = nodes[2]; + K = inputs[1]; + V = inputs[2]; head_size = detail::get_dimensions(Q, {-1}); } - const auto& past_K = nodes[3]; - const auto& past_V = nodes[4]; - const auto& seqlens_k = nodes[5]; - const auto& total_sequence_length = nodes[6]; - const auto& cos = nodes[7]; - const auto& sin = nodes[8]; - const bool rope_interleaved = node.get_attribute_value("rotary_interleaved", 0); + const auto& past_K = inputs[3]; + const auto& past_V = inputs[4]; + const auto& seqlens_k = inputs[5]; + const auto& total_sequence_length = inputs[6]; + const auto& cos = inputs[7]; + const auto& sin = inputs[8]; // FIXME: It works only when KV cache is dynamically growing and doesn't have unused space inside. So it is not compatible with statically-shaped KV cache. // const auto past_seq_len = detail::get_dimensions(past_K, {0}); // TODO: GQA spec is not compatible with test model. Spec supposes 1D tensor, in the test model we have 2D tensor, flattening to work in both cases. // FIXME: Unaligned elements in KV cache are not supported. - // We just get one of the seq lens as a common value for all past sequences + // We just get the first of the seq lens as a common value for all past sequences ignoring others, under assumption that they are all the same const auto& past_seq_len = detail::get_elements(std::make_shared(seqlens_k, v0::Constant::create(element::i32, Shape{1}, {-1}), false), {0}); - Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); - K = detail::rope(K, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length); + Q = rope(Q, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length); + K = rope(K, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length); - K = std::make_shared(ov::OutputVector{past_K, K}, 2); - V = std::make_shared(ov::OutputVector{past_V, V}, 2); + if(past_K.get_partial_shape()[2].is_dynamic()) { + K = concat_cache(past_K, K); + } else { + K = update_cache(past_K, K, past_seq_len); + } - const auto num_kv_heads = node.get_attribute_value("kv_num_heads"); + if(past_V.get_partial_shape()[2].is_dynamic()) { + V = concat_cache(past_V, V); + } else { + V = update_cache(past_V, V, past_seq_len); + } - K = detail::broadcast_groups(K, num_kv_heads, num_heads); - V = detail::broadcast_groups(V, num_kv_heads, num_heads); + K = broadcast_groups(K, kv_num_heads, num_heads); + V = broadcast_groups(V, kv_num_heads, num_heads); // FIXME: Unaligned batch of sequences is not supported. All past key-value are assumed to have the same length. // That means all input sequence lengths should be the same and match input.shape[2] @@ -204,11 +231,27 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { // It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that. // Also inplace KV-cache modification logic is not supported efficiently in any plugins (CPU, GPU and NPU). - auto output = std::make_shared(Q, K, V, true); + auto output = make_shared(Q, K, V, true); return {output, K, V}; } + + +} // namespace +} // namespace detail + +namespace opset_1 { +ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) { + return detail::group_query_attention_decomposition( + node.get_ov_inputs(), + static_cast(node.get_attribute_value("num_heads")), + static_cast(node.get_attribute_value("rotary_interleaved", 0)), + static_cast(node.get_attribute_value("kv_num_heads")) + ); +} + ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN); + } // namespace opset_1 namespace detail { @@ -225,7 +268,7 @@ std::shared_ptr get_hidden_size(const std::shared_ptr& no return hidden_size; } -ov::NodeVector split_to_QKV(const Output& node, +ov::NodeVector split_to_QKV(const Output& node, int64_t num_heads, const std::vector& qkv_hidden_sizes) { ov::OutputVector split;