-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat fsdp diloco #29
Open
samsja
wants to merge
12
commits into
main
Choose a base branch
from
feat-fsdp-diloco
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Feat fsdp diloco #29
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
75ede2a
wip add pure torch fsdp
samsja ee89d33
add something somewhat working
samsja c217dd4
add wandb and real data
samsja 94bc0ae
fix data
samsja 5b701d7
do all reduce on cpu
samsja 04dd585
do all reduce on cpu
samsja e926e6c
fix opt step on cpu
samsja 9f680f2
fix opt step on cpu
samsja 8576f1c
fix batch size stuff
samsja 71a9f89
add helper script
samsja 88b5250
fix it
samsja 5957cae
fix it
samsja File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#!/bin/bash | ||
|
||
# | ||
# simulate multi nodes on one gpu. start N torchrun on X gpu locally. | ||
# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/debug.toml | ||
|
||
# Function to get CUDA devices based on the number of GPUs and index | ||
function get_cuda_devices() { | ||
local num_gpu=$1 | ||
local index=$2 | ||
local start_gpu=$((num_gpu * index)) | ||
local end_gpu=$((start_gpu + num_gpu - 1)) | ||
|
||
if [ "$num_gpu" -eq 1 ]; then | ||
echo $start_gpu | ||
else | ||
echo $(seq -s ',' $start_gpu $end_gpu) | ||
fi | ||
} | ||
|
||
# Array to store PIDs of child processes | ||
child_pids=() | ||
|
||
# Function to kill all child processes | ||
cleanup() { | ||
echo "Cleaning up child processes..." | ||
local killed=0 | ||
for pid in "${child_pids[@]}"; do | ||
if kill -TERM "$pid" 2>/dev/null; then | ||
((killed++)) | ||
fi | ||
done | ||
wait | ||
echo "All child processes terminated. Killed $killed processes." | ||
exit | ||
} | ||
|
||
# Check if at least three arguments were passed | ||
if [ "$#" -lt 3 ]; then | ||
echo "Usage: $0 <N> <initial_peer> <num_gpu> [additional_python_args]" | ||
exit 1 | ||
fi | ||
|
||
|
||
N=$1 # Set N from the first argument | ||
NUM_GPU=$2 | ||
shift 2 # Remove the first three arguments so $@ contains only additional Python arguments | ||
|
||
# Register the cleanup function to be called on SIGINT (Ctrl+C) | ||
trap cleanup SIGINT | ||
|
||
|
||
mkdir -p logs | ||
|
||
|
||
|
||
for i in $(seq 0 $(($N - 1 ))) | ||
do | ||
> logs/log$i | ||
CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & | ||
child_pids+=($!) | ||
done | ||
|
||
tail -f logs/log0 & | ||
child_pids+=($!) | ||
|
||
wait |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
import os | ||
from contextlib import nullcontext | ||
import datetime | ||
from typing import Literal | ||
|
||
import torch | ||
import torch.distributed as dist | ||
from pydantic_config import parse_argv, BaseConfig | ||
from torch.distributed import destroy_process_group, init_process_group | ||
|
||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
from transformers import ( | ||
AutoTokenizer, | ||
DataCollatorForLanguageModeling, | ||
LlamaConfig, | ||
LlamaForCausalLM, | ||
get_cosine_schedule_with_warmup, | ||
) | ||
from torch.distributed.fsdp import ( | ||
FullyShardedDataParallel as FSDP, | ||
MixedPrecision, | ||
) | ||
from torch.distributed.device_mesh import init_device_mesh | ||
from hivemind.optim.optimizer import logger | ||
from open_diloco.utils import ( | ||
FakeTokenizedDataset, | ||
get_sharding_strategy, | ||
WandbLogger, | ||
DummyLogger, | ||
) | ||
from datasets import load_dataset | ||
from datasets.distributed import split_dataset_by_node | ||
|
||
TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) | ||
TEST_VOCAB_SIZE = 1024 | ||
|
||
|
||
# Function to initialize the distributed process group | ||
def ddp_setup(): | ||
init_process_group(timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) | ||
|
||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
|
||
|
||
def log(message): | ||
logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}") | ||
|
||
|
||
class DilocoConfig(BaseConfig): | ||
outer_lr: float = 0.7 | ||
local_steps: int = 10 | ||
|
||
|
||
class Config(BaseConfig): | ||
diloco: DilocoConfig = DilocoConfig() | ||
path_model: str = "PrimeIntellect/llama-150m-fresh" | ||
torch_compile: bool = True | ||
attn_implementation: str = "flash_attention_2" | ||
# Data | ||
seq_length: int = 1024 | ||
num_workers: int = 4 | ||
# Optimization | ||
lr: float = 4e-4 | ||
total_batch_size: int = 512 | ||
per_device_train_batch_size: int = 32 | ||
warmup_steps: int = 1000 | ||
total_steps: int = 88_000 | ||
sharding_strategy: str = "FULL_SHARD" | ||
project: str = "debug" | ||
metric_logger_type: Literal["wandb", "dummy"] = "wandb" | ||
fake_data: bool = False | ||
dataset_name_or_path: str = "allenai/c4" | ||
|
||
|
||
def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader: | ||
if config.fake_data: | ||
train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) | ||
else: | ||
ds = load_dataset(config.dataset_name_or_path, "en", streaming=True) | ||
|
||
def tokenize_function(data): | ||
outputs = tokenizer( | ||
data["text"], | ||
truncation=True, | ||
max_length=config.seq_length, | ||
padding="max_length", | ||
) | ||
return outputs | ||
|
||
tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[ | ||
"train" | ||
] | ||
|
||
train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) | ||
|
||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | ||
|
||
return StatefulDataLoader( | ||
train_dataset, | ||
collate_fn=data_collator, | ||
batch_size=config.per_device_train_batch_size, | ||
num_workers=config.num_workers, | ||
) | ||
|
||
|
||
def get_model(config: Config) -> LlamaForCausalLM: | ||
# Load model | ||
config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) | ||
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) | ||
|
||
|
||
def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: | ||
offloaded_params = [] | ||
for param in model.parameters(): | ||
if param.requires_grad: | ||
offloaded_param = param.data.detach().clone().to("cpu") | ||
offloaded_param.requires_grad = True | ||
offloaded_params.append(offloaded_param) | ||
|
||
return offloaded_params | ||
|
||
|
||
def train(config: Config): | ||
sharding_strategy = get_sharding_strategy(config.sharding_strategy) | ||
local_rank = int(os.environ["LOCAL_RANK"]) | ||
world_size = int(os.environ["WORLD_SIZE"]) | ||
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) | ||
rank = int(os.environ["RANK"]) | ||
|
||
# batch_size is the total batch size for all GPUs | ||
assert config.total_batch_size % local_world_size == 0 | ||
batch_size = config.total_batch_size // local_world_size | ||
|
||
assert batch_size % config.per_device_train_batch_size == 0 | ||
gradient_accumulation_steps = batch_size // config.per_device_train_batch_size | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) | ||
tokenizer.pad_token = "</s>" # Ensure pad token is set for models that need it | ||
|
||
train_dataloader = get_dataloader(tokenizer, world_size, rank, config) | ||
|
||
model = get_model(config) | ||
model = model.to(local_rank) | ||
|
||
nnodes = world_size // local_world_size | ||
|
||
# right now device mesh does not support two backend so we just create two identicaly mesh expect the backend | ||
device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local")) | ||
device_mesh_cpu = init_device_mesh("gloo", (nnodes, local_world_size), mesh_dim_names=("global", "local")) | ||
|
||
global_pg = device_mesh_cpu.get_group("global") | ||
local_pg = device_mesh.get_group("local") | ||
log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}") | ||
|
||
model = FSDP( | ||
model, | ||
sharding_strategy=sharding_strategy, | ||
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), | ||
use_orig_params=True, | ||
process_group=local_pg, | ||
) | ||
if config.torch_compile: | ||
model = torch.compile(model) | ||
|
||
# Setup optimizers | ||
inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95)) | ||
|
||
cpu_model = get_offloaded_param( | ||
model | ||
) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes | ||
outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) | ||
|
||
# for param in outer_optimizer.param_groups[0]["params"]: | ||
# log(param.device) | ||
|
||
scheduler = get_cosine_schedule_with_warmup( | ||
inner_optimizer, | ||
num_warmup_steps=config.warmup_steps, | ||
num_training_steps=config.total_steps, | ||
) | ||
|
||
model.train() | ||
|
||
if rank == 0: | ||
logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger | ||
metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False) | ||
|
||
loss_batch = 0 | ||
|
||
train_dataloader_iterator = iter(train_dataloader) | ||
|
||
outer_step = 0 | ||
while True: | ||
if rank == 0: | ||
log(f"outer_step step: {outer_step}") | ||
# if "momentum_buffer" in outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]: | ||
# momentum_buffer = outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]['momentum_buffer'] | ||
# log(f"momentum buffer device: {momentum_buffer.device}, shape: {momentum_buffer.shape}") | ||
# else: | ||
# log("no momentum buffer") | ||
for inner_step in range(config.diloco.local_steps): | ||
for grad_acc_step in range(gradient_accumulation_steps): | ||
is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 | ||
batch = next(train_dataloader_iterator) | ||
|
||
for key in batch.keys(): | ||
batch[key] = batch[key].to("cuda") | ||
|
||
with model.no_sync() if is_accumulating else nullcontext(): | ||
outputs = model(**batch) | ||
loss = outputs.loss / gradient_accumulation_steps | ||
loss.backward() | ||
loss_batch += loss.detach() | ||
|
||
model.clip_grad_norm_(1.0) # gradient clipping | ||
inner_optimizer.step() | ||
scheduler.step() | ||
inner_optimizer.zero_grad() | ||
|
||
if rank == 0: | ||
real_step = outer_step * config.diloco.local_steps + inner_step + 1 | ||
inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] | ||
|
||
metrics = { | ||
"Loss": loss_batch.item(), | ||
"step": real_step, | ||
"inner_lr": inner_lr, | ||
} | ||
|
||
metric_logger.log(metrics) | ||
|
||
log(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}") | ||
|
||
loss_batch = 0 | ||
|
||
### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: | ||
## do the all reduce on cpu or on gpu | ||
## do the outer optimizer step on cpu or on gpu | ||
|
||
for param_offloaded, param in zip(cpu_model, model.parameters()): | ||
# todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices | ||
param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) | ||
|
||
if param_offloaded.grad.device == torch.device("cpu"): | ||
# gloo does not support AVG | ||
param_offloaded.grad = param_offloaded.grad / global_pg.size() | ||
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg) | ||
else: | ||
dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) | ||
|
||
outer_optimizer.step() | ||
outer_optimizer.zero_grad() | ||
|
||
# todo for the SHARD_GRAD_OP strategy we need to do one cpu -> gpu 0 copy and then do | ||
# gpu 0 -> gpu 1,2.. copy as it would be faster | ||
for param_offloaded, param in zip(cpu_model, model.parameters()): | ||
param.data = param_offloaded.data.to("cuda") | ||
|
||
outer_step += 1 | ||
|
||
if rank == 0: | ||
metric_logger.finish() | ||
|
||
|
||
if __name__ == "__main__": | ||
# Allow eager fallback during production so that that the training runs dont die | ||
# However, in development, we want to know that we broke torch compile | ||
torch._dynamo.config.suppress_errors = "PRIME_INTELLECT_DEV" not in os.environ | ||
torch.set_float32_matmul_precision("high") | ||
ddp_setup() | ||
config = Config(**parse_argv()) | ||
train(config) | ||
destroy_process_group() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we might not need this logger from hivemind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do, otherwise we don;t have a nice logger