diff --git a/calc/README.md b/calc/README.md index af2c8c0..b4e2381 100644 --- a/calc/README.md +++ b/calc/README.md @@ -56,13 +56,14 @@ options: ``` Example with Fairseq-MoE 15B: python calc_transformer_params.py -l 12 -hs 768 --moe -e 512 Example with GPT-3 175B: python calc_transformer_params.py -l 96 -hs 12288 -usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL] - [--topk TOPK] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] +usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--tied-embeddings] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] + [--expert-interval EXPERT_INTERVAL] [--topk TOPK] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--kv-size-ratio KV_SIZE_RATIO] options: -h, --help show this help message and exit --vocab-size VOCAB_SIZE, -v VOCAB_SIZE Size of the vocab + --tied-embeddings Whether embeddings are tied (shared between input and output) --hidden-size HIDDEN_SIZE, -hs HIDDEN_SIZE Dimension of the model's hidden size --sequence-length SEQUENCE_LENGTH, -s SEQUENCE_LENGTH @@ -77,6 +78,8 @@ options: --topk TOPK, -t TOPK Top k routing for MoE --ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR How much the MLP hidden size expands + --kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO + What fraction of num. query heads is num. key/value heads ``` diff --git a/calc/calc_transformer_params.py b/calc/calc_transformer_params.py index c4c1e30..6490ed3 100644 --- a/calc/calc_transformer_params.py +++ b/calc/calc_transformer_params.py @@ -19,6 +19,9 @@ def config_parser(): type=int, default=51200, help='Size of the vocab') + parser.add_argument("--tied-embeddings", + action="store_true", + help='Whether embeddings are tied (shared between input and output)') parser.add_argument("--hidden-size", "-hs", type=int, default=6144, @@ -58,8 +61,11 @@ def config_parser(): # calculates the params of a model given their hparams def calc_params(args): - # Assumes that the embedding and unembedding are tied - embedding_params = args.hidden_size * args.vocab_size + # Calculate embedding and unembedding params. If tied, re-use the same params + if args.tied_embeddings: + embedding_params = args.hidden_size * args.vocab_size + else: + embedding_params = 2 * args.hidden_size * args.vocab_size position_embedding_params = args.hidden_size * args.sequence_length # Each QKVO matrix is (hxh) # Unless using GQA/MQA which makes K/V smaller