From 75ede2a11dd34fc68340ed9757cc84d28844cee9 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 10 Sep 2024 13:48:04 +0000 Subject: [PATCH 01/12] wip add pure torch fsdp --- open_diloco/train_pure_fsdp.py | 183 +++++++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 open_diloco/train_pure_fsdp.py diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py new file mode 100644 index 0000000..b328f2c --- /dev/null +++ b/open_diloco/train_pure_fsdp.py @@ -0,0 +1,183 @@ +import os +from contextlib import nullcontext +import datetime + +import torch +import torch.distributed as dist +from pydantic_config import parse_argv, BaseConfig +from torch.distributed import destroy_process_group, init_process_group + +from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import ( + AutoTokenizer, + DataCollatorForLanguageModeling, + LlamaConfig, + LlamaForCausalLM, + get_cosine_schedule_with_warmup, +) +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + MixedPrecision, +) +from torch.distributed.device_mesh import init_device_mesh +from hivemind.optim.optimizer import logger + +from open_diloco.utils import ( + FakeTokenizedDataset, + get_sharding_strategy, +) + +TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) +TEST_VOCAB_SIZE = 1024 + + +# Function to initialize the distributed process group +def ddp_setup(): + init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + + +def log(message): + logger.info(f"[rank {os.environ['LOCAL_RANK']}] {message}") + + +class DilocoConfig(BaseConfig): + outer_lr: float = 0.7 + local_steps: int = 10 + + +class Config(BaseConfig): + diloco: DilocoConfig = DilocoConfig() + path_model: str = "PrimeIntellect/llama-150m-fresh" + torch_compile: bool = True + attn_implementation: str = "flash_attention_2" + # Data + seq_length: int = 1024 + num_workers: int = 4 + # Optimization + lr: float = 4e-4 + total_batch_size: int = 512 + per_device_train_batch_size: int = 32 + warmup_steps: int = 1000 + total_steps: int = 88_000 + sharding_strategy: str = "SHARD_GRAD_OP" + + +def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: + train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) + + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) + + return StatefulDataLoader( + train_dataset, + collate_fn=data_collator, + batch_size=config.per_device_train_batch_size, + num_workers=config.num_workers, + ) + + +def get_model(config: Config) -> LlamaForCausalLM: + # Load model + config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) + return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) + + +def train(config: Config): + sharding_strategy = get_sharding_strategy(config.sharding_strategy) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + + # batch_size is the total batch size for all GPUs + assert config.total_batch_size % world_size == 0 + batch_size = config.total_batch_size // world_size + + assert batch_size % config.per_device_train_batch_size == 0 + gradient_accumulation_steps = batch_size // config.per_device_train_batch_size + + tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) + tokenizer.pad_token = "" # Ensure pad token is set for models that need it + + train_dataloader = get_dataloader(tokenizer, world_size, rank, local_rank, config) + + model = get_model(config) + model = model.to(local_rank) + + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + nnodes = world_size // local_world_size + device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local")) + + global_pg = device_mesh.get_group("global") + local_pg = device_mesh.get_group("local") + log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}") + + model = FSDP( + model, + sharding_strategy=sharding_strategy, + mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), + use_orig_params=config.torch_compile, + process_group=local_pg, + ) + if config.torch_compile: + model = torch.compile(model) + + # Setup optimizers + inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95)) + # outer_optimizer = torch.optim.SGD(model.parameters(), lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) + + scheduler = get_cosine_schedule_with_warmup( + inner_optimizer, + num_warmup_steps=config.warmup_steps, + num_training_steps=config.total_steps, + ) + + model.train() + + loss_batch = 0 + + train_dataloader_iterator = iter(train_dataloader) + + outer_step = 0 + while True: + if rank == 0: + log(f"outer_step step: {outer_step}") + for inner_step in range(config.diloco.local_steps): + for grad_acc_step in range(gradient_accumulation_steps): + is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 + batch = next(train_dataloader_iterator) + + for key in batch.keys(): + batch[key] = batch[key].to("cuda") + + with model.no_sync() if is_accumulating else nullcontext(): + outputs = model(**batch) + loss = outputs.loss / gradient_accumulation_steps + loss_batch += loss.detach() + + model.clip_grad_norm_(1.0) # gradient clipping + inner_optimizer.step() + scheduler.step() + inner_optimizer.zero_grad() + + if rank == 0: + log( + f"step: {outer_step} inner: {inner_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in inner_optimizer.param_groups][0]}" + ) + + loss_batch = 0 + + for param in model.parameters(): # todo make this like hybrid shard is doing + dist.all_reduce(param, op=dist.ReduceOp.AVG, group=global_pg) + + outer_step += 1 + + +if __name__ == "__main__": + # Allow eager fallback during production so that that the training runs dont die + # However, in development, we want to know that we broke torch compile + torch._dynamo.config.suppress_errors = "PRIME_INTELLECT_DEV" not in os.environ + torch.set_float32_matmul_precision("high") + ddp_setup() + config = Config(**parse_argv()) + train(config) + destroy_process_group() From ee89d3353fd1fbaa7c168ac7aff43c5793c16064 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 10 Sep 2024 14:49:18 +0000 Subject: [PATCH 02/12] add something somewhat working --- open_diloco/train_pure_fsdp.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index b328f2c..43df77b 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -60,7 +60,7 @@ class Config(BaseConfig): per_device_train_batch_size: int = 32 warmup_steps: int = 1000 total_steps: int = 88_000 - sharding_strategy: str = "SHARD_GRAD_OP" + sharding_strategy: str = "FULL_SHARD" def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: @@ -82,6 +82,10 @@ def get_model(config: Config) -> LlamaForCausalLM: return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) +def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: + return [param.data.detach().clone().to("cuda") for param in model.parameters()] + + def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) @@ -123,7 +127,11 @@ def train(config: Config): # Setup optimizers inner_optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr, weight_decay=0.1, betas=(0.9, 0.95)) - # outer_optimizer = torch.optim.SGD(model.parameters(), lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) + + cpu_model = get_offloaded_param( + model + ) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes + outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) scheduler = get_cosine_schedule_with_warmup( inner_optimizer, @@ -152,6 +160,7 @@ def train(config: Config): with model.no_sync() if is_accumulating else nullcontext(): outputs = model(**batch) loss = outputs.loss / gradient_accumulation_steps + loss.backward() loss_batch += loss.detach() model.clip_grad_norm_(1.0) # gradient clipping @@ -166,8 +175,22 @@ def train(config: Config): loss_batch = 0 - for param in model.parameters(): # todo make this like hybrid shard is doing - dist.all_reduce(param, op=dist.ReduceOp.AVG, group=global_pg) + ### the whole sectione below is just a PoC. We need to benchmark and optimizer what is the most efficient: + ## do the all reduce on cpu or on gpu + ## do the outer optimizer step on cpu or on gpu + + for param_offloaded, param in zip( + cpu_model, model.parameters() + ): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff + # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices + param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + + outer_optimizer.step() + outer_optimizer.zero_grad() + + for param_offloaded, param in zip(cpu_model, model.parameters()): + param.data = param_offloaded.data.to("cuda") outer_step += 1 From c217dd4b1c9042468f841e91aff3f65392a2edc2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 09:18:57 +0000 Subject: [PATCH 03/12] add wandb and real data --- open_diloco/train_pure_fsdp.py | 53 ++++++++++++++++++++++++++++++---- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index 43df77b..ccba643 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -1,6 +1,7 @@ import os from contextlib import nullcontext import datetime +from typing import Literal import torch import torch.distributed as dist @@ -21,11 +22,14 @@ ) from torch.distributed.device_mesh import init_device_mesh from hivemind.optim.optimizer import logger - from open_diloco.utils import ( FakeTokenizedDataset, get_sharding_strategy, + WandbLogger, + DummyLogger, ) +from datasets import load_dataset +from datasets.distributed import split_dataset_by_node TIMEOUT_NCCL_MINUTES = os.environ.get("TIMEOUT_NCCL_MINUTES", 120) TEST_VOCAB_SIZE = 1024 @@ -61,10 +65,31 @@ class Config(BaseConfig): warmup_steps: int = 1000 total_steps: int = 88_000 sharding_strategy: str = "FULL_SHARD" + project: str = "debug" + metric_logger_type: Literal["wandb", "dummy"] = "wandb" + fake_data: bool = False + + +def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader: + if config.fake_data: + train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) + else: + ds = load_dataset(config.dataset_name_or_path, "en", streaming=True) + def tokenize_function(data): + outputs = tokenizer( + data["text"], + truncation=True, + max_length=config.seq_length, + padding="max_length", + ) + return outputs -def get_dataloader(tokenizer, world_size, rank, local_rank, config: Config) -> StatefulDataLoader: - train_dataset = FakeTokenizedDataset(config.seq_length, TEST_VOCAB_SIZE) + tokenized_datasets = ds.map(tokenize_function, batched=True, remove_columns=["text", "timestamp", "url"])[ + "train" + ] + + train_dataset = split_dataset_by_node(tokenized_datasets, world_size=world_size, rank=rank) data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) @@ -141,6 +166,10 @@ def train(config: Config): model.train() + if rank == 0: + logger_cls = WandbLogger if config.metric_logger_type == "wandb" else DummyLogger + metric_logger = logger_cls(project=config.project, config=config.model_dump(), resume=False) + loss_batch = 0 train_dataloader_iterator = iter(train_dataloader) @@ -169,9 +198,18 @@ def train(config: Config): inner_optimizer.zero_grad() if rank == 0: - log( - f"step: {outer_step} inner: {inner_step}, loss: {loss_batch.item()}, lr {[group['lr'] for group in inner_optimizer.param_groups][0]}" - ) + real_step = outer_step * config.diloco.local_steps + inner_step + 1 + inner_lr = [group["lr"] for group in inner_optimizer.param_groups][0] + + metrics = { + "Loss": loss_batch.item(), + "step": real_step, + "inner_lr": inner_lr, + } + + metric_logger.log(metrics) + + log(f"step: {real_step}, loss: {loss_batch.item()}, inner_lr: {inner_lr}") loss_batch = 0 @@ -194,6 +232,9 @@ def train(config: Config): outer_step += 1 + if rank == 0: + metric_logger.finish() + if __name__ == "__main__": # Allow eager fallback during production so that that the training runs dont die From 94bc0aee00aa322a87c5429ab90abc1c0b0546a1 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 09:39:24 +0000 Subject: [PATCH 04/12] fix data --- open_diloco/train_pure_fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index ccba643..c1cff51 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -68,6 +68,7 @@ class Config(BaseConfig): project: str = "debug" metric_logger_type: Literal["wandb", "dummy"] = "wandb" fake_data: bool = False + dataset_name_or_path: str = "allenai/c4" def get_dataloader(tokenizer, world_size, rank, config: Config) -> StatefulDataLoader: @@ -127,7 +128,7 @@ def train(config: Config): tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", use_fast=True) tokenizer.pad_token = "" # Ensure pad token is set for models that need it - train_dataloader = get_dataloader(tokenizer, world_size, rank, local_rank, config) + train_dataloader = get_dataloader(tokenizer, world_size, rank, config) model = get_model(config) model = model.to(local_rank) From 5b701d7182d3e4d70d85a1327088ecba963885cb Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 11:44:25 +0000 Subject: [PATCH 05/12] do all reduce on cpu --- open_diloco/train_pure_fsdp.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index c1cff51..02ada3e 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -37,7 +37,8 @@ # Function to initialize the distributed process group def ddp_setup(): - init_process_group(backend="nccl", timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) + init_process_group(timeout=datetime.timedelta(minutes=TIMEOUT_NCCL_MINUTES)) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) @@ -109,7 +110,7 @@ def get_model(config: Config) -> LlamaForCausalLM: def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: - return [param.data.detach().clone().to("cuda") for param in model.parameters()] + return [param.data.detach().clone().to("cpu") for param in model.parameters()] def train(config: Config): @@ -135,9 +136,12 @@ def train(config: Config): local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) nnodes = world_size // local_world_size + + # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local")) + device_mesh_cpu = init_device_mesh("gloo", (nnodes, local_world_size), mesh_dim_names=("global", "local")) - global_pg = device_mesh.get_group("global") + global_pg = device_mesh_cpu.get_group("global") local_pg = device_mesh.get_group("local") log(f"global pg world : {global_pg.size()}, local pg: {local_pg.size()}") @@ -145,7 +149,7 @@ def train(config: Config): model, sharding_strategy=sharding_strategy, mixed_precision=MixedPrecision(param_dtype=torch.bfloat16), - use_orig_params=config.torch_compile, + use_orig_params=True, process_group=local_pg, ) if config.torch_compile: @@ -159,6 +163,9 @@ def train(config: Config): ) # todo: in case of sharded grap op we need to offload the cpu model only once per nodes outer_optimizer = torch.optim.SGD(cpu_model, lr=config.diloco.outer_lr, momentum=0.9, nesterov=True) + # for param in outer_optimizer.param_groups[0]["params"]: + # log(param.device) + scheduler = get_cosine_schedule_with_warmup( inner_optimizer, num_warmup_steps=config.warmup_steps, @@ -179,6 +186,11 @@ def train(config: Config): while True: if rank == 0: log(f"outer_step step: {outer_step}") + # if "momentum_buffer" in outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]: + # momentum_buffer = outer_optimizer.state[outer_optimizer.param_groups[0]['params'][0]]['momentum_buffer'] + # log(f"momentum buffer device: {momentum_buffer.device}, shape: {momentum_buffer.shape}") + # else: + # log("no momentum buffer") for inner_step in range(config.diloco.local_steps): for grad_acc_step in range(gradient_accumulation_steps): is_accumulating = grad_acc_step < gradient_accumulation_steps - 1 @@ -218,16 +230,22 @@ def train(config: Config): ## do the all reduce on cpu or on gpu ## do the outer optimizer step on cpu or on gpu - for param_offloaded, param in zip( - cpu_model, model.parameters() - ): # There is only one big fat tensor in the param because of fsdp 1 bucket stuff + for param_offloaded, param in zip(cpu_model, model.parameters()): # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) - dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + + if param_offloaded.grad.device != torch.device("cpu"): + # gloo does not support AVG + param_offloaded.grad = param_offloaded.grad / global_pg.size() + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg) + else: + dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) outer_optimizer.step() outer_optimizer.zero_grad() + # todo for the SHARD_GRAD_OP strategy we need to do one cpu -> gpu 0 copy and then do + # gpu 0 -> gpu 1,2.. copy as it would be faster for param_offloaded, param in zip(cpu_model, model.parameters()): param.data = param_offloaded.data.to("cuda") From 04dd5859826b4f1665ee6e265109eb8c7ef39c0f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 15:02:46 +0000 Subject: [PATCH 06/12] do all reduce on cpu --- open_diloco/train_pure_fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index 02ada3e..42f3757 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -234,7 +234,7 @@ def train(config: Config): # todo check how to handle the SHARD_GRAD_OP strategy where the weight are replicated across the local devices param_offloaded.grad = param_offloaded.data - param.data.to(param_offloaded.device) - if param_offloaded.grad.device != torch.device("cpu"): + if param_offloaded.grad.device == torch.device("cpu"): # gloo does not support AVG param_offloaded.grad = param_offloaded.grad / global_pg.size() dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.SUM, group=global_pg) From e926e6c37c79b623fdcba2efceb93c8d94c9db0c Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 15:27:05 +0000 Subject: [PATCH 07/12] fix opt step on cpu --- open_diloco/train_pure_fsdp.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index 42f3757..24e33ae 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -110,7 +110,14 @@ def get_model(config: Config) -> LlamaForCausalLM: def get_offloaded_param(model: LlamaForCausalLM) -> list[torch.Tensor]: - return [param.data.detach().clone().to("cpu") for param in model.parameters()] + offloaded_params = [] + for param in model.parameters(): + if param.requires_grad: + offloaded_param = param.data.detach().clone().to("cpu") + offloaded_param.requires_grad = True + offloaded_params.append(offloaded_param) + + return offloaded_params def train(config: Config): @@ -241,6 +248,9 @@ def train(config: Config): else: dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) + for param in outer_optimizer.param_groups[0]["params"]: + print(param.requires_grad) + outer_optimizer.step() outer_optimizer.zero_grad() From 9f680f2792d1f4eefc49c24dcc81c73625cbdbb2 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 15:31:28 +0000 Subject: [PATCH 08/12] fix opt step on cpu --- open_diloco/train_pure_fsdp.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index 24e33ae..ebf30eb 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -248,9 +248,6 @@ def train(config: Config): else: dist.all_reduce(param_offloaded.grad, op=dist.ReduceOp.AVG, group=global_pg) - for param in outer_optimizer.param_groups[0]["params"]: - print(param.requires_grad) - outer_optimizer.step() outer_optimizer.zero_grad() From 8576f1c84b4655188fef7fa9d704aaa5db0d510b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 11 Sep 2024 15:40:47 +0000 Subject: [PATCH 09/12] fix batch size stuff --- open_diloco/train_pure_fsdp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_diloco/train_pure_fsdp.py b/open_diloco/train_pure_fsdp.py index ebf30eb..a9a6afc 100644 --- a/open_diloco/train_pure_fsdp.py +++ b/open_diloco/train_pure_fsdp.py @@ -124,11 +124,12 @@ def train(config: Config): sharding_strategy = get_sharding_strategy(config.sharding_strategy) local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) + local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) rank = int(os.environ["RANK"]) # batch_size is the total batch size for all GPUs - assert config.total_batch_size % world_size == 0 - batch_size = config.total_batch_size // world_size + assert config.total_batch_size % local_world_size == 0 + batch_size = config.total_batch_size // local_world_size assert batch_size % config.per_device_train_batch_size == 0 gradient_accumulation_steps = batch_size // config.per_device_train_batch_size @@ -141,7 +142,6 @@ def train(config: Config): model = get_model(config) model = model.to(local_rank) - local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) nnodes = world_size // local_world_size # right now device mesh does not support two backend so we just create two identicaly mesh expect the backend From 71a9f8978859c263c1d7b5dcacfda172729c66df Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 01:31:32 +0000 Subject: [PATCH 10/12] add helper script --- open_diloco/simulate_multi_node.sh | 67 ++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 open_diloco/simulate_multi_node.sh diff --git a/open_diloco/simulate_multi_node.sh b/open_diloco/simulate_multi_node.sh new file mode 100644 index 0000000..c5def4a --- /dev/null +++ b/open_diloco/simulate_multi_node.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# +# simulate multi nodes on one gpu. start N torchrun on X gpu locally. +# example how to run ./scripts/simulate_multi_node.sh 2 1 src/zeroband/train.py @configs/debug/debug.toml + +# Function to get CUDA devices based on the number of GPUs and index +function get_cuda_devices() { + local num_gpu=$1 + local index=$2 + local start_gpu=$((num_gpu * index)) + local end_gpu=$((start_gpu + num_gpu - 1)) + + if [ "$num_gpu" -eq 1 ]; then + echo $start_gpu + else + echo $(seq -s ',' $start_gpu $end_gpu) + fi +} + +# Array to store PIDs of child processes +child_pids=() + +# Function to kill all child processes +cleanup() { + echo "Cleaning up child processes..." + local killed=0 + for pid in "${child_pids[@]}"; do + if kill -TERM "$pid" 2>/dev/null; then + ((killed++)) + fi + done + wait + echo "All child processes terminated. Killed $killed processes." + exit +} + +# Check if at least three arguments were passed +if [ "$#" -lt 3 ]; then + echo "Usage: $0 [additional_python_args]" + exit 1 +fi + + +N=$1 # Set N from the first argument +NUM_GPU=$2 +shift 2 # Remove the first three arguments so $@ contains only additional Python arguments + +# Register the cleanup function to be called on SIGINT (Ctrl+C) +trap cleanup SIGINT + + +mkdir -p logs + + + +for i in $(seq 0 $(($N - 1 ))) +do + > logs/log$i + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + child_pids+=($!) +done + +tail -f logs/log0 & +child_pids+=($!) + +wait From 88b5250e2a8c7665ad08ba06704bd49151652f44 Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 02:27:56 +0000 Subject: [PATCH 11/12] fix it --- open_diloco/simulate_multi_node.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_diloco/simulate_multi_node.sh b/open_diloco/simulate_multi_node.sh index c5def4a..fde4efa 100644 --- a/open_diloco/simulate_multi_node.sh +++ b/open_diloco/simulate_multi_node.sh @@ -57,7 +57,7 @@ mkdir -p logs for i in $(seq 0 $(($N - 1 ))) do > logs/log$i - CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) uv run torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & + CUDA_VISIBLE_DEVICES=$(get_cuda_devices $NUM_GPU $i) torchrun --nproc_per_node=$NUM_GPU --node-rank $i --rdzv-endpoint localhost:9999 --nnodes=$N $@ > logs/log$i 2>&1 & child_pids+=($!) done From 5957caedd5e57afa15b45fe009061c8fce6de54b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Tue, 24 Sep 2024 04:32:18 +0000 Subject: [PATCH 12/12] fix it --- open_diloco/simulate_multi_node.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 open_diloco/simulate_multi_node.sh diff --git a/open_diloco/simulate_multi_node.sh b/open_diloco/simulate_multi_node.sh old mode 100644 new mode 100755