From 07441b66e1c38d9618b150cad543d8e42df82dfd Mon Sep 17 00:00:00 2001 From: Fares Obeid Date: Tue, 24 Sep 2024 16:43:08 +0000 Subject: [PATCH 1/2] Drop a Worker --- open_diloco/train_fsdp.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 4d5ef3e..0cf51fe 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -91,6 +91,7 @@ class HvConfig(BaseConfig): world_rank: int galaxy_size: int fail_rank_drop: bool = False # fail if we lose a diloco worker + @model_validator(mode="before") def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]: @@ -449,12 +450,12 @@ def scheduler_fn(opt): metrics.update(log_activations) log_activations = {} - if world_messenger_hv and num_peers < max_num_peers: - log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}") - if config.hv.fail_rank_drop: - raise ValueError( - f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}" - ) + # if world_messenger_hv and num_peers < max_num_peers: + #log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}") + #if config.hv.fail_rank_drop: + #raise ValueError( + # f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}" + #) current_time = time.time() @@ -510,7 +511,9 @@ def scheduler_fn(opt): if config.max_steps is not None and real_step >= config.max_steps: break - + if real_step >= 50: + if config.hv is not None and config.hv.world_rank == 1: + break log("Training completed.") if rank == 0: metric_logger.finish() From 04ff6113e6c2e7573c5dca039b98cd7e867424e8 Mon Sep 17 00:00:00 2001 From: Fares Obeid Date: Thu, 26 Sep 2024 22:01:15 +0000 Subject: [PATCH 2/2] drop half workers --- open_diloco/train_fsdp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/open_diloco/train_fsdp.py b/open_diloco/train_fsdp.py index 0cf51fe..310ee32 100644 --- a/open_diloco/train_fsdp.py +++ b/open_diloco/train_fsdp.py @@ -180,7 +180,7 @@ def train(config: Config): local_rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) rank = int(os.environ["RANK"]) - + world_rank_list = list(range(config.hv.galaxy_size)) world_messenger_hv = config.hv is not None and local_rank == 0 # batch_size is the total batch size for all GPUs @@ -358,7 +358,7 @@ def scheduler_fn(opt): max_num_peers = 0 log_activations = {} - + log_drop = True for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps): real_step = (step + 1) // gradient_accumulation_steps is_accumulating = bool((step + 1) % gradient_accumulation_steps) @@ -511,8 +511,12 @@ def scheduler_fn(opt): if config.max_steps is not None and real_step >= config.max_steps: break - if real_step >= 50: - if config.hv is not None and config.hv.world_rank == 1: + + if real_step >= int(config.total_steps)//2: + if log_drop: + log(f"Dropping worker world ranks {world_rank_list[config.hv.galaxy_size//2:]}") + log_drop = False + if config.hv is not None and config.hv.world_rank in world_rank_list[config.hv.galaxy_size//2:]: break log("Training completed.") if rank == 0: