From 3ff4f8865d1ed0f35eb4b51b4d81b1e55fe36090 Mon Sep 17 00:00:00 2001 From: "Wu, Gangsheng" Date: Fri, 12 Jul 2024 09:46:12 +0000 Subject: [PATCH] update --- llm_on_ray/finetune/finetune.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py index 4d742700..6c9aa71f 100644 --- a/llm_on_ray/finetune/finetune.py +++ b/llm_on_ray/finetune/finetune.py @@ -309,11 +309,9 @@ def load_model(config: Dict): model.generation_config.pad_token_id = 0 model.generation_config.bos_token_id = 1 model.generation_config.eos_token_id = 2 - attn_softmax_bf16 = config["General"]["attn_softmax_bf16"] - if attn_softmax_bf16 and device == "hpu": - model.generation_config.attn_softmax_bf16 - use_flash_attention = config["General"]["use_flash_attention"] - if use_flash_attention and device == "hpu": + if device == "hpu" and config["General"]["attn_softmax_bf16"]: + model.generation_config.attn_softmax_bf16 = True + if device == "hpu" and config["General"]["use_flash_attention"]: model.generation_config.use_flash_attention = True model.generation_config.flash_attention_recompute = False model.generation_config.flash_attention_causal_mask = False