From d42650f83b90acf3c75c4bcf751c4fa22137eb94 Mon Sep 17 00:00:00 2001 From: Tanner Voas Date: Thu, 24 Oct 2024 09:47:23 +0000 Subject: [PATCH] Resolved alibi bias issue due to porting flat PA pr Signed-off-by: Tanner Voas --- vllm/attention/backends/hpu_attn.py | 29 ++++++++++++++----- vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/worker/hpu_model_runner.py | 43 ++++++++++++++++++++++++---- 3 files changed, 61 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 8f16081e2e2b5..7510604a0d657 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -118,6 +118,7 @@ def __init__( alibi_slopes_tensor = torch.tensor(alibi_slopes, dtype=torch.bfloat16) self.alibi_slopes = alibi_slopes_tensor + self.max_seq_len = max_seq_len assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -198,7 +199,7 @@ def forward( if self.alibi_slopes is not None: position_bias = _make_alibi_bias( self.alibi_slopes, self.num_kv_heads, - attn_bias.dtype, attn_bias.shape[-1]) + self.alibi_slopes.dtype, attn_bias.shape[-1]) attn_bias = attn_bias.tile( (1, self.num_kv_heads, 1, 1)) attn_bias.add_(position_bias) @@ -235,6 +236,17 @@ def forward( output = out.reshape(batch_size, seq_len, hidden_size) else: # Decoding run. + self.position_bias = None + attn_bias = attn_metadata.attn_bias + if self.alibi_slopes is not None: + self.position_bias = _make_alibi_bias( + self.alibi_slopes, + self.num_kv_heads, + self.alibi_slopes.dtype, + self.max_seq_len if self.max_seq_len is not None else + attn_bias.shape[-1] * attn_metadata.block_tables.shape[-1] + ) + output = HPUPagedAttention.forward_decode( query=query, key_cache=key_cache, @@ -245,10 +257,12 @@ def forward( block_scales=attn_metadata.block_scales, block_groups=attn_metadata.block_groups, scale=self.scale, + alibi_slopes=self.position_bias, matmul_qk_op=self.matmul_qk, matmul_av_op=self.matmul_av, keys_fetch_func=self.k_cache.fetch_from_cache, - values_fetch_func=self.v_cache.fetch_from_cache) + values_fetch_func=self.v_cache.fetch_from_cache, + ) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -271,15 +285,16 @@ def _make_alibi_bias( padded_len = (seq_len + 7) // 8 * 8 num_heads = alibi_slopes.shape[0] - bias = torch.empty( + per_head_bias = torch.empty( 1, # batch size num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, - )[:, :, :, :seq_len].copy_(bias) - bias.mul_(alibi_slopes[:, None, None]) + )[:, :, :, :seq_len] + per_head_bias[:, :] = bias + per_head_bias.mul_(alibi_slopes[:, None, None]) if num_heads != num_kv_heads: - bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) - return bias + per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + return per_head_bias diff --git a/vllm/attention/ops/hpu_paged_attn.py b/vllm/attention/ops/hpu_paged_attn.py index e55a4de11fd6c..1678bf45adb71 100644 --- a/vllm/attention/ops/hpu_paged_attn.py +++ b/vllm/attention/ops/hpu_paged_attn.py @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata: block_offsets: Optional[torch.Tensor] block_scales: Optional[torch.Tensor] block_groups: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] class HPUPagedAttention: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index fec5f3d01cff8..35099008bd25e 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -979,6 +979,14 @@ def _prepare_prompt( block_indices, block_offsets = precompute_indices_and_offsets( self.block_size, slot_mapping, True) + + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) + block_tables = make_tensor_with_pad(prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( is_prompt=True, block_list=prefix_block_list_tensor, @@ -995,6 +1003,7 @@ def _prepare_prompt( num_prefill_tokens=sum_query_len, num_decode_tokens=0, slot_mapping=slot_mapping, + block_tables=block_tables, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -1187,6 +1196,20 @@ def _prepare_decode( dtype=self.model_config.dtype, device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + attn_metadata = self.attn_backend.make_metadata( is_prompt=False, block_list=block_list, @@ -1197,12 +1220,13 @@ def _prepare_decode( block_scales=block_scales, block_groups=block_groups, attn_bias=None, - seq_lens_tensor=None, + seq_lens_tensor=seq_lens_tensor, context_lens_tensor=None, num_prefills=0, num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, + block_tables=block_tables, ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, @@ -1406,10 +1430,19 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object: # input_hash(123) != input_hash(321) # input_hash("abc") != input_hash("cba") attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [ - 'attn_bias', 'seq_lens_tensor', 'context_lens_tensor', - 'block_list', 'block_mapping', 'block_usage', 'slot_mapping', - 'is_prompt', 'block_indices', 'block_offsets', 'block_scales', - 'block_groups' + 'attn_bias', + 'seq_lens_tensor', + 'context_lens_tensor', + 'block_list', + 'block_mapping', + 'block_usage', + 'slot_mapping', + 'is_prompt', + 'block_indices', + 'block_offsets', + 'block_scales', + 'block_groups', + 'block_tables', ]) return attention_metadata