Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT3 MoE #4

Open
wants to merge 53 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
7aa3940
XGLM work in progress: Causal Attention and Positional Embeddings work
AleHD Jun 26, 2024
78dd53c
WIP: GPT arch almost done, hf->nt converters working perfectly for no…
AleHD Jun 26, 2024
a74c71a
Added hf2nt frontend + tested training
AleHD Jul 9, 2024
04eaef9
Added nt2hf conversion + tests :)
AleHD Jul 11, 2024
138da5f
precommit
AleHD Jul 11, 2024
35c43f7
Merge pull request #1 from swiss-ai/gpt
negar-foroutan Jul 15, 2024
0485fd6
Added MultilingualNanoset Config
TJ-Solergibert Jul 16, 2024
539832a
Added MultilingualNanoset
TJ-Solergibert Jul 16, 2024
d9f0670
Added Language token
TJ-Solergibert Jul 16, 2024
efe8720
Forgot the trainer ups
TJ-Solergibert Jul 16, 2024
25ad39b
Fix minor errors. Everything works
TJ-Solergibert Jul 16, 2024
d91f9e1
Updated config file with GPT2 tokenized datasets in RCP
TJ-Solergibert Jul 16, 2024
d0c14e3
Before lunch
TJ-Solergibert Jul 17, 2024
9cfc5ea
After lunch
TJ-Solergibert Jul 17, 2024
eed7bce
Ready
TJ-Solergibert Jul 18, 2024
da50231
Merge pull request #2 from TJ-Solergibert/multilingual_nanoset
negar-foroutan Jul 18, 2024
7a932f8
start documenting moe setup
Aug 2, 2024
f08a05e
base moe file
haeggee Aug 2, 2024
fa06c0d
add todo
haeggee Aug 5, 2024
a9dba53
gpt3_moe basis
haeggee Aug 5, 2024
2efffb8
add nn.linear to init for moe router
haeggee Aug 5, 2024
57c58b7
changes to pipeline for backward through aux losses
haeggee Aug 5, 2024
3967bee
correct block costs and flops
haeggee Aug 5, 2024
bcb94cc
case of dict in pipelineblock
haeggee Aug 5, 2024
91acdc0
option for GLU or normal MLP
haeggee Aug 5, 2024
df3befc
init of linear layer in starcoder
haeggee Aug 6, 2024
6edce83
potential bug in pipeline block
haeggee Aug 8, 2024
7425167
XGLM work in progress: Causal Attention and Positional Embeddings work
AleHD Jun 26, 2024
42695ba
WIP: GPT arch almost done, hf->nt converters working perfectly for no…
AleHD Jun 26, 2024
6294aad
Added hf2nt frontend + tested training
AleHD Jul 9, 2024
b469ee9
Added nt2hf conversion + tests :)
AleHD Jul 11, 2024
1b19ca2
precommit
AleHD Jul 11, 2024
c1fabac
Added MultilingualNanoset Config
TJ-Solergibert Jul 16, 2024
086b50d
Added MultilingualNanoset
TJ-Solergibert Jul 16, 2024
1fe7445
Added Language token
TJ-Solergibert Jul 16, 2024
fb6631a
Forgot the trainer ups
TJ-Solergibert Jul 16, 2024
a6eb1bd
Fix minor errors. Everything works
TJ-Solergibert Jul 16, 2024
ef3fac4
Updated config file with GPT2 tokenized datasets in RCP
TJ-Solergibert Jul 16, 2024
49294f1
Before lunch
TJ-Solergibert Jul 17, 2024
8a80e5a
After lunch
TJ-Solergibert Jul 17, 2024
0fa1971
Ready
TJ-Solergibert Jul 18, 2024
8b68126
Add multilingual validation (#3)
TJ-Solergibert Aug 15, 2024
d08c949
correct logging of all losses
haeggee Aug 15, 2024
d14315f
minor bug fix when using bias
haeggee Aug 16, 2024
5dc67fe
bias init in case of use for moe
haeggee Aug 16, 2024
fad3497
sparse upcycling converter
haeggee Aug 16, 2024
100ebf4
merge main into moe; also, adapt gpt3(-moe) for val logging
haeggee Aug 16, 2024
f31a1a3
add example config
haeggee Aug 16, 2024
a45dc35
small fixes
haeggee Sep 4, 2024
bb768bb
fix for eval
haeggee Sep 6, 2024
f8e30b4
lighteval fix for multiling
haeggee Sep 6, 2024
ac27ada
Update moe.md
haeggee Sep 6, 2024
328b8c2
Update moe.md
haeggee Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
checkpoints:
checkpoint_interval: 1000000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: General purpose training (Blended dataset)
start_training_step: 1
- data:
dataset:
training_folder:
- datasets/c4-es/train
validation_folder:
- datasets/c4-es/validation
languages:
- es
num_loading_workers: 1
seed: 42
name: Second purpose training (Single dataset)
start_training_step: 1000
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: Third purpose training (>1 dataset)
start_training_step: 2000
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: MultilingualV2
run: llama
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 4096
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rope_interleaved: false
rope_theta: 500000.0
rms_norm_eps: 1.0e-06
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 4
tp_linear_async_communication: false
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 10
micro_batch_size: 3
sequence_length: 4096
train_steps: 500
val_check_interval: 100
27 changes: 27 additions & 0 deletions examples/xglm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# How to use XGLM?

1. First, make sure to convert the weights from huggingface, for instance:
```bash
torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M
```

2. Now you are ready to use XGLM.
Make sure you use a .yaml configuration with proper GPT3 config and then run for instance:
```bash
torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml
```
If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`.

3. If you want to convert your finetuned checkpoint back to huggingface use:
```bash
torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checkpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M
```

## Sparse Upcycling

To create a sparse model from a dense model, you can use the `convert_dense2moe.py` script that goes from a GPT3 Nanotron model to a GPT3 MoE Nanotron model. For instance:
```bash
cd examples/xglm
torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=checkpoints/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-8x564M --num-experts=8
```
Note that this upcycling _drops_ the bias parameters of the MLP because the MegaBlocks implementation does not support bias parameters. While this is a limitation of the current implementation, the performance is quickly recovered after a few training steps.
Empty file added examples/xglm/__init__.py
Empty file.
179 changes: 179 additions & 0 deletions examples/xglm/convert_dense2moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
Converts a nanotron model to HF format
Command:
torchrun --nproc-per-node=1 convert_dense2moe.py --checkpoint-path=nanotron_weights --save-path=nanotron_moe_weights
"""

import dataclasses
import json
import warnings
from argparse import ArgumentParser
from pathlib import Path
from typing import Optional

from torch import nn
import torch
import nanotron
from nanotron.config.models_config import GPT3Config, GPT3MoEConfig
from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock
from nanotron.models.gpt3_moe import GPT3MoEForTraining, GPT3MoEBlock
from nanotron.trainer import mark_tied_parameters

from convert_utils import convert_generic, create_nt_model


def convert_config(config: GPT3Config, num_experts=8) -> GPT3MoEConfig:
return GPT3MoEConfig(
**config.__dict__,
is_moe=True,
moe_num_experts=num_experts,
num_experts_per_tok=min(2, num_experts), # arbitrarily chosen
moe_loss_weight=0.01, # arbitrarily chosen
moe_z_loss_weight=0.001, # arbitrarily chosen
moe_glu=False,
)


def convert_dense_to_moe(ff_moe: nn.Module, dense_ff: nn.Module, num_experts: int):
with torch.no_grad():
# only copy the weight matrix and repeat it n_expert times
weight_1 = dense_ff.c_fc.weight.clone()
if num_experts == 1:
ff_moe.experts.mlp.w1.module.weight.data = weight_1.contiguous()
else:
# [intermediate_size, hidden_size] -> [hidden_size, intermediate_size * n_experts]
weight_1 = weight_1.T
ff_moe.experts.mlp.w1.module.weight.data = weight_1.repeat(1, num_experts)

weight_2 = dense_ff.c_proj.weight.clone()
if num_experts == 1: # just a specific case for 1 expert
ff_moe.experts.mlp.w2.module.weight.data = weight_2.contiguous()
else:
# [hidden_size, intermediate_size] -> [intermediate_size * n_experts, hidden_size]
weight_2 = weight_2.T
ff_moe.experts.mlp.w2.module.weight.data = weight_2.repeat(num_experts, 1)

# # -- could add bias only for 2nd layer, because that works with the MegaBlocks MoE implementation
# # -- but won't make a big difference?
# ff_moe.experts.bias.copy_(dense_ff.c_proj.bias)

# init gating randomly
nn.init.normal_(ff_moe.gate.layer.weight, mean=0.0, std=0.02)


def convert_decoder(block_moe: GPT3MoEBlock, block_nt: GPTBlock, num_experts: int):
convert_generic(block_moe.ln_1, block_nt.ln_1)
convert_generic(block_moe.attn, block_nt.attn)
convert_generic(block_moe.ln_2, block_nt.ln_2)
convert_dense_to_moe(block_moe.ff, block_nt.ff, num_experts)


def convert(
model_moe: GPT3MoEForTraining, model_dense: GPT3ForTraining, num_experts: int
):
convert_generic(
model_moe.model.token_embeddings.pp_block.token_embedding,
model_dense.model.token_embeddings.pp_block.token_embedding,
)
for layer_moe, layer_nt in zip(model_moe.model.decoder, model_dense.model.decoder):
convert_decoder(layer_moe.pp_block, layer_nt.pp_block, num_experts)
convert_generic(
model_moe.model.final_layer_norm.pp_block,
model_dense.model.final_layer_norm.pp_block,
)
convert_generic(
model_moe.model.lm_head.pp_block, model_dense.model.lm_head.pp_block
)


def create_nt_moe_model(
model_config: Optional[GPT3Config] = None,
device: torch.device = torch.device("cuda"),
dtype: torch.dtype = torch.bfloat16,
checkpoint_path: Optional[Path] = None,
):

if model_config is None:
assert checkpoint_path is not None
with open(checkpoint_path / "model_config.json") as f:
model_config = GPT3MoEConfig(**json.load(f))

parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1)
parallel_context = nanotron.parallel.ParallelContext(
data_parallel_size=parallel_config.dp,
pipeline_parallel_size=parallel_config.pp,
tensor_parallel_size=parallel_config.tp,
)
model_nt = nanotron.models.build_model(
model_builder=lambda: GPT3MoEForTraining(
config=model_config,
parallel_context=parallel_context,
parallel_config=parallel_config,
random_states=None,
),
parallel_context=parallel_context,
dtype=dtype,
device=device,
)
mark_tied_parameters(model=model_nt, parallel_context=parallel_context)

if checkpoint_path is not None:
nanotron.serialize.load_weights(
model=model_nt,
parallel_context=parallel_context,
root_folder=checkpoint_path,
)

return model_nt


def main(
checkpoint_path: Path,
save_path: Path,
num_experts: int,
):
# Load nanotron model.
model_dense = create_nt_model(checkpoint_path=checkpoint_path)

# Init moe model.
model_config_moe = convert_config(model_dense.config, num_experts)
model_moe = create_nt_moe_model(model_config=model_config_moe)

convert(model_moe, model_dense, num_experts)
nanotron.serialize.save_weights(
model=model_moe,
parallel_context=model_moe.parallel_context,
root_folder=save_path,
)
with open(save_path / "model_config.json", "w+") as f:
json.dump(dataclasses.asdict(model_config_moe), f)
print(f"Model saved to {save_path}")


if __name__ == "__main__":
# fix all random seeds
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
parser = ArgumentParser(description="Convert dense weights to moe format")
parser.add_argument(
"--checkpoint-path",
type=Path,
default="checkpoints/xglm-7.5B",
help="Path to the nanotron dense checkpoint",
)
parser.add_argument(
"--save-path",
type=Path,
default="checkpoints/xglm-moe-7.5B",
help="Path to save the nanotron moe model",
)
parser.add_argument(
"--num-experts",
type=int,
default=8,
help="Number of experts in the MoE model (duplicates of MLP layer)",
)
args = parser.parse_args()
main(args.checkpoint_path, args.save_path, args.num_experts)
Loading
Loading