Skip to content

Commit

Permalink
Refactor NormalizedConfigs for GQA (#1539)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Nov 15, 2023
1 parent e3b7efb commit ea4349d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class NormalizedTextConfig(NormalizedConfig):
EOS_TOKEN_ID = "eos_token_id"


class NormalizedTextConfigWithGQA(NormalizedTextConfig):
NUM_KEY_VALUE_HEADS = "num_key_value_heads"


class NormalizedSeq2SeqConfig(NormalizedTextConfig):
ENCODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS
DECODER_NUM_LAYERS = NormalizedTextConfig.NUM_LAYERS
Expand Down Expand Up @@ -166,8 +170,6 @@ def __getattr__(self, attr_name):
allow_new=True,
)

MistralNormalizedTextConfig = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True)


class NormalizedConfigManager:
"""
Expand Down Expand Up @@ -227,13 +229,13 @@ class NormalizedConfigManager:
"gpt-bigcode": GPTBigCodeNormalizedTextConfig,
"gpt-neo": NormalizedTextConfig.with_args(num_attention_heads="num_heads"),
"gpt-neox": NormalizedTextConfig,
"llama": NormalizedTextConfig,
"llama": NormalizedTextConfigWithGQA,
"gptj": GPT2LikeNormalizedTextConfig,
"imagegpt": GPT2LikeNormalizedTextConfig,
"longt5": T5LikeNormalizedTextConfig,
"marian": BartLikeNormalizedTextConfig,
"mbart": BartLikeNormalizedTextConfig,
"mistral": MistralNormalizedTextConfig,
"mistral": NormalizedTextConfigWithGQA,
"mt5": T5LikeNormalizedTextConfig,
"m2m-100": BartLikeNormalizedTextConfig,
"nystromformer": NormalizedTextConfig,
Expand Down

0 comments on commit ea4349d

Please sign in to comment.