Skip to content

Commit

Permalink
remove kv cache wa
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhaiFeiyue committed Jul 20, 2023
1 parent 9ad7180 commit 1e7d3a7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 16 deletions.
5 changes: 2 additions & 3 deletions optimum/habana/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,8 @@ def forward(
if layer_past is not None:
past_key, past_value = layer_past
if token_idx is not None:
# HPU bug WA
past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1))
past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1))
past_key.index_copy_(2, token_idx - 1, key)
past_value.index_copy_(2, token_idx - 1, value)
key = past_key
value = past_value
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def gaudi_gpt_neox_attention_forward(
past_key = layer_past[0]
past_value = layer_past[1]
if token_idx is not None:
# HPU bug WA
past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1))
past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1))
past_key.index_copy_(2, token_idx - 1, key)
past_value.index_copy_(2, token_idx - 1, value)
key = past_key
value = past_value
else:
Expand Down
5 changes: 2 additions & 3 deletions optimum/habana/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ def gaudi_gptj_attention_forward(
past_value = layer_past[1]

if token_idx is not None:
# HPU bug WA
past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1))
past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1))
past_key.index_copy_(2, token_idx - 1, key)
past_value.index_copy_(2, token_idx - 1, value)
key = past_key
value = past_value
else:
Expand Down
9 changes: 2 additions & 7 deletions optimum/habana/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,8 @@ def gaudi_opt_attention_forward(
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if token_idx is not None:
# HPU bug WA
past_key_value[0].index_add_(
2, token_idx - 1, key_states - torch.index_select(past_key_value[0], 2, token_idx - 1)
)
past_key_value[1].index_add_(
2, token_idx - 1, value_states - torch.index_select(past_key_value[1], 2, token_idx - 1)
)
past_key_value[0].index_copy_(2, token_idx - 1, key_states)
past_key_value[1].index_copy_(2, token_idx - 1, value_states)
key_states = past_key_value[0]
value_states = past_key_value[1]
else:
Expand Down

0 comments on commit 1e7d3a7

Please sign in to comment.