diff --git a/open_lm/distributed.py b/open_lm/distributed.py index 8c07d663..f49e564f 100644 --- a/open_lm/distributed.py +++ b/open_lm/distributed.py @@ -3,6 +3,7 @@ import logging import torch import torch.distributed as dist +import datetime def is_global_master(args): @@ -59,6 +60,9 @@ def init_distributed_device(args): args.local_rank = 0 # For testing, allow forcing distributed mode to test distributed code path even on one gpu. if is_using_distributed() or args.force_distributed: + + timeout = datetime.timedelta(seconds=args.backend_timeout) if args.backend_timeout else None + if "SLURM_PROCID" in os.environ: # DDP via SLURM args.local_rank, args.rank, env_world_size = world_info_from_env() @@ -79,13 +83,14 @@ def init_distributed_device(args): init_method=args.dist_url, world_size=args.world_size, rank=args.rank, + timeout=timeout, ) else: # DDP via torchrun, torch.distributed.launch # Note that this currently assumes that the world size is all gpus in a node. assert args.preset_world_size is None, "--preset_world_size with torchrun is not currently supported." args.local_rank, _, _ = world_info_from_env() - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, timeout=timeout) args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() args.distributed = True diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py index f91919b2..12b6ea88 100644 --- a/open_lm/file_utils.py +++ b/open_lm/file_utils.py @@ -76,7 +76,6 @@ def remote_sync_with_expon_backoff(sync_every, local_dir, remote_dir, protocol, success = remote_sync(local_dir, remote_dir, protocol) if success: return True - return False diff --git a/open_lm/main.py b/open_lm/main.py index 7c80f558..a3f77ce6 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -211,7 +211,8 @@ def save_checkpoint( ): cpu_state, optim_state = None, None if args.logs and args.logs.lower() != "none" and args.fsdp: - save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + rank0_only = not args.log_local + save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=rank0_only) with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy): cpu_state = model.state_dict() optim_state = FSDP.optim_state_dict(model, optimizer) @@ -380,7 +381,7 @@ def main(args): args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to args.checkpoint_path = os.path.join(log_base_path, "checkpoints") args.failed_checkpoint_path = os.path.join(log_base_path, "checkpoints_failed") - if is_master(args): + if is_master(args, local=args.log_local): args.tensorboard_path = os.path.join(log_base_path, "tensorboard") if args.tensorboard else "" for dirname in [args.tensorboard_path, args.checkpoint_path, args.failed_checkpoint_path]: if dirname: @@ -424,9 +425,9 @@ def main(args): # start the sync proces if remote-sync is not None remote_sync_process = None if is_master(args) and args.remote_sync is not None: - # first make sure it works + # first make sure it works: here, remote_sync_frequency is set to 0 for this initial test result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, + 0, os.path.join(args.logs, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol, @@ -572,7 +573,7 @@ def main(args): if args.resume is not None and averagers is not None: load_avg_models(args, averagers) - if is_master(args): + if is_master(args, local=args.log_local): logging.info(f"Model (has {sum(p.numel() for p in model.parameters() if p.requires_grad)} parameters):") logging.info(f"{str(model)}") logging.info("Params:") @@ -717,7 +718,7 @@ def main(args): raise ValueError(f"Unknown scheduler, {args.lr_scheduler}. Available options are: cosine, const.") # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 - args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args, local=args.log_local) writer = None if args.save_logs and args.tensorboard: assert tensorboard is not None, "Please install tensorboard." @@ -931,8 +932,9 @@ def main(args): if remote_sync_process is not None: logging.info("Final remote sync.") terminate_sync_process(remote_sync_process) + # Can just pass in sync_every=0 for last sync, otherwise will unecessarily sleep. result = remote_sync_with_expon_backoff( - args.remote_sync_frequency, + 0, os.path.join(args.logs, args.name), os.path.join(args.remote_sync, args.name), args.remote_sync_protocol, diff --git a/open_lm/params.py b/open_lm/params.py index 0a7a3f64..4c66c8ee 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -787,6 +787,12 @@ def parse_args(args): default=0, help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed", ) + parser.add_argument( + "--backend-timeout", + type=int, + default=None, + help="This the number of seconds passed into the timeout arg for torch.distributed.init_process_group.", + ) add_model_args(parser)