Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patches for remote sync waiting and log_local checkpointing #290

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion open_lm/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import torch
import torch.distributed as dist
import datetime


def is_global_master(args):
Expand Down Expand Up @@ -59,6 +60,9 @@ def init_distributed_device(args):
args.local_rank = 0
# For testing, allow forcing distributed mode to test distributed code path even on one gpu.
if is_using_distributed() or args.force_distributed:

timeout = datetime.timedelta(seconds=args.backend_timeout) if args.backend_timeout else None

if "SLURM_PROCID" in os.environ:
# DDP via SLURM
args.local_rank, args.rank, env_world_size = world_info_from_env()
Expand All @@ -79,13 +83,14 @@ def init_distributed_device(args):
init_method=args.dist_url,
world_size=args.world_size,
rank=args.rank,
timeout=timeout,
)
else:
# DDP via torchrun, torch.distributed.launch
# Note that this currently assumes that the world size is all gpus in a node.
assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported."
args.local_rank, _, _ = world_info_from_env()
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout)
args.world_size = torch.distributed.get_world_size()
args.rank = torch.distributed.get_rank()
args.distributed = True
Expand Down
1 change: 0 additions & 1 deletion open_lm/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol,
success = remote_sync(local_dir, remote_dir, protocol)
if success:
return True

return False


Expand Down
16 changes: 9 additions & 7 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ def save_checkpoint(
):
cpu_state, optim_state = None, None
if args.logs and args.logs.lower() != "none" and args.fsdp:
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
rank0_only = not args.log_local
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
Expand Down Expand Up @@ -380,7 +381,7 @@ def main(args):
args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to
args.checkpoint_path = os.path.join(log_base_path, "checkpoints")
args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed")
if is_master(args):
if is_master(args, local=args.log_local):
args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else ""
for dirname in [args.tensorboard_path, args.checkpoint_path, args.failed_checkpoint_path]:
if dirname:
Expand Down Expand Up @@ -424,9 +425,9 @@ def main(args):
# start the sync proces if remote-sync is not None
remote_sync_process = None
if is_master(args) and args.remote_sync is not None:
# first make sure it works
# first make sure it works: here, remote_sync_frequency is set to 0 for this initial test
result = remote_sync_with_expon_backoff(
args.remote_sync_frequency,
0,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol,
Expand Down Expand Up @@ -572,7 +573,7 @@ def main(args):
if args.resume is not None and averagers is not None:
load_avg_models(args, averagers)

if is_master(args):
if is_master(args, local=args.log_local):
logging.info(f"Model (has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters):")
logging.info(f"{str(model)}")
logging.info("Params:")
Expand Down Expand Up @@ -717,7 +718,7 @@ def main(args):
raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.")

# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args)
args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args, local=args.log_local)
writer = None
if args.save_logs and args.tensorboard:
assert tensorboard is not None, "Please install tensorboard."
Expand Down Expand Up @@ -931,8 +932,9 @@ def main(args):
if remote_sync_process is not None:
logging.info("Final remote sync.")
terminate_sync_process(remote_sync_process)
# Can just pass in sync_every=0 for last sync, otherwise will unecessarily sleep.
result = remote_sync_with_expon_backoff(
args.remote_sync_frequency,
0,
os.path.join(args.logs, args.name),
os.path.join(args.remote_sync, args.name),
args.remote_sync_protocol,
Expand Down
6 changes: 6 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,12 @@ def parse_args(args):
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)
parser.add_argument(
"--backend-timeout",
type=int,
default=None,
help="This the number of seconds passed into the timeout arg for torch.distributed.init_process_group.",
)

add_model_args(parser)

Expand Down
Loading