diff --git a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py index 38ff9389d9..4fda9abec5 100644 --- a/optimum/habana/transformers/models/gpt2/modeling_gpt2.py +++ b/optimum/habana/transformers/models/gpt2/modeling_gpt2.py @@ -211,9 +211,8 @@ def forward( if layer_past is not None: past_key, past_value = layer_past if token_idx is not None: - # HPU bug WA - past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1)) - past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1)) + past_key.index_copy_(2, token_idx - 1, key) + past_value.index_copy_(2, token_idx - 1, value) key = past_key value = past_value else: diff --git a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py index e41ce7554c..3d9683c678 100644 --- a/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/optimum/habana/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -60,9 +60,8 @@ def gaudi_gpt_neox_attention_forward( past_key = layer_past[0] past_value = layer_past[1] if token_idx is not None: - # HPU bug WA - past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1)) - past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1)) + past_key.index_copy_(2, token_idx - 1, key) + past_value.index_copy_(2, token_idx - 1, value) key = past_key value = past_value else: diff --git a/optimum/habana/transformers/models/gptj/modeling_gptj.py b/optimum/habana/transformers/models/gptj/modeling_gptj.py index e5a4ccca01..5b14e45b52 100644 --- a/optimum/habana/transformers/models/gptj/modeling_gptj.py +++ b/optimum/habana/transformers/models/gptj/modeling_gptj.py @@ -65,9 +65,8 @@ def gaudi_gptj_attention_forward( past_value = layer_past[1] if token_idx is not None: - # HPU bug WA - past_key.index_add_(2, token_idx - 1, key - torch.index_select(past_key, 2, token_idx - 1)) - past_value.index_add_(2, token_idx - 1, value - torch.index_select(past_value, 2, token_idx - 1)) + past_key.index_copy_(2, token_idx - 1, key) + past_value.index_copy_(2, token_idx - 1, value) key = past_key value = past_value else: diff --git a/optimum/habana/transformers/models/opt/modeling_opt.py b/optimum/habana/transformers/models/opt/modeling_opt.py index c8cfa9ffa0..fee4dcf7e8 100644 --- a/optimum/habana/transformers/models/opt/modeling_opt.py +++ b/optimum/habana/transformers/models/opt/modeling_opt.py @@ -72,13 +72,8 @@ def gaudi_opt_attention_forward( key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if token_idx is not None: - # HPU bug WA - past_key_value[0].index_add_( - 2, token_idx - 1, key_states - torch.index_select(past_key_value[0], 2, token_idx - 1) - ) - past_key_value[1].index_add_( - 2, token_idx - 1, value_states - torch.index_select(past_key_value[1], 2, token_idx - 1) - ) + past_key_value[0].index_copy_(2, token_idx - 1, key_states) + past_key_value[1].index_copy_(2, token_idx - 1, value_states) key_states = past_key_value[0] value_states = past_key_value[1] else: