Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolved alibi bias issue due to porting flat PA pr #437

Open
wants to merge 1 commit into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
self.max_seq_len = max_seq_len
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

Expand Down Expand Up @@ -198,7 +199,7 @@ def forward(
if self.alibi_slopes is not None:
position_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads,
attn_bias.dtype, attn_bias.shape[-1])
self.alibi_slopes.dtype, attn_bias.shape[-1])
attn_bias = attn_bias.tile(
(1, self.num_kv_heads, 1, 1))
attn_bias.add_(position_bias)
Expand Down Expand Up @@ -235,6 +236,17 @@ def forward(
output = out.reshape(batch_size, seq_len, hidden_size)
else:
# Decoding run.
self.position_bias = None
attn_bias = attn_metadata.attn_bias
if self.alibi_slopes is not None:
self.position_bias = _make_alibi_bias(
self.alibi_slopes,
self.num_kv_heads,
self.alibi_slopes.dtype,
self.max_seq_len if self.max_seq_len is not None else
attn_bias.shape[-1] * attn_metadata.block_tables.shape[-1]
)

output = HPUPagedAttention.forward_decode(
query=query,
key_cache=key_cache,
Expand All @@ -245,10 +257,12 @@ def forward(
block_scales=attn_metadata.block_scales,
block_groups=attn_metadata.block_groups,
scale=self.scale,
alibi_slopes=self.position_bias,
matmul_qk_op=self.matmul_qk,
matmul_av_op=self.matmul_av,
keys_fetch_func=self.k_cache.fetch_from_cache,
values_fetch_func=self.v_cache.fetch_from_cache)
values_fetch_func=self.v_cache.fetch_from_cache,
)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

Expand All @@ -271,15 +285,16 @@ def _make_alibi_bias(

padded_len = (seq_len + 7) // 8 * 8
num_heads = alibi_slopes.shape[0]
bias = torch.empty(
per_head_bias = torch.empty(
1, # batch size
num_heads,
seq_len,
padded_len,
device=alibi_slopes.device,
dtype=dtype,
)[:, :, :, :seq_len].copy_(bias)
bias.mul_(alibi_slopes[:, None, None])
)[:, :, :, :seq_len]
per_head_bias[:, :] = bias
per_head_bias.mul_(alibi_slopes[:, None, None])
if num_heads != num_kv_heads:
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return bias
per_head_bias = per_head_bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
return per_head_bias
1 change: 1 addition & 0 deletions vllm/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class HPUPagedAttentionMetadata:
block_offsets: Optional[torch.Tensor]
block_scales: Optional[torch.Tensor]
block_groups: Optional[torch.Tensor]
block_tables: Optional[torch.Tensor]


class HPUPagedAttention:
Expand Down
43 changes: 38 additions & 5 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,14 @@ def _prepare_prompt(

block_indices, block_offsets = precompute_indices_and_offsets(
self.block_size, slot_mapping, True)

max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = make_tensor_with_pad(prefix_block_tables,
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
block_list=prefix_block_list_tensor,
Expand All @@ -995,6 +1003,7 @@ def _prepare_prompt(
num_prefill_tokens=sum_query_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
block_tables=block_tables,
)
multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)

Expand Down Expand Up @@ -1187,6 +1196,20 @@ def _prepare_decode(
dtype=self.model_config.dtype,
device=self.device)

seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)

max_block_table_len = max(
len(block_table) for block_table in block_tables)
block_tables = make_tensor_with_pad(
block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)

attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
block_list=block_list,
Expand All @@ -1197,12 +1220,13 @@ def _prepare_decode(
block_scales=block_scales,
block_groups=block_groups,
attn_bias=None,
seq_lens_tensor=None,
seq_lens_tensor=seq_lens_tensor,
context_lens_tensor=None,
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=num_decode_tokens,
slot_mapping=slot_mapping,
block_tables=block_tables,
)
return PrepareDecodeMetadata(input_tokens=input_tokens,
input_positions=input_positions,
Expand Down Expand Up @@ -1406,10 +1430,19 @@ def trim_attn_metadata(self, metadata: AttentionMetadata) -> object:
# input_hash(123) != input_hash(321)
# input_hash("abc") != input_hash("cba")
attention_metadata = subtuple(metadata, 'TrimmedAttentionMetadata', [
'attn_bias', 'seq_lens_tensor', 'context_lens_tensor',
'block_list', 'block_mapping', 'block_usage', 'slot_mapping',
'is_prompt', 'block_indices', 'block_offsets', 'block_scales',
'block_groups'
'attn_bias',
'seq_lens_tensor',
'context_lens_tensor',
'block_list',
'block_mapping',
'block_usage',
'slot_mapping',
'is_prompt',
'block_indices',
'block_offsets',
'block_scales',
'block_groups',
'block_tables',
])
return attention_metadata

Expand Down