Skip to content

Commit

Permalink
Implemented update of statically shaped kv cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
slyalin committed Nov 15, 2024
1 parent 557a1ba commit d33579d
Showing 1 changed file with 103 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scaled_dot_product_attention.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/select.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/softmax.hpp"
Expand All @@ -53,48 +54,56 @@ namespace com_microsoft {
namespace detail {
namespace {


using v1::Split;
using v0::Constant;
using v1::Multiply;
using v1::Add;
using v8::Slice;
using v0::Concat;
using v1::Subtract;
using v3::ShapeOf;
using v3::Broadcast;
using v1::Reshape;
using v0::Unsqueeze;
using v4::Range;
using v3::ScatterUpdate;
using v15::Squeeze;
using Output = ov::Output<ov::Node>;
using std::make_shared;

// FIXME: Reuse the same function from file attention.cpp, but it requires a bit of adaptation -- I have redesigned part of the inputs a bit here and in the helper functions below
ov::NodeVector split_to_QKV(const Output<ov::Node>& node,
ov::NodeVector split_to_QKV(const Output& node,
int64_t num_heads,
const std::vector<int64_t>& qkv_hidden_sizes);

ov::Output<ov::Node> get_elements(const ov::Output<ov::Node>& shape, const std::vector<int>& dims) {
Output get_elements(const Output& shape, const std::vector<int>& dims) {
static const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims);
return std::make_shared<v8::Gather>(shape, dims_const, zero);
}

ov::Output<ov::Node> get_dimensions(const ov::Output<ov::Node>& node, const std::vector<int>& dims) {
return get_elements(std::make_shared<v3::ShapeOf>(node), dims);
Output get_dimensions(const Output& node, const std::vector<int>& dims) {
return get_elements(std::make_shared<v3::ShapeOf>(node, element::i32), dims);
}

ov::Output<ov::Node> rope(
const Output<ov::Node>& x,
Output<ov::Node> cos,
Output<ov::Node> sin,
Output rope(
const Output& x,
const Output& cos_cache,
const Output& sin_cache,
bool interleaved,
const Output<ov::Node>& head_size,
const Output<ov::Node>& pos_id_begin,
const Output<ov::Node>& pos_id_end
const Output& head_size,
const Output& pos_id_begin,
const Output& pos_id_end
) {
OPENVINO_ASSERT(!interleaved, "rotary_interleaved is not supported"); // TODO: Support interleaved mode

using v1::Split;
using v0::Constant;
using v1::Multiply;
using v1::Add;
using v8::Slice;
using v0::Concat;
using v1::Subtract;
using Output = Output<ov::Node>;
using std::make_shared;

Output zero = Constant::create(element::i32, Shape{1}, {0});
Output step = Constant::create(element::i32, Shape{1}, {1});

// cut for the current sequence length
cos = make_shared<Slice>(cos, pos_id_begin, pos_id_end, step, zero);
sin = make_shared<Slice>(sin, pos_id_begin, pos_id_end, step, zero);
Output cos = make_shared<Slice>(cos_cache, pos_id_begin, pos_id_end, step, zero);
Output sin = make_shared<Slice>(sin_cache, pos_id_begin, pos_id_end, step, zero);

OutputVector x_split = make_shared<Split>(x, Constant::create(element::i32, Shape{}, {-1}), 2)->outputs();

Expand All @@ -111,7 +120,7 @@ ov::Output<ov::Node> rope(
return make_shared<Concat>(OutputVector{res_0, res_1}, -1);
}

ov::Output<ov::Node> broadcast_groups(const Output<ov::Node>& cache, const int num_kv_heads, const int num_heads) {
Output broadcast_groups(const Output& cache, const int num_kv_heads, const int num_heads) {
if(num_kv_heads == 1 || num_kv_heads == num_heads) {
// No broadcast or there is the broadcast that SDPA broadcastability can handle
return cache;
Expand All @@ -120,79 +129,97 @@ ov::Output<ov::Node> broadcast_groups(const Output<ov::Node>& cache, const int n
OPENVINO_ASSERT(num_heads % num_kv_heads == 0);
const auto broadcast_multiplier = num_heads/num_kv_heads;

auto unsqueeze = std::make_shared<v0::Unsqueeze>(cache, v0::Constant::create(element::i32, Shape{}, {2}));
auto shapeof = std::make_shared<v3::ShapeOf>(cache, element::i32);
auto unsqueeze = make_shared<Unsqueeze>(cache, Constant::create(element::i32, Shape{}, {2}));
auto shapeof = make_shared<ShapeOf>(cache, element::i32);

auto broadcast_shape = std::make_shared<v0::Concat>(OutputVector{
auto broadcast_shape = make_shared<Concat>(OutputVector{
get_elements(shapeof, {0, 1}),
v0::Constant::create(element::i32, Shape{1}, {broadcast_multiplier}),
Constant::create(element::i32, Shape{1}, {broadcast_multiplier}),
get_elements(shapeof, {2, 3})
}, 0);

auto broadcast = std::make_shared<v3::Broadcast>(unsqueeze, broadcast_shape);
auto broadcast = make_shared<Broadcast>(unsqueeze, broadcast_shape);

auto reshape_shape = std::make_shared<v0::Concat>(OutputVector{
v0::Constant::create(element::i32, Shape{3}, {0, num_heads, -1}),
auto reshape_shape = make_shared<Concat>(OutputVector{
Constant::create(element::i32, Shape{3}, {0, num_heads, -1}),
get_elements(shapeof, {3})
}, 0);

auto reshape = std::make_shared<v1::Reshape>(broadcast, reshape_shape, true);
auto reshape = make_shared<Reshape>(broadcast, reshape_shape, true);

return reshape;
}

Output concat_cache(const Output& past, const Output& current) {
return make_shared<Concat>(ov::OutputVector{past, current}, 2); // 2 is the dimension index that corresponds to sequence len
}

} // namespace
} // namespace detail
Output squeeze_1d(const Output& x) {
return make_shared<Squeeze>(x, Constant::create(element::i32, Shape{0}, {1}));
}

namespace opset_1 {
ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
auto nodes = node.get_ov_inputs();
const auto& input = nodes[0];
Output update_cache(const Output& past, const Output& current, const Output& past_len) {
Output update_len = get_dimensions(current, {2});
Output update_end = make_shared<Add>(past_len, update_len);
Output update_indices = make_shared<Range>(squeeze_1d(past_len), squeeze_1d(update_end), Constant::create(element::i32, Shape{}, {1}), element::i32);
return make_shared<ScatterUpdate>(past, update_indices, current, Constant::create(element::i32, Shape{1}, {2})); // 2 is the dimension index that corresponds to sequence len
}

ov::Output<ov::Node> Q, K, V, head_size;
ov::OutputVector group_query_attention_decomposition(
const ov::OutputVector& inputs,
int num_heads,
bool rotary_interleaved,
int kv_num_heads
) {
const auto& input = inputs[0];

const auto num_heads = static_cast<int32_t>(node.get_attribute_value<int64_t>("num_heads"));
Output Q, K, V, head_size;

if(ov::op::util::is_null(nodes[1]) || ov::op::util::is_null(nodes[2])) {
if(ov::op::util::is_null(inputs[1]) || ov::op::util::is_null(inputs[2])) {
const auto split_result = detail::split_to_QKV(input, num_heads, {});
Q = split_result[0];
K = split_result[1];
V = split_result[2];
head_size = split_result[3];
} else {
Q = input;
K = nodes[1];
V = nodes[2];
K = inputs[1];
V = inputs[2];
head_size = detail::get_dimensions(Q, {-1});
}

const auto& past_K = nodes[3];
const auto& past_V = nodes[4];
const auto& seqlens_k = nodes[5];
const auto& total_sequence_length = nodes[6];
const auto& cos = nodes[7];
const auto& sin = nodes[8];
const bool rope_interleaved = node.get_attribute_value<int64_t>("rotary_interleaved", 0);
const auto& past_K = inputs[3];
const auto& past_V = inputs[4];
const auto& seqlens_k = inputs[5];
const auto& total_sequence_length = inputs[6];
const auto& cos = inputs[7];
const auto& sin = inputs[8];

// FIXME: It works only when KV cache is dynamically growing and doesn't have unused space inside. So it is not compatible with statically-shaped KV cache.
// const auto past_seq_len = detail::get_dimensions(past_K, {0});
// TODO: GQA spec is not compatible with test model. Spec supposes 1D tensor, in the test model we have 2D tensor, flattening to work in both cases.

// FIXME: Unaligned elements in KV cache are not supported.
// We just get one of the seq lens as a common value for all past sequences
// We just get the first of the seq lens as a common value for all past sequences ignoring others, under assumption that they are all the same
const auto& past_seq_len = detail::get_elements(std::make_shared<v1::Reshape>(seqlens_k, v0::Constant::create(element::i32, Shape{1}, {-1}), false), {0});

Q = detail::rope(Q, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length);
K = detail::rope(K, cos, sin, rope_interleaved, head_size, past_seq_len, total_sequence_length);
Q = rope(Q, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length);
K = rope(K, cos, sin, rotary_interleaved, head_size, past_seq_len, total_sequence_length);

K = std::make_shared<v0::Concat>(ov::OutputVector{past_K, K}, 2);
V = std::make_shared<v0::Concat>(ov::OutputVector{past_V, V}, 2);
if(past_K.get_partial_shape()[2].is_dynamic()) {
K = concat_cache(past_K, K);
} else {
K = update_cache(past_K, K, past_seq_len);
}

const auto num_kv_heads = node.get_attribute_value<int64_t>("kv_num_heads");
if(past_V.get_partial_shape()[2].is_dynamic()) {
V = concat_cache(past_V, V);
} else {
V = update_cache(past_V, V, past_seq_len);
}

K = detail::broadcast_groups(K, num_kv_heads, num_heads);
V = detail::broadcast_groups(V, num_kv_heads, num_heads);
K = broadcast_groups(K, kv_num_heads, num_heads);
V = broadcast_groups(V, kv_num_heads, num_heads);

// FIXME: Unaligned batch of sequences is not supported. All past key-value are assumed to have the same length.
// That means all input sequence lengths should be the same and match input.shape[2]
Expand All @@ -204,11 +231,27 @@ ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
// It requires more complex tensor manipulations that are introduce overhead into pure tensor-value data flow and should be implemented if we really have demand for that.
// Also inplace KV-cache modification logic is not supported efficiently in any plugins (CPU, GPU and NPU).

auto output = std::make_shared<v13::ScaledDotProductAttention>(Q, K, V, true);
auto output = make_shared<v13::ScaledDotProductAttention>(Q, K, V, true);

return {output, K, V};
}


} // namespace
} // namespace detail

namespace opset_1 {
ov::OutputVector group_query_attention(const ov::frontend::onnx::Node& node) {
return detail::group_query_attention_decomposition(
node.get_ov_inputs(),
static_cast<int>(node.get_attribute_value<int64_t>("num_heads")),
static_cast<bool>(node.get_attribute_value<int64_t>("rotary_interleaved", 0)),
static_cast<int>(node.get_attribute_value<int64_t>("kv_num_heads"))
);
}

ONNX_OP("GroupQueryAttention", OPSET_SINCE(1), com_microsoft::opset_1::group_query_attention, MICROSOFT_DOMAIN);

} // namespace opset_1

namespace detail {
Expand All @@ -225,7 +268,7 @@ std::shared_ptr<ov::Node> get_hidden_size(const std::shared_ptr<v3::ShapeOf>& no
return hidden_size;
}

ov::NodeVector split_to_QKV(const Output<ov::Node>& node,
ov::NodeVector split_to_QKV(const Output& node,
int64_t num_heads,
const std::vector<int64_t>& qkv_hidden_sizes) {
ov::OutputVector split;
Expand Down

0 comments on commit d33579d

Please sign in to comment.