From f2a515e170494b5b43539b7485889546ae4256bb Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 17 Jul 2024 08:57:58 +0000 Subject: [PATCH] add custom attn impl --- open_diloco/train_fsdp.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 5d209ef..782dbe7 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -116,7 +116,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: class Config(BaseConfig): llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json" torch_compile: bool = True - attn_implementation: str = "sdpa" + attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa" # Data dataset_name_or_path: str = "allenai/c4" seq_length: int = 1024 @@ -184,11 +184,13 @@ def tokenize_function(data): def get_model(config: Config) -> GPT: # Load model if isinstance(config.llama_config, ModelConfig): - return GPT(config.llama_config) + llama_config = config.llama_config else: with open(config.llama_config) as f: llama_config = ModelConfig(**json.load(f)) - return GPT(llama_config) + + llama_config.attention_impl = config.attention_impl + return GPT(llama_config) def train(config: Config):