Skip to content

Commit

Permalink
TGI: fix bloom concatenate and filter issue (#401)
Browse files Browse the repository at this point in the history
fix bloom concatenate and filter issue since the dims of bloom past_key_values is 3.

Signed-off-by: Wang, Yi A <[email protected]>
  • Loading branch information
sywangyi authored Sep 14, 2023
1 parent 45752d6 commit af3da93
Showing 1 changed file with 19 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,27 @@ def filter(self, request_ids: List[int], is_optimized_for_gaudi: bool = False) -
past_kv_length = max_input_length - 1
for layer in self.past_key_values:
past_keys, past_values = layer
past_keys_dims = len(past_keys.shape)
if past_keys_dims == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if is_optimized_for_gaudi:
layer[0] = past_keys[keep_indices]
del past_keys
layer[1] = past_values[keep_indices]
del past_values
else:
if len(past_keys.shape) == 3:
# Force past to be of dim [self_size, num_heads, ...] for easy indexing
past_keys = past_keys.view(len(self), -1, *past_keys.shape[-2:])
past_values = past_values.view(len(self), -1, *past_values.shape[-2:])
if self.keys_head_dim_last:
layer[0] = past_keys[keep_indices, :, -past_kv_length:, :]
else:
layer[0] = past_keys[keep_indices, :, :, -past_kv_length:]
del past_keys
layer[1] = past_values[keep_indices, :, -past_kv_length:, :]
del past_values
if past_keys_dims == 3:
layer[0] = layer[0].view(layer[0].shape[0] * layer[0].shape[1], *layer[0].shape[-2:])
layer[1] = layer[1].view(layer[1].shape[0] * layer[1].shape[1], *layer[1].shape[-2:])

top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices]
max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens
Expand Down Expand Up @@ -378,12 +382,13 @@ def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: boo
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
# And ensure that we can update tensors in-place
kv_tuple = False
past_key_values_dims = len(batch.past_key_values[0][0].shape)
if type(batch.past_key_values[0]) == tuple:
batch.past_key_values = [
[t.view(len(batch), -1, *t.shape[-2:]) for t in layer] for layer in batch.past_key_values
]
kv_tuple = True
elif len(batch.past_key_values[0][0].shape) == 3:
elif past_key_values_dims == 3:
for layer in batch.past_key_values:
for k, t in enumerate(layer):
layer[k] = t.view(len(batch), -1, *t.shape[-2:])
Expand Down Expand Up @@ -469,6 +474,15 @@ def concatenate(cls, batches: List["CausalLMBatch"], is_optimized_for_gaudi: boo

# Update values
start_index = end_index

if past_key_values_dims == 3:
padded_past_keys = padded_past_keys.view(
padded_past_keys.shape[0] * padded_past_keys.shape[1], *padded_past_keys.shape[-2:]
)
padded_past_values = padded_past_values.view(
padded_past_values.shape[0] * padded_past_values.shape[1], *padded_past_values.shape[-2:]
)

if kv_tuple:
past_key_values.append((padded_past_keys, padded_past_values))
else:
Expand Down

0 comments on commit af3da93

Please sign in to comment.