From 801192e8fe3715c3a35b76000c12a06032e166b8 Mon Sep 17 00:00:00 2001 From: haileyschoelkopf Date: Fri, 15 Sep 2023 19:45:24 +0000 Subject: [PATCH] workaround for FSX quota --- megatron/checkpointing.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 2c08f46c2..74a0ec491 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -327,19 +327,19 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): raise ValueError("Must be using deepspeed to use neox") torch.distributed.barrier() - upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None - if upload_to_s3: - upload_checkpoint(iteration, neox_args) - # Wait so everyone is done (necessary) - torch.distributed.barrier() if neox_args.keep_last_n_checkpoints is not None: delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints) # Wait so everyone is done (not necessary) torch.distributed.barrier() + upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None + if upload_to_s3: + upload_checkpoint(iteration, neox_args) - + # Wait so everyone is done (necessary) + torch.distributed.barrier() + def load_checkpoint( neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None ):