From c37d1cbab0d13b0351831fdc093eb89fbeaa16ce Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 27 Aug 2023 21:30:25 -0400 Subject: [PATCH] clean up andromeda --- Andromeda/configs.py | 2 +- train.py | 3 +-- train_simple.py | 27 ++------------------------- 3 files changed, 4 insertions(+), 28 deletions(-) diff --git a/Andromeda/configs.py b/Andromeda/configs.py index fcd3968..04348d9 100644 --- a/Andromeda/configs.py +++ b/Andromeda/configs.py @@ -1,4 +1,4 @@ -from Andromeda.model import AndromedaEmbedding, Andromeda +from Andromeda.model import Andromeda Andromeda1Billion = Andromeda( diff --git a/train.py b/train.py index e533aa3..dca5681 100644 --- a/train.py +++ b/train.py @@ -40,9 +40,8 @@ from Andromeda.utils.stable_adamw import StableAdamWUnfused -from Andromeda.core.transformer import Transformer, AndromedaEmbedding +from Andromeda.core.transformer import Transformer # from Andromeda.model import Andromeda -from Andromeda.model import AndromedaEmbedding #, Andromeda from Andromeda.configs import Andromeda1Billion ########### SETUP CONFIG diff --git a/train_simple.py b/train_simple.py index aaf6c6c..f35746b 100644 --- a/train_simple.py +++ b/train_simple.py @@ -9,7 +9,7 @@ from Andromeda.model import Andromeda -from Andromeda.core.transformer import Decoder, AndromedaEmbedding, Transformer +from Andromeda.core.transformer import Decoder, Transformer from Andromeda.core.autoregressive_wrapper import AutoregressiveWrapper # constants @@ -37,30 +37,7 @@ def decode_tokens(tokens): # instantiate GPT-like decoder model -model = Transformer( - num_tokens=50432, - max_seq_len=8192, - use_abs_pos_emb=False, - embedding_provider=AndromedaEmbedding(), - attn_layers=Decoder( - dim=2560, - depth=32, - dim_head=128, - heads=24, - alibi_pos_bias=True, - alibi_num_heads=12, - rotary_xpos=True, - attn_flash=True, - # deepnorm=deepnorm, - # shift_tokens=shift_tokens, - attn_one_kv_head=True, - qk_norm=True, - attn_qk_norm=True, - attn_qk_norm_dim_scale=True - ) -) - -model = AutoregressiveWrapper(model) +model = Andromeda() model.cuda()