Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Example] add an example of running open Mixtral 8x7B in 4D using veScale #24

Merged
merged 2 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/example/mixtral_4D_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# veScale Mixtral Example

## Overview

In this directory, we provides an 4D parallelism example of using veScale to run
a [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) that is directly imported
from HuggingFace without any model code modifications.


## Run

### Single Machine 8 cards
```
torchrun --nproc-per-node=8 --nnodes=1 --master-port=42516 -- python/example/mixtral_4D_benchmark/mixtral_train.py --num_hidden_layers=16
```
This will start a 8-cards MFU benchmark for Mixtral with veScale with dp=1 and tp=8.

### Distributed Environment (2 Machine 16 cards example)
```
# You may need to pull up a suitable distributed cluster environment
torchrun --nproc-per-node=8 --nnodes=1 python/example/mixtral_4D_benchmark/mixtral_train.py --tp 8 --dp 2
```
This will start a 16 cards MFU benchmark for Mixtral with veScale with dp=2 and tp=8.

### Options
1. `--bsz`: the total number of batch size for one iteration. The default is 16.
2. `--seqlen`: the sequence lengtht of the input. The default value is 256.
3. `--dp`: the amount of data parallelism (DDP). The default is 1.
4. `--tp`: the amount of tensor parallelism. The default is 8.


## Caveats
1. The scripts are purely for demonstration propose and mfu calculation. You need to write your own training script
it in order to fine-tune Mixtral with your data.
150 changes: 150 additions & 0 deletions python/example/mixtral_4D_benchmark/mixtral_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

import argparse
import os

import torch
import torch.distributed as dist

from vescale.dtensor.device_mesh import DeviceMesh
from vescale.dmodule import parallelize_module
from vescale.ddp.distributed_data_parallel import DistributedDataParallel as DDP
from vescale.optim.distributed_optimizer import DistributedOptimizer
from vescale.initialize.deferred_init import deferred_init, is_deferred

from transformers.models.mixtral.modeling_mixtral import MixtralModel
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from sharding_plan import mixtral_plan

local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])


def estimate_mixtral(config, bsz, sqence_length):
embed = 4 * bsz * sqence_length * config.hidden_size
# MixtralMoE consists of 3 linear layers.
ff = 3 * 2 * config.num_experts_per_tok * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn_qkv = 2 * bsz * sqence_length * config.hidden_size * 3 * config.hidden_size
attn_mask = 2 * sqence_length * config.hidden_size
attn_proj = 2 * config.hidden_size * config.intermediate_size * bsz * sqence_length
attn = attn_qkv + attn_mask + attn_proj
return embed + (ff + attn) * config.num_hidden_layers


def run_mixtral(args):
torch.random.manual_seed(777)
device_list = [
list(range(i * args.tp, min((i + 1) * args.tp, world_size))) for i in range(max(world_size // args.tp, 1))
]
device_mesh = DeviceMesh("cuda", device_list, mesh_dim_names=("DP", "TP"))
torch.cuda.set_device(local_rank)

mixtral_config = MixtralConfig(
vocab_size=args.vocab_size,
hidden_size=args.hidden_size,
intermediate_size=args.intermediate_size,
num_hidden_layers=args.num_hidden_layers,
num_attention_heads=args.num_attention_heads,
num_key_value_heads=args.num_key_value_heads,
)

model_deferred = deferred_init(MixtralModel, mixtral_config)

mixtral_model = parallelize_module(
model_deferred,
device_mesh["TP"],
mixtral_plan,
factory=True,
)

assert not is_deferred(mixtral_model)

ddp_mixtral_model = DDP(
mixtral_model,
device_mesh["DP"],
accumulate_allreduce_grads_in_fp32=True,
overlap_grad_reduce=False,
use_distributed_optimizer=True,
)

doptim = DistributedOptimizer(
torch.optim.Adam(mixtral_model.parameters(), lr=0.01),
models=[ddp_mixtral_model],
overlap_param_gather=True,
)

dataloader = []
for iter in range(args.iter):
data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda()
dist.all_reduce(data, op=dist.ReduceOp.MAX)
dataloader.append(data)

# =----- warmup -----= #
for _ in range(args.warmup):
data = torch.randint(0, args.vocab_size, (args.bsz, args.seqlen)).cuda()
doptim.zero_grad()
ddp_mixtral_model(data).last_hidden_state.to_local().sum().backward()
ddp_mixtral_model.finish_grad_sync()
doptim.step()

# =----- training ----= #
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
for iter in range(args.iter):
doptim.zero_grad()
x = dataloader[iter]
ddp_mixtral_model(x).last_hidden_state.to_local().sum().backward()
ddp_mixtral_model.finish_grad_sync()
doptim.step()
end.record()
torch.cuda.synchronize()
exec_t = start.elapsed_time(end) / 1000 / args.iter
# masure mfu
if local_rank == 0:
# Note we are using FP32. The peak FLOPs of H100 is 59 TFLOPs.
total_flops = 59 * (10**12) * device_mesh.ndevice
print(f"1 iter time: {exec_t}")
mixtral_flops = estimate_mixtral(mixtral_config, args.bsz, args.seqlen)
print(f"fwd mixtral flops: {mixtral_flops}")
# bwd ~= fwd * 2
print("mfu:", mixtral_flops * 3 * args.dp * 100 / exec_t / total_flops)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--warmup", type=int, default=2)
parser.add_argument("--iter", type=int, default=10)
parser.add_argument("--vocab_size", type=int, default=32000)
parser.add_argument("--hidden_size", type=int, default=4096)
parser.add_argument("--intermediate_size", type=int, default=14336)
parser.add_argument("--num_hidden_layers", type=int, default=16)
parser.add_argument("--num_attention_heads", type=int, default=32)
parser.add_argument("--num_key_value_heads", type=int, default=8)
parser.add_argument("--bsz", type=int, default=16)
parser.add_argument("--seqlen", type=int, default=256)
parser.add_argument("--dp", type=int, default=1)
parser.add_argument("--tp", type=int, default=8)
return parser


if __name__ == "__main__":
parser = parse_args()
args = parser.parse_args()
run_mixtral(args)
69 changes: 69 additions & 0 deletions python/example/mixtral_4D_benchmark/sharding_plan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
################################################################################
#
# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################

"""This file contain TP/SP sharding plans for Mixtral example code."""

from vescale.dtensor.placement_types import Replicate, Shard


param_sharding_plan = {
"embed_tokens.weight": [Replicate()],
r"layers.\d+.input_layernorm.weight": [Replicate()], # MixtralRMSNorm
r"layers.\d+.self_attn.q_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.k_proj.weight": [Shard(0)],
r"layers.\d+.self_attn.v_proj.weight": [Shard(0)],
# TODO: buggy, cos_cached or sin_cached can be updated or recreated if seqlen exceeds the max seqlen.
r"layers.\d+.self_attn.rotary_emb.layers.\d+.cos_cached": [Replicate()],
r"layers.\d+.self_attn.rotary_emb.layers.\d+.sin_cached": [Replicate()],
r"layers.\d+.self_attn.o_proj.weight": [Shard(1)],
r"layers.\d+.post_attention_layernorm.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.gate.weight": [Replicate()],
r"layers.\d+.block_sparse_moe.experts.\d+.w1.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.weight": [Shard(0)],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.weight": [Shard(1)],
"norm.weight": [Replicate()],
}

fwd_resharding_plan = {
# TODO: buggy: attn mask is torch.Tensor, in training, it's a None
r".input": {"input_ids": [Replicate()], "attention_mask": [Replicate()]},
"embed_tokens.input": [[Replicate()]],
# No SP
# r"layers.\d+.input_layernorm.input": [[Replicate()]],
# r"layers.\d+.input_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.input_layernorm.input": [[Shard(1)]],
r"layers.\d+.input_layernorm.output": [[Shard(1)]],
r"layers.\d+.self_attn.input": [[Replicate()]],
r"layers.\d+.self_attn.output": {"attn_output": [Replicate()], "attn_weights": None, "past_key_value": None},
r"layers.\d+.self_attn.o_proj.output": [[Replicate()]],
# No SP
# r"layers.\d+.post_attention_layernorm.input": [[Replicate()]],
# r"layers.\d+.post_attention_layernorm.output": [[Replicate()]],
# SP
r"layers.\d+.post_attention_layernorm.input": [[Shard(1)]],
r"layers.\d+.post_attention_layernorm.output": [[Shard(1)]],
r"layers.\d+.block_sparse_moe.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.gate.output": [[Replicate()]],
r"layers.\d+.block_sparse_moe.output": {"final_hidden_states": [Replicate()], "router_logits": [Replicate()]},
r"layers.\d+.block_sparse_moe.experts.\d+.w1.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w3.input": [[Replicate()]],
r"layers.\d+.block_sparse_moe.experts.\d+.w2.output": [[Replicate()]],
"norm.input": [[Replicate()]],
}

mixtral_plan = {"parameter": param_sharding_plan, "forward": fwd_resharding_plan}
22 changes: 7 additions & 15 deletions python/example/nanogpt_4D_finetune/base_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
import time
import math
import pickle
from contextlib import nullcontext

import numpy as np
import torch
Expand Down Expand Up @@ -138,7 +137,6 @@
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# poor man's data loader
data_dir = os.path.join("data", dataset)
Expand Down Expand Up @@ -227,9 +225,7 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
model.crop_block_size(block_size)
model_args["block_size"] = block_size # so that the checkpoint will have the right value
model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))
model.to(ptdtype)

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
Expand Down Expand Up @@ -258,8 +254,7 @@ def estimate_loss():
losses = torch.zeros(eval_iters // factor).to(device)
for k in range(eval_iters // factor):
X, Y = get_batch(split, batch_size * factor, local_batch_size * factor)
with ctx:
logits, loss = model(X, Y)
logits, loss = model(X, Y)
losses[k] = loss.item() / ddp_world_size
if ddp:
all_reduce(losses)
Expand Down Expand Up @@ -345,20 +340,17 @@ def get_lr(it):
# I really dislike that this bloats the code and forces us to repeat code
# looking at the source of that context manager, it just toggles this variable
model.require_backward_grad_sync = micro_step == gradient_accumulation_steps - 1
with ctx:
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
logits, loss = model(X, Y)
loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
# immediately async prefetch next batch while model is doing the forward pass on the GPU
X, Y = get_batch("train")
# backward pass, with gradient scaling if training in fp16
scaler.scale(loss).backward()
loss.backward()
# clip the gradient
if grad_clip != 0.0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
# step the optimizer and scaler if training in fp16
scaler.step(optimizer)
scaler.update()
# step the optimizer
optimizer.step()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
dataset = "shakespeare"
init_from = "gpt2" # this is the smallest GPT-2 model

wandb_log = True # feel free to turn on
wandb_log = False # feel free to turn on
wandb_project = f"{init_from}_finetune_{dataset}"

# only save checkpoints if the validation loss improves
Expand Down
38 changes: 23 additions & 15 deletions python/example/nanogpt_4D_finetune/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,42 +47,50 @@ def parse(log_fn, name=None):
print(f'"{name}": {val_losses},')


GPU_CNT = 4
DP_SIZES = [4, 2, 1]
SINGLE_GPU_RUN = "python3"
MULTI_GPU_RUN = "torchrun --standalone --nproc_per_node=4"
CONFIG = "config/finetune_shakespeare.py"
LOG_PREFIX = ""


def run_exps(max_iters, dtypes, run=True):
os.makedirs("logs", exist_ok=True)
if run:
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"python3 base_train.py config/finetune_shakespeare.py --compile=False --wandb_log=False --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/1gpu_{dtype}_max_iters_{max_iters}.log"
cmd = f"{SINGLE_GPU_RUN} base_train.py {CONFIG} --compile=False --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
for dtype in dtypes:
dt = "bfloat16" if dtype == "bf16" else "float32"
cmd = f"torchrun --standalone --nproc_per_node=4 finetune_4D.py config/finetune_shakespeare.py --compile=False --use_DO=True --wandb_log=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
cmd = f"{MULTI_GPU_RUN} finetune_4D.py {CONFIG} --compile=False --DDP_grads_in_fp32=False --dp_size={dp_size} --tp_size={tp_size} --max_iters={max_iters} --dtype='{dt}'"
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
print(f"run {cmd} > {log_fn} 2> {log_fn}.err")
os.system(f"{cmd} > {log_fn} 2> {log_fn}.err")

print("train_loss = {")
for dtype in dtypes:
parse_train_loss(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse_train_loss(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")

print("val_loss = {")
for dtype in dtypes:
parse(f"logs/1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in [1, 2, 4]:
tp_size = 4 // dp_size
log_fn = f"logs/4gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse(f"logs/{LOG_PREFIX}_1gpu_{dtype}_max_iters_{max_iters}.log", f"1GPU_{dtype}")
for dp_size in DP_SIZES:
tp_size = GPU_CNT // dp_size
log_fn = f"logs/{LOG_PREFIX}_{GPU_CNT}gpu_dp{dp_size}_tp{tp_size}_{dtype}_max_iters_{max_iters}.log"
parse(log_fn, f"4GPU_DP{dp_size}_TP{tp_size}_{dtype}")
print("}")


if __name__ == "__main__":
run_exps(200, ["bf16", "fp32"], run=True)
run_exps(200, ["bf16"], run=True)
Loading