Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support KV-Compress paged KV cache #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 80 additions & 22 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ void set_params_fprop(Flash_fwd_params &params,
int window_size_right,
const float softcap,
bool seqlenq_ngroups_swapped=false,
const bool unpadded_lse=false) {
const bool unpadded_lse=false,
const bool is_kvc=false) {

// Reset the parameters
params = {};
Expand All @@ -59,11 +60,19 @@ void set_params_fprop(Flash_fwd_params &params,
params.v_ptr = v.data_ptr();
// All stride are in elements, not bytes.
params.q_row_stride = q.stride(-3);
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.q_head_stride = q.stride(-2);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
params.is_kvc_cache = is_kvc;
if (!is_kvc) {
params.k_row_stride = k.stride(-3);
params.v_row_stride = v.stride(-3);
params.k_head_stride = k.stride(-2);
params.v_head_stride = v.stride(-2);
} else {
params.k_row_stride = k.stride(1);
params.v_row_stride = v.stride(1);
// head stride not used
}

params.o_ptr = out.data_ptr();
params.o_row_stride = out.stride(-3);
params.o_head_stride = out.stride(-2);
Expand Down Expand Up @@ -502,6 +511,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const bool is_KVC = paged_KV && (block_table.dim() > 2);


TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand All @@ -514,11 +525,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = paged_KV ? (!is_KVC ? k.size(2): block_table.size(1)) : k.size(1);

if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
Expand Down Expand Up @@ -554,13 +565,29 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
if (!is_KVC) {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, head_size_og);
CHECK_SHAPE(v, num_blocks, page_block_size, head_size_og);
// [ batch_size, kv_heads, blocks ]
// printf("batch_size=%d, num_heads_k=%d, max_num_blocks_per_seq=%d",
// batch_size, num_heads_k, max_num_blocks_per_seq);
// std::cout << "block_tables shape\n" << block_table.sizes() << std::endl;
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
}
}

bool seqlen_by_head = false;
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
if (!is_KVC) {
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
} else {
seqlen_by_head = cu_seqlens_k.size(0) > batch_size + 1;
// CHECK_SHAPE(cu_seqlens_k, batch_size + 1 + batch_size * num_heads_k + 1);
}
if (seqused_k.has_value()){
auto seqused_k_ = seqused_k.value();
TORCH_CHECK(seqused_k_.dtype() == torch::kInt32, "seqused_k must have dtype int32");
Expand Down Expand Up @@ -639,12 +666,22 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
window_size_right,
softcap,
seqlenq_ngroups_swapped,
/*unpadded_lse*/true);
/*unpadded_lse*/true,
/*is_kvc*/is_KVC);
params.total_q = total_q;
params.seqlen_by_head = seqlen_by_head;

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
if (!is_KVC) {
params.block_table_batch_stride = block_table.stride(0);
} else {
params.kseqlen_batch_stride = num_heads_k;
params.block_table_batch_stride = block_table.stride(0);
params.block_table_head_stride = block_table.stride(1);
}
// std::cout << "\n" << k_padded.strides() << std::endl;
// std::cout << k_padded.sizes() << std::endl;
params.k_batch_stride = k_padded.stride(0);
params.v_batch_stride = v_padded.stride(0);
}
Expand Down Expand Up @@ -759,6 +796,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
}
const bool is_KVC = paged_KV && (block_table.dim() > 2);

const auto sizes = q.sizes();

Expand All @@ -769,12 +807,12 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int num_heads_og = num_heads;
const int head_size_og = sizes[3];

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int max_num_blocks_per_seq = !paged_KV ? 0 : (!is_KVC ? block_table.size(1) : block_table.size(2));
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
TORCH_CHECK(!paged_KV || page_block_size % 16 == 0, "Paged KV cache block size must be divisible by 16");
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
const int num_heads_k = kcache.size(2);
const int num_heads_k = !is_KVC ? kcache.size(2) : block_table.size(1);
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
Expand Down Expand Up @@ -802,9 +840,16 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
if (!is_KVC) {
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
} else {
CHECK_SHAPE(kcache, num_blocks, page_block_size, head_size_og);
CHECK_SHAPE(vcache, num_blocks, page_block_size, head_size_og);
// [ batch_size, kv_heads, blocks ]
CHECK_SHAPE(block_table, batch_size, num_heads_k, max_num_blocks_per_seq);
}
}

at::Tensor q_padded, kcache_padded, vcache_padded;
Expand Down Expand Up @@ -865,8 +910,10 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
softmax_scale,
window_size_left,
window_size_right,
softcap
);
softcap,
/*seqlenq_ngroups_swapped=*/false,
/*unpadded_lse=*/false,
/*is_kvc=*/is_KVC);

at::Tensor k, v, k_padded, v_padded;
if (k_.has_value()) {
Expand Down Expand Up @@ -907,8 +954,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32");
CHECK_DEVICE(seqlens_k);
CHECK_CONTIGUOUS(seqlens_k);
CHECK_SHAPE(seqlens_k, batch_size);
if (!is_KVC) {
CHECK_SHAPE(seqlens_k, batch_size);
} else {
CHECK_SHAPE(seqlens_k, batch_size * num_heads_k);
}
params.cu_seqlens_k = static_cast<int *>(seqlens_k.data_ptr());
params.seqlen_by_head = is_KVC;
}
params.is_seqlens_k_cumulative = !(seqlens_k_.has_value());

Expand Down Expand Up @@ -954,7 +1006,13 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he

if (paged_KV) {
params.block_table = block_table.data_ptr<int>();
params.block_table_batch_stride = block_table.stride(0);
if (!is_KVC) {
params.block_table_batch_stride = block_table.stride(0);
} else {
params.kseqlen_batch_stride = num_heads_k;
params.block_table_batch_stride = block_table.stride(0);
params.block_table_head_stride = block_table.stride(1);
}
}
params.page_block_size = page_block_size;

Expand Down
19 changes: 16 additions & 3 deletions csrc/flash_attn/src/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,26 @@ template<bool Varlen=true>
struct BlockInfo {

template<typename Params>
__device__ BlockInfo(const Params &params, const int bidb)
__device__ BlockInfo(const Params &params, const int bidb, const int bidkh)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : (
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb]) :
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh])
))
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (
params.is_seqlens_k_cumulative ?
(
((bidkh < 0) || !params.seqlen_by_head) ? (params.cu_seqlens_k[bidb + 1] - sum_s_k) :
(params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh + 1] - sum_s_k)
) :
(
((bidkh < 0) || !params.seqlen_by_head) ? params.cu_seqlens_k[bidb] :
params.cu_seqlens_k[bidb * params.kseqlen_batch_stride + bidkh]
)
))
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
{
}
Expand Down
4 changes: 4 additions & 0 deletions csrc/flash_attn/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ struct Flash_fwd_params : public Qkv_params {

// Paged KV cache
int * __restrict__ block_table;
bool is_kvc_cache;
bool seqlen_by_head;
index_t kseqlen_batch_stride;
index_t block_table_batch_stride;
index_t block_table_head_stride;
int page_block_size;

// The dropout probability (probability of keeping an activation).
Expand Down
Loading