diff --git a/optimum/habana/transformers/models/llama/modeling_llama.py b/optimum/habana/transformers/models/llama/modeling_llama.py index 4630678a97..88796f0922 100755 --- a/optimum/habana/transformers/models/llama/modeling_llama.py +++ b/optimum/habana/transformers/models/llama/modeling_llama.py @@ -515,7 +515,7 @@ def pre_attn_forward( use_recompute = True if os.getenv("QUANT_CONFIG", "") else False with ht.sdp_kernel(enable_recompute=use_recompute): attn_output = self.fused_scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, 0.0, False, None, softmax_mode + query_states, key_states, value_states, attention_mask, 0.0, False, None, 'None' ) else: # first token