Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Drop a Worker #32

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions open_diloco/train_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class HvConfig(BaseConfig):
world_rank: int
galaxy_size: int
fail_rank_drop: bool = False # fail if we lose a diloco worker


@model_validator(mode="before")
def cast_str_to_list(cls, values: dict[str, Any]) -> dict[str, Any]:
Expand Down Expand Up @@ -179,7 +180,7 @@ def train(config: Config):
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["RANK"])

world_rank_list = list(range(config.hv.galaxy_size))
world_messenger_hv = config.hv is not None and local_rank == 0

# batch_size is the total batch size for all GPUs
Expand Down Expand Up @@ -357,7 +358,7 @@ def scheduler_fn(opt):
max_num_peers = 0

log_activations = {}

log_drop = True
for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps):
real_step = (step + 1) // gradient_accumulation_steps
is_accumulating = bool((step + 1) % gradient_accumulation_steps)
Expand Down Expand Up @@ -449,12 +450,12 @@ def scheduler_fn(opt):
metrics.update(log_activations)
log_activations = {}

if world_messenger_hv and num_peers < max_num_peers:
log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")
if config.hv.fail_rank_drop:
raise ValueError(
f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}"
)
# if world_messenger_hv and num_peers < max_num_peers:
#log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")
#if config.hv.fail_rank_drop:
#raise ValueError(
# f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}"
#)

current_time = time.time()

Expand Down Expand Up @@ -510,7 +511,13 @@ def scheduler_fn(opt):

if config.max_steps is not None and real_step >= config.max_steps:
break


if real_step >= int(config.total_steps)//2:
if log_drop:
log(f"Dropping worker world ranks {world_rank_list[config.hv.galaxy_size//2:]}")
log_drop = False
if config.hv is not None and config.hv.world_rank in world_rank_list[config.hv.galaxy_size//2:]:
break
log("Training completed.")
if rank == 0:
metric_logger.finish()
Expand Down
Loading