Skip to content

Commit

Permalink
add apc
Browse files Browse the repository at this point in the history
Signed-off-by: MengqingCao <[email protected]>
  • Loading branch information
MengqingCao committed Dec 16, 2024
1 parent 2b92b5c commit 93bb53c
Showing 1 changed file with 95 additions and 32 deletions.
127 changes: 95 additions & 32 deletions vllm/attention/backends/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ class AscendMetadata(AttentionMetadata, PagedAttentionMetadata):
pse_shift: Optional[torch.Tensor] = None
sparse_mode: int = 0

# Cross-attention memory-mapping data structures: slot mapping
# and block tables
cross_slot_mapping: Optional[torch.Tensor] = None
cross_block_tables: Optional[torch.Tensor] = None

# slot_mapping: Optional[torch.Tensor] = None

@property
Expand All @@ -207,6 +212,8 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
Expand All @@ -224,6 +231,7 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
Expand All @@ -232,7 +240,9 @@ def prefill_metadata(self) -> Optional["AscendMetadata"]:
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps)
multi_modal_placeholder_index_maps,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_prefill_metadata

@property
Expand Down Expand Up @@ -264,14 +274,25 @@ def decode_metadata(self) -> Optional["AscendMetadata"]:
seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
# Begin encoder & cross attn fields below...
encoder_seq_lens=self.encoder_seq_lens,
encoder_seq_lens_tensor=self.encoder_seq_lens_tensor,
max_encoder_seq_len=self.max_encoder_seq_len,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps)
multi_modal_placeholder_index_maps,
cross_slot_mapping=self.cross_slot_mapping,
cross_block_tables=self.cross_block_tables)
return self._cached_decode_metadata


Expand Down Expand Up @@ -332,7 +353,6 @@ def _add_seq_group(

is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
computed_block_nums = inter_data.computed_block_nums

for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block) in zip(
Expand All @@ -357,11 +377,21 @@ def _add_seq_group(
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if inter_data.prefix_cache_hit:
block_table = computed_block_nums
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
Expand Down Expand Up @@ -468,33 +498,66 @@ def forward(
seq_len=attn_metadata.max_prefill_seq_len,
batch_size=num_tokens,
)

# shape of q/k/v [B,S*H] --> [B,S,N,D]
query = query.view(-1, attn_metadata.max_prefill_seq_len,
self.num_heads, self.head_size).transpose(1, 2)
key = key.view(-1, attn_metadata.max_prefill_seq_len,
self.num_kv_heads, self.head_size).transpose(1, 2)
value = value.view(-1, attn_metadata.max_prefill_seq_len,
self.num_kv_heads,
if (len(kv_cache) == 0 or attn_metadata.block_tables is None
or attn_metadata.block_tables.numel() == 0):
max_seq_len = attn_metadata.max_prefill_seq_len

# shape of q/k/v [B,S*H] --> [B,S,N,D]
query = query.view(-1, max_seq_len, self.num_heads,
self.head_size).transpose(1, 2)
key = key.view(-1, max_seq_len, self.num_kv_heads,
self.head_size).transpose(1, 2)

# FA for prefill phase
output = torch_npu.npu_prompt_flash_attention(
query,
key,
value,
pse_shift=attn_metadata.pse_shift,
atten_mask=attn_metadata.attn_mask,
num_heads=self.num_heads,
scale_value=1 / math.sqrt(self.head_size),
input_layout="BNSD",
num_key_value_heads=self.num_kv_heads,
pre_tokens=65535,
next_tokens=0,
sparse_mode=attn_metadata.sparse_mode,
)
output = output.transpose(1, 2).reshape(
num_tokens, self.num_heads * self.head_size)
value = value.view(-1, max_seq_len, self.num_kv_heads,
self.head_size).transpose(1, 2)
# FA for prefill phase
output = torch_npu.npu_prompt_flash_attention(
query,
key,
value,
pse_shift=attn_metadata.pse_shift,
atten_mask=attn_metadata.attn_mask,
num_heads=self.num_heads,
scale_value=1 / math.sqrt(self.head_size),
input_layout="BNSD",
num_key_value_heads=self.num_kv_heads,
pre_tokens=65535,
next_tokens=0,
sparse_mode=attn_metadata.sparse_mode,
)
# reshape to [B,H]
output = output.transpose(1, 2).reshape(
num_tokens, self.num_heads * self.head_size)
else:
# prefix-enabled attention
assert attn_type == AttentionType.DECODER, (
"Only decoder-only models support prefix caching")
assert attn_metadata.seq_lens is not None
assert kv_cache is not None
query = query.view(query.shape[0], -1,
self.num_heads * self.head_size)
output = torch.zeros(query.shape,
device="npu",
dtype=query.dtype)
# TODO (Mengqing Cao): torch_npu.npu_incre_flash_attention
# support only when `S == 1`, OPTIMIZE ME when prefix caching
# is supported in torch-npu ops.
for i in range(query.shape[0]):
# FA for prefill phase
output[i] = torch_npu.npu_incre_flash_attention(
query[i].unsqueeze(0),
key_cache,
value_cache,
num_heads=self.num_heads,
num_key_value_heads=self.num_kv_heads,
scale_value=self.scale,
input_layout="BSH",
block_table=attn_metadata.block_tables,
block_size=key_cache.
shape[1], # max val of block_size == 512
actual_seq_lengths=attn_metadata.seq_lens,
)
# [B,S,H] --> [B,H]
output = output.squeeze(1)

elif attn_metadata.decode_metadata:
# FA for decoding phase
Expand Down

0 comments on commit 93bb53c

Please sign in to comment.