Skip to content

Commit

Permalink
[Example] add an example of running open Mixtral 8x7B in 4D using veS…
Browse files Browse the repository at this point in the history
…cale (#24)

This PR adds an 4D parallelism example of using veScale to run a
[Mixtral 8x7B model](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1)
that is directly imported from HuggingFace without any model code
modifications.

At the same time, we also develop a debug utility of printing logs and
reorganize some code structures in this PR.
  • Loading branch information
Vremold authored Apr 10, 2024
1 parent 372adcb commit 9d59f8d
Show file tree
Hide file tree
Showing 42 changed files with 2,250 additions and 447 deletions.
34 changes: 34 additions & 0 deletions python/example/mixtral_4D_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# veScale Mixtral Example

## Overview

In this directory, we provides an 4D parallelism example of using veScale to run
a [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) that is directly imported
from HuggingFace without any model code modifications.


## Run

### Single Machine 8 cards
```
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- python/example/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
```
This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8.

### Distributed Environment (2 Machine 16 cards example)
```
# You may need to pull up a suitable distributed cluster environment
torchrun --nproc-per-node=8 --nnodes=1 python/example/mixtral_4D_benchmark/mixtral_train.py --tp 8 --dp 2
```
This will start a 16 cards MFU benchmark for Mixtral with veScale with dp=2 and tp=8.

### Options
1. `--bsz`: the total number of batch size for one iteration. The default is 16.
2. `--seqlen`: the sequence lengtht of the input. The default value is 256.
3. `--dp`: the amount of data parallelism (DDP). The default is 1.
4. `--tp`: the amount of tensor parallelism. The default is 8.


## Caveats
1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script
it in order to fine-tune Mixtral with your data.
150 changes: 150 additions & 0 deletions python/example/mixtral_4D_benchmark/mixtral_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import argparse
import os

import torch
import torch.distributed as dist

from vescale.dtensor.device_mesh import DeviceMesh
from vescale.dmodule import parallelize_module
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.initialize.deferred_init import deferred_init, is_deferred

from transformers.models.mixtral.modeling_mixtral import MixtralModel
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from sharding_plan import mixtral_plan

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])


def estimate_mixtral(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
# MixtralMoE consists of 3 linear layers.
ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
attn_mask = 2 * sqence_length * config.hidden_size
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn = attn_qkv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers


def run_mixtral(args):
torch.random.manual_seed(777)
device_list = [
list(range(i * args.tp, min((i + 1) * args.tp, world_size))) for i in range(max(world_size // args.tp, 1))
]
device_mesh = DeviceMesh("cuda", device_list, mesh_dim_names=("DP", "TP"))
torch.cuda.set_device(local_rank)

mixtral_config = MixtralConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
intermediate_size=args.intermediate_size,
num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads,
)

model_deferred = deferred_init(MixtralModel, mixtral_config)

mixtral_model = parallelize_module(
model_deferred,
device_mesh["TP"],
mixtral_plan,
factory=True,
)

assert not is_deferred(mixtral_model)

ddp_mixtral_model = DDP(
mixtral_model,
device_mesh["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
)

doptim = DistributedOptimizer(
torch.optim.Adam(mixtral_model.parameters(), lr=0.01),
models=[ddp_mixtral_model],
overlap_param_gather=True,
)

dataloader = []
for iter in range(args.iter):
data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda()
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dataloader.append(data)

# =----- warmup -----= #
for _ in range(args.warmup):
data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda()
doptim.zero_grad()
ddp_mixtral_model(data).last_hidden_state.to_local().sum().backward()
ddp_mixtral_model.finish_grad_sync()
doptim.step()

# =----- training ----= #
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for iter in range(args.iter):
doptim.zero_grad()
x = dataloader[iter]
ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward()
ddp_mixtral_model.finish_grad_sync()
doptim.step()
end.record()
torch.cuda.synchronize()
exec_t = start.elapsed_time(end) / 1000 / args.iter
# masure mfu
if local_rank == 0:
# Note we are using FP32. The peak FLOPs of H100 is 59 TFLOPs.
total_flops = 59 * (10**12) * device_mesh.ndevice
print(f"1 iter time: {exec_t}")
mixtral_flops = estimate_mixtral(mixtral_config, args.bsz, args.seqlen)
print(f"fwd mixtral flops: {mixtral_flops}")
# bwd ~= fwd * 2
print("mfu:", mixtral_flops * 3 * args.dp * 100 / exec_t / total_flops)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--iter", type=int, default=10)
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--hidden_size", type=int, default=4096)
parser.add_argument("--intermediate_size", type=int, default=14336)
parser.add_argument("--num_hidden_layers", type=int, default=16)
parser.add_argument("--num_attention_heads", type=int, default=32)
parser.add_argument("--num_key_value_heads", type=int, default=8)
parser.add_argument("--bsz", type=int, default=16)
parser.add_argument("--seqlen", type=int, default=256)
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--tp", type=int, default=8)
return parser


if __name__ == "__main__":
parser = parse_args()
args = parser.parse_args()
run_mixtral(args)
69 changes: 69 additions & 0 deletions python/example/mixtral_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

"""This file contain TP/SP sharding plans for Mixtral example code."""

from vescale.dtensor.placement_types import Replicate, Shard


param_sharding_plan = {
"embed_tokens.weight": [Replicate()],
r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.v_proj.weight": [Shard(0)],
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"norm.weight": [Replicate()],
}

fwd_resharding_plan = {
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
"embed_tokens.input": [[Replicate()]],
# No SP
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.input_layernorm.input": [[Shard(1)]],
r"layers.\d+.input_layernorm.output": [[Shard(1)]],
r"layers.\d+.self_attn.input": [[Replicate()]],
r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
r"layers.\d+.self_attn.o_proj.output": [[Replicate()]],
# No SP
# r"layers.\d+.post_attention_layernorm.input": [[Replicate()]],
# r"layers.\d+.post_attention_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"norm.input": [[Replicate()]],
}

mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}
22 changes: 7 additions & 15 deletions python/example/nanogpt_4D_finetune/base_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
Expand Down Expand Up @@ -138,7 +137,6 @@
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join("data", dataset)
Expand Down Expand Up @@ -227,9 +225,7 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
model.crop_block_size(block_size)
model_args["block_size"] = block_size # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
model.to(ptdtype)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
Expand Down Expand Up @@ -258,8 +254,7 @@ def estimate_loss():
losses = torch.zeros(eval_iters // factor).to(device)
for k in range(eval_iters // factor):
X, Y = get_batch(split, batch_size * factor, local_batch_size * factor)
with ctx:
logits, loss = model(X, Y)
logits, loss = model(X, Y)
losses[k] = loss.item() / ddp_world_size
if ddp:
all_reduce(losses)
Expand Down Expand Up @@ -345,20 +340,17 @@ def get_lr(it):
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch("train")
# backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward()
loss.backward()
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# step the optimizer
optimizer.step()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
dataset = "shakespeare"
init_from = "gpt2" # this is the smallest GPT-2 model

wandb_log = True # feel free to turn on
wandb_log = False # feel free to turn on
wandb_project = f"{init_from}_finetune_{dataset}"

# only save checkpoints if the validation loss improves
Expand Down
38 changes: 23 additions & 15 deletions python/example/nanogpt_4D_finetune/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,50 @@ def parse(log_fn, name=None):
print(f'"{name}": {val_losses},')


GPU_CNT = 4
DP_SIZES = [4, 2, 1]
SINGLE_GPU_RUN = "python3"
MULTI_GPU_RUN = "torchrun --standalone --nproc_per_node=4"
CONFIG = "config/finetune_shakespeare.py"
LOG_PREFIX = ""


def run_exps(max_iters, dtypes, run=True):
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"python3 base_train.py config/finetune_shakespeare.py --compile=False --wandb_log=False --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/1gpu_{dtype}_max_iters_{max_iters}.log"
cmd = f"{SINGLE_GPU_RUN} base_train.py {CONFIG} --compile=False --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"torchrun --standalone --nproc_per_node=4 finetune_4D.py config/finetune_shakespeare.py --compile=False --use_DO=True --wandb_log=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --DDP_grads_in_fp32=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")

print("train_loss = {")
for dtype in dtypes:
parse_train_loss(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")

print("val_loss = {")
for dtype in dtypes:
parse(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")


if __name__ == "__main__":
run_exps(200, ["bf16", "fp32"], run=True)
run_exps(200, ["bf16"], run=True)
Loading

0 comments on commit 9d59f8d

Please sign in to comment.