Skip to content

Commit

Permalink
add paged-attetionv2: support seq length split across thread block
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries committed May 10, 2024
1 parent bfad393 commit 3fad8bf
Show file tree
Hide file tree
Showing 8 changed files with 606 additions and 131 deletions.
18 changes: 18 additions & 0 deletions colossalai/inference/flash_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def _reset(self):
self._tensors_initialized = False
del self._mid_output
del self._mid_output_lse
del self._exp_sums
del self._max_logits

@property
def is_initialized(self):
Expand All @@ -31,6 +33,16 @@ def mid_output_lse(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._mid_output_lse

@property
def exp_sums(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._exp_sums

@property
def max_logits(self):
assert self.is_initialized, "Intermediate tensors not initialized yet"
return self._max_logits

def initialize(
self,
max_batch_size: int,
Expand Down Expand Up @@ -60,5 +72,11 @@ def initialize(
self._mid_output_lse = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._exp_sums = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)
self._max_logits = torch.empty(
size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device
)

self._tensors_initialized = True
3 changes: 2 additions & 1 deletion colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ def forward(
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
self.alibi_slopes,
sm_scale,
)
Expand Down
3 changes: 2 additions & 1 deletion colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ def forward(
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
fd_inter_tensor.exp_sums,
fd_inter_tensor.max_logits,
None,
sm_scale,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def benchmark_flash_decoding_attention(
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
max_logits = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)

if provider == "vllm_paged_decoding_attention":
alibi_slopes = None
Expand Down Expand Up @@ -166,7 +168,8 @@ def benchmark_flash_decoding_attention(
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
exp_sums,
max_logits,
alibi_slopes,
sm_scale,
)
Expand Down
552 changes: 497 additions & 55 deletions extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions extensions/csrc/kernel/cuda/fused_rotary_emb_and_cache_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ __device__ void apply_emb_rotary_compute(
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMul> mul;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kMinus> sub;
BinaryOpFunctor<MT, MT, MT, BinaryOpType::kAdd> add;
CastFunctor<T, MT> t2mt;
CastFunctor<MT, T> mt2t;

T x[VecSize];
T y[VecSize];
Expand All @@ -44,10 +46,10 @@ __device__ void apply_emb_rotary_compute(

#pragma unroll
for (int j = 0; j < VecSize; j++) {
out_x[j] = CastFunctor<MT, T>()(sub(mul(CastFunctor<T, MT>()(x[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(y[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = CastFunctor<MT, T>()(add(mul(CastFunctor<T, MT>()(y[j]), cos_ptr[j * 32 + shard_offset]),
mul(CastFunctor<T, MT>()(x[j]), sin_ptr[j * 32 + shard_offset])));
out_x[j] = mt2t(sub(mul(t2mt(x[j]), cos_ptr[j * 32 + shard_offset]),
mul(t2mt(y[j]), sin_ptr[j * 32 + shard_offset])));
out_y[j] = mt2t(add(mul(t2mt(y[j]), cos_ptr[j * 32 + shard_offset]),
mul(t2mt(x[j]), sin_ptr[j * 32 + shard_offset])));
}

copy<T, VecSize>(out_x, src + addr_offset);
Expand Down
3 changes: 2 additions & 1 deletion extensions/pybind/inference/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ void flash_decoding_attention(
int block_size, int max_context_len,
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
torch::Tensor& exp_sums, // [num_tokens, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes, float scale);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
143 changes: 75 additions & 68 deletions tests/test_infer/test_kernels/cuda/test_flash_decoding_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)

q_len = 1
PARTITION_SIZE = 512


def prepare_data(
Expand Down Expand Up @@ -57,7 +58,7 @@ def numpy_allclose(x, y, rtol, atol):

@pytest.mark.parametrize("BATCH_SIZE", [1, 4, 7, 32])
@pytest.mark.parametrize("BLOCK_SIZE", [8, 16, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32])
@pytest.mark.parametrize("MAX_NUM_BLOCKS_PER_SEQ", [1, 8, 32, 256, 512])
@pytest.mark.parametrize("HEAD_SIZE", [64, 128])
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
Expand All @@ -76,81 +77,86 @@ def test_flash_decoding_attention(
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()

if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None

q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)

k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)
try:
if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None

block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)

k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)
k_cache, v_cache, block_tables = generate_caches_and_block_tables_v3(
k_unpad, v_unpad, kv_seq_lengths, BATCH_SIZE, MAX_NUM_BLOCKS_PER_SEQ, BLOCK_SIZE, dtype, device
)

if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask
block_tables = block_tables.to(device=device)
max_seq_len_across_batch = kv_seq_lengths.max().item()
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)

if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]
k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)

mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
mid_output_lse = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)
if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask

if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3
if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]

high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)
mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
exp_sums = torch.empty(size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device)
max_logits = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num), dtype=torch.float32, device=device
)

else:
rtol = 1e-5
atol = 1e-7
if dtype == torch.float16:
rtol = 1e-3
atol = 1e-3

high_precision_q = q.to(torch.float32)
high_precision_k_torch = k_torch.to(torch.float32)
high_precision_v_torch = v_torch.to(torch.float32)
out_ref = torch_attn_ref(
high_precision_q,
high_precision_k_torch,
high_precision_v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
).to(torch.float16)

out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)
else:
rtol = 1e-5
atol = 1e-7

out_ref = torch_attn_ref(
q,
k_torch,
v_torch,
torch_padding_mask,
BATCH_SIZE,
q_len,
max_seq_len_across_batch,
NUM_ATTN_HEADS,
NUM_KV_HEADS,
HEAD_SIZE,
)

except torch.cuda.OutOfMemoryError:
pytest.skip("Required GPU memory is larger than capacity.")

inference_ops.flash_decoding_attention(
output,
Expand All @@ -162,7 +168,8 @@ def test_flash_decoding_attention(
BLOCK_SIZE,
max_seq_len_across_batch,
mid_output,
mid_output_lse,
exp_sums,
max_logits,
alibi_slopes,
sm_scale,
)
Expand Down

0 comments on commit 3fad8bf

Please sign in to comment.