Skip to content

Commit

Permalink
add custom attn impl
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Jul 17, 2024
1 parent 5da6c57 commit f2a515e
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
class Config(BaseConfig):
llama_config: str | ModelConfig = "open_diloco/configs/config_1b.json"
torch_compile: bool = True
attn_implementation: str = "sdpa"
attention_impl: Literal["sdpa", "fa", "xformers"] = "sdpa"
# Data
dataset_name_or_path: str = "allenai/c4"
seq_length: int = 1024
Expand Down Expand Up @@ -184,11 +184,13 @@ def tokenize_function(data):
def get_model(config: Config) -> GPT:
# Load model
if isinstance(config.llama_config, ModelConfig):
return GPT(config.llama_config)
llama_config = config.llama_config
else:
with open(config.llama_config) as f:
llama_config = ModelConfig(**json.load(f))
return GPT(llama_config)

llama_config.attention_impl = config.attention_impl
return GPT(llama_config)


def train(config: Config):
Expand Down

0 comments on commit f2a515e

Please sign in to comment.