Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat fsdp diloco #29

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions open_diloco/simulate_multi_node.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/bin/bash

#
# simulate multi nodes on one gpu. start N torchrun on X gpu locally.
# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/debug.toml

# Function to get CUDA devices based on the number of GPUs and index
function get_cuda_devices() {
local num_gpu=$1
local index=$2
local start_gpu=$((num_gpu * index))
local end_gpu=$((start_gpu + num_gpu - 1))

if [ "$num_gpu" -eq 1 ]; then
echo $start_gpu
else
echo $(seq -s ',' $start_gpu $end_gpu)
fi
}

# Array to store PIDs of child processes
child_pids=()

# Function to kill all child processes
cleanup() {
echo "Cleaning up child processes..."
local killed=0
for pid in "${child_pids[@]}"; do
if kill -TERM "$pid" 2>/dev/null; then
((killed++))
fi
done
wait
echo "All child processes terminated. Killed $killed processes."
exit
}

# Check if at least three arguments were passed
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <N> <initial_peer> <num_gpu> [additional_python_args]"
exit 1
fi


N=$1 # Set N from the first argument
NUM_GPU=$2
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments

# Register the cleanup function to be called on SIGINT (Ctrl+C)
trap cleanup SIGINT


mkdir -p logs



for i in $(seq 0 $(($N - 1 )))
do
> logs/log$i
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 &
child_pids+=($!)
done

tail -f logs/log0 &
child_pids+=($!)

wait
273 changes: 273 additions & 0 deletions open_diloco/train_pure_fsdp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import os
from contextlib import nullcontext
import datetime
from typing import Literal

import torch
import torch.distributed as dist
from pydantic_config import parse_argv, BaseConfig
from torch.distributed import destroy_process_group, init_process_group

from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
LlamaConfig,
LlamaForCausalLM,
get_cosine_schedule_with_warmup,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.distributed.device_mesh import init_device_mesh
from hivemind.optim.optimizer import logger
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might not need this logger from hivemind?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do, otherwise we don;t have a nice logger

from open_diloco.utils import (
FakeTokenizedDataset,
get_sharding_strategy,
WandbLogger,
DummyLogger,
)
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node

TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120)
TEST_VOCAB_SIZE = 1024


# Function to initialize the distributed process group
def ddp_setup():
init_process_group(timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES))

torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))


def log(message):
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}")


class DilocoConfig(BaseConfig):
outer_lr: float = 0.7
local_steps: int = 10


class Config(BaseConfig):
diloco: DilocoConfig = DilocoConfig()
path_model: str = "PrimeIntellect/llama-150m-fresh"
torch_compile: bool = True
attn_implementation: str = "flash_attention_2"
# Data
seq_length: int = 1024
num_workers: int = 4
# Optimization
lr: float = 4e-4
total_batch_size: int = 512
per_device_train_batch_size: int = 32
warmup_steps: int = 1000
total_steps: int = 88_000
sharding_strategy: str = "FULL_SHARD"
project: str = "debug"
metric_logger_type: Literal["wandb", "dummy"] = "wandb"
fake_data: bool = False
dataset_name_or_path: str = "allenai/c4"


def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader:
if config.fake_data:
train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE)
else:
ds = load_dataset(config.dataset_name_or_path, "en", streaming=True)

def tokenize_function(data):
outputs = tokenizer(
data["text"],
truncation=True,
max_length=config.seq_length,
padding="max_length",
)
return outputs

tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[
"train"
]

train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

return StatefulDataLoader(
train_dataset,
collate_fn=data_collator,
batch_size=config.per_device_train_batch_size,
num_workers=config.num_workers,
)


def get_model(config: Config) -> LlamaForCausalLM:
# Load model
config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation)
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model)


def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]:
offloaded_params = []
for param in model.parameters():
if param.requires_grad:
offloaded_param = param.data.detach().clone().to("cpu")
offloaded_param.requires_grad = True
offloaded_params.append(offloaded_param)

return offloaded_params


def train(config: Config):
sharding_strategy = get_sharding_strategy(config.sharding_strategy)
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
rank = int(os.environ["RANK"])

# batch_size is the total batch size for all GPUs
assert config.total_batch_size % local_world_size == 0
batch_size = config.total_batch_size // local_world_size

assert batch_size % config.per_device_train_batch_size == 0
gradient_accumulation_steps = batch_size // config.per_device_train_batch_size

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True)
tokenizer.pad_token = "</s>" # Ensure pad token is set for models that need it

train_dataloader = get_dataloader(tokenizer, world_size, rank, config)

model = get_model(config)
model = model.to(local_rank)

nnodes = world_size // local_world_size

# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend
device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local"))
device_mesh_cpu = init_device_mesh("gloo", (nnodes, local_world_size), mesh_dim_names=("global", "local"))

global_pg = device_mesh_cpu.get_group("global")
local_pg = device_mesh.get_group("local")
log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}")

model = FSDP(
model,
sharding_strategy=sharding_strategy,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16),
use_orig_params=True,
process_group=local_pg,
)
if config.torch_compile:
model = torch.compile(model)

# Setup optimizers
inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95))

cpu_model = get_offloaded_param(
model
) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes

Check failure on line 170 in open_diloco/train_pure_fsdp.py

View workflow job for this annotation

GitHub Actions / codespell

grap ==> grep, grape
outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True)

# for param in outer_optimizer.param_groups[0]["params"]:
# log(param.device)

scheduler = get_cosine_schedule_with_warmup(
inner_optimizer,
num_warmup_steps=config.warmup_steps,
num_training_steps=config.total_steps,
)

model.train()

if rank == 0:
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False)

loss_batch = 0

train_dataloader_iterator = iter(train_dataloader)

outer_step = 0
while True:
if rank == 0:
log(f"outer_step step: {outer_step}")
# if "momentum_buffer" in outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]:
# momentum_buffer = outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]['momentum_buffer']
# log(f"momentum buffer device: {momentum_buffer.device}, shape: {momentum_buffer.shape}")
# else:
# log("no momentum buffer")
for inner_step in range(config.diloco.local_steps):
for grad_acc_step in range(gradient_accumulation_steps):
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1
batch = next(train_dataloader_iterator)

for key in batch.keys():
batch[key] = batch[key].to("cuda")

with model.no_sync() if is_accumulating else nullcontext():
outputs = model(**batch)
loss = outputs.loss / gradient_accumulation_steps
loss.backward()
loss_batch += loss.detach()

model.clip_grad_norm_(1.0) # gradient clipping
inner_optimizer.step()
scheduler.step()
inner_optimizer.zero_grad()

if rank == 0:
real_step = outer_step * config.diloco.local_steps + inner_step + 1
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0]

metrics = {
"Loss": loss_batch.item(),
"step": real_step,
"inner_lr": inner_lr,
}

metric_logger.log(metrics)

log(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}")

loss_batch = 0

### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient:
## do the all reduce on cpu or on gpu
## do the outer optimizer step on cpu or on gpu

for param_offloaded, param in zip(cpu_model, model.parameters()):
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device)

if param_offloaded.grad.device == torch.device("cpu"):
# gloo does not support AVG
param_offloaded.grad = param_offloaded.grad / global_pg.size()
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg)
else:
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg)

outer_optimizer.step()
outer_optimizer.zero_grad()

# todo for the SHARD_GRAD_OP strategy we need to do one cpu -> gpu 0 copy and then do
# gpu 0 -> gpu 1,2.. copy as it would be faster
for param_offloaded, param in zip(cpu_model, model.parameters()):
param.data = param_offloaded.data.to("cuda")

outer_step += 1

if rank == 0:
metric_logger.finish()


if __name__ == "__main__":
# Allow eager fallback during production so that that the training runs dont die
# However, in development, we want to know that we broke torch compile
torch._dynamo.config.suppress_errors = "PRIME_INTELLECT_DEV" not in os.environ
torch.set_float32_matmul_precision("high")
ddp_setup()
config = Config(**parse_argv())
train(config)
destroy_process_group()
Loading