From 1affc60889ddb77e07ba330a620bd5fb98e888aa Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 16 Sep 2024 10:28:57 -0700 Subject: [PATCH 1/2] use eager attention for wav2vec2 disable sdpa for wav2vec2 for now. --- optimum/habana/transformers/models/modeling_all_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 808e0f012a..01d70ad2ab 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -115,7 +115,7 @@ def gaudi_conv1d_forward(self, x): @classmethod def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: # This model doesn't support SDPA in Gaudi yet, fallback to original code. - MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral"] + MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral", "wav2vec2'] if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER: config._attn_implementation = "eager" From 83e80a77415304f8f69ea97f0e5a2ea326fb98e1 Mon Sep 17 00:00:00 2001 From: Shiv Kaul Date: Mon, 16 Sep 2024 10:32:06 -0700 Subject: [PATCH 2/2] Update modeling_all_models.py --- optimum/habana/transformers/models/modeling_all_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/models/modeling_all_models.py b/optimum/habana/transformers/models/modeling_all_models.py index 01d70ad2ab..c9eb95524e 100644 --- a/optimum/habana/transformers/models/modeling_all_models.py +++ b/optimum/habana/transformers/models/modeling_all_models.py @@ -115,7 +115,7 @@ def gaudi_conv1d_forward(self, x): @classmethod def gaudi_check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: # This model doesn't support SDPA in Gaudi yet, fallback to original code. - MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral", "wav2vec2'] + MODELS_ATTN_IMPLEMENTATION_EAGER = ["bart", "gpt_bigcode", "mistral", "mixtral", "wav2vec2"] if config.model_type in MODELS_ATTN_IMPLEMENTATION_EAGER: config._attn_implementation = "eager"