From 8fcddf677b05155c7f7c4ff529a1661316b83b5a Mon Sep 17 00:00:00 2001 From: kkoryun Date: Tue, 24 Sep 2024 23:13:51 +0100 Subject: [PATCH] added flash attn args --- optimum/habana/transformers/models/gemma/modeling_gemma.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/optimum/habana/transformers/models/gemma/modeling_gemma.py b/optimum/habana/transformers/models/gemma/modeling_gemma.py index 6f40c65ea..a0171775e 100644 --- a/optimum/habana/transformers/models/gemma/modeling_gemma.py +++ b/optimum/habana/transformers/models/gemma/modeling_gemma.py @@ -321,6 +321,9 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from GemmaForCausalLM: https://github.com/huggingface/transformers/blob/v4.38.1/src/transformers/models/gemma/modeling_gemma.py @@ -346,6 +349,9 @@ def forward( return_dict=return_dict, cache_position=cache_position, token_idx=token_idx, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, ) hidden_states = outputs[0]