Skip to content

Commit

Permalink
[V1] Optimize the CPU overheads in FlashAttention custom op (vllm-pro…
Browse files Browse the repository at this point in the history
…ject#10733)

Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored and BKitor committed Dec 30, 2024
1 parent f5aa6cf commit 252ba9e
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
# overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

output = torch.empty_like(query)
torch.ops.vllm.unified_v1_flash_attention(
output,
Expand All @@ -153,7 +160,7 @@ def forward(
self.alibi_slopes,
self.logits_soft_cap,
)
return output
return output.view(-1, self.num_heads * self.head_size)


def unified_v1_flash_attention(
Expand Down Expand Up @@ -184,11 +191,6 @@ def unified_v1_flash_attention(
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens

# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)

# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
Expand Down Expand Up @@ -218,8 +220,7 @@ def unified_v1_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
# TODO(woosuk): Remove this unnecessary copy.
output[:num_actual_tokens].copy_(attn_output)


Expand Down

0 comments on commit 252ba9e

Please sign in to comment.