From 07aa037b57b2913138a6cf10c3e38027bb311b55 Mon Sep 17 00:00:00 2001 From: James Kunstle Date: Tue, 1 Oct 2024 00:30:37 +0000 Subject: [PATCH] adds Accelerate full-state (opt, lr_sched, params) saving and reloading for DeepSpeed and FSDP. Signed-off-by: James Kunstle --- src/instructlab/training/config.py | 3 +- src/instructlab/training/main_ds.py | 53 +++++++++---- src/instructlab/training/utils.py | 111 ++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 14 deletions(-) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 751976c6..256866db 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -169,7 +169,8 @@ class TrainingArgs(BaseModel): warmup_steps: int is_padding_free: bool random_seed: int = 42 - checkpoint_at_epoch: bool = False + checkpoint_at_epoch: bool = True + accelerate_full_state_at_epoch: bool = True mock_data: Optional[bool] = False mock_data_len: int = 0 diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 0f8b2507..bd99d80b 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -43,9 +43,11 @@ apply_gradient_checkpointing, convert_loss_to_reduce_sum, ensure_loadable_granite_checkpoint, + load_latest_full_state, prepare_peft_model, prepare_universal_checkpoint_from_latest, retrieve_chat_template, + save_checkpoint, save_hf_format_accelerate, set_random_seed, setup_logger, @@ -316,6 +318,10 @@ def train( batch_size = args.effective_batch_size // grad_accum samples_seen = 0 + if hasattr(args, "samples_seen"): + print(f"\033[93mUpdating 'samples_seen' {args.samples_seen}\033[0m") + samples_seen = args.samples_seen + if args.save_samples > 0: args.save_samples = (args.save_samples // batch_size) * batch_size ( @@ -335,7 +341,7 @@ def train( ) global_grad_norm = None - for epoch in range(args.num_epochs): + for epoch in range(args.current_epoch, args.num_epochs): if args.sampler in ("multipack"): train_loader.batch_sampler.set_epoch(epoch) elif args.sampler in ("distributed"): @@ -346,6 +352,7 @@ def train( if local_rank == 0: inner_pb = tqdm(range(len(train_loader)), desc=f"Epoch {epoch}") + # blast through the batches in the train loader up to the last step within the epoch. for batch in train_loader: if global_step <= args.last_step: # in the case of resuming, last_step > 0 @@ -437,13 +444,14 @@ def train( if args.save_samples > 0 and ( global_step * batch_size % args.save_samples == 0 ): - save_hf_format_accelerate( - args, - model, - tokenizer, - accelerator, - samples_seen, + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, is_lora=bool(args.lora_r), + hf_format=True, ) # if ( @@ -461,13 +469,16 @@ def train( inner_pb.update(1) torch.cuda.empty_cache() if args.checkpoint_at_epoch: - save_hf_format_accelerate( - args, - model, - tokenizer, - accelerator, - samples_seen, + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, is_lora=bool(args.lora_r), + full_state=args.accelerate_full_state_at_epoch, + hf_format=True, + epoch=epoch, ) if args.save_last: @@ -588,6 +599,8 @@ def main(args): args, tokenizer, train_loader, grad_accum ) + load_latest_full_state(args=args, accelerator=accelerator) + train( args, model, @@ -661,6 +674,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.checkpoint_at_epoch: command.append("--checkpoint_at_epoch") + if train_args.accelerate_full_state_at_epoch: + command.append("--accelerate_full_state_at_epoch") + if train_args.mock_data: command.append("--mock_data") if train_args.mock_len: @@ -775,6 +791,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: parser.add_argument("--data_path", type=str) parser.add_argument("--output_dir", type=str) parser.add_argument("--num_epochs", type=int, default=1) + parser.add_argument( + "--current_epoch", + type=int, + default=0, + help="Helpful flag for resuming on a later epoch. Sets dataloader correctly.", + ) parser.add_argument( "--last_step", type=int, @@ -820,6 +842,11 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: action="store_true", help="Save a model checkpoint after finishing an epoch.", ) + parser.add_argument( + "--accelerate_full_state_at_epoch", + action="store_true", + help="Save full model state using Accelerate after finishing an epoch.", + ) parser.add_argument("--log_level", type=str, default="INFO") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--mock_data", action="store_true") diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index ec051441..9b1a6c8f 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -781,3 +781,114 @@ def set_random_seed(seed): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +def save_checkpoint( + args, + accelerator: Accelerator, + model, + tokenizer, + samples_seen, + is_lora: bool, + epoch: int = None, + hf_format: bool = True, + full_state: bool = False, +) -> None: + if hf_format: + save_hf_format_accelerate( + args=args, + model=model, + accelerator=accelerator, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + ) + + if full_state: + save_full_state( + args=args, + accelerator=accelerator, + is_lora=is_lora, + epoch=epoch, + samples_seen=samples_seen, + ) + + +def save_full_state(args, accelerator, is_lora: bool, epoch: int, samples_seen: int): + """ + Saves model, optimizer, and lr_scheduler state. + TODO: save model config - decided not to do this. + TODO: save tokenizer - decided not to do this. + TODO: handle LoRA + TODO: handle granite + """ + if is_lora: + raise NotImplementedError("Can't save full state for LoRA at the moment.") + + # if args.is_granite: + # raise NotImplementedError("Can't save full state for Granite models yet.") + + output_dir = Path(args.output_dir) / "full_state" / f"epoch_{epoch}" + log_rank_0(f"\033[93mSaving full model state in {output_dir}\033[0m", to_print=True) + + # patch FSDP state dict method so it works correctly. + def _get_state_dict_patched(model, unwrap=False): + return get_state_dict_unpatched(model, unwrap=unwrap) + + if args.distributed_training_framework == "fsdp": + get_state_dict_unpatched = accelerator.get_state_dict + accelerator.get_state_dict = _get_state_dict_patched + + accelerator.save_state( + output_dir=output_dir, + # max_shard_size="5GB", + # safe_serialization=True, + ) + + # save metadata file for current training status + if accelerator.is_main_process: + # TODO: should we set the global_step here rather than calculating global_step + # based on samples_seen? + metadata = {"current_epoch": epoch, "samples_seen": samples_seen} + torch.save(metadata, output_dir / "training_metadata.json") + log_rank_0(f"\033[93mSaving training state: {metadata}\033[0m", to_print=True) + + log_rank_0(f"\033[93mModel state saved in: {output_dir}\033[0m", to_print=True) + + # cleanup + if args.distributed_training_framework == "fsdp": + accelerator.get_state_dict = get_state_dict_unpatched + + +def load_latest_full_state(args, accelerator) -> None: + """ + Loads accelerator state from most recently saved checkpoint + in `output_dir/full_state`. + """ + output_dir = Path(args.output_dir) / "full_state" + + if not output_dir.is_dir(): + return + + # picks checkpoint with the largest number of samples seen, by name. + checkpoint_list = sorted(list(output_dir.iterdir()), reverse=True) + + if len(checkpoint_list) == 0: + log_rank_0( + f"\033[93mNo checkpoints to load from: {output_dir}\033[0m", to_print=True + ) + return + + latest = checkpoint_list[0] + + log_rank_0(f"\033[93mLoading state from: {latest}\033[0m", to_print=True) + accelerator.load_state(latest) + + training_metadata = torch.load(latest / "training_metadata.json") + log_rank_0( + f"\033[93mTraining metadata loaded: {training_metadata}\033[0m", to_print=True + ) + + # previous epoch is basis for current epoch. + args.__dict__["current_epoch"] = training_metadata["current_epoch"] + 1 + args.__dict__["samples_seen"] = training_metadata["samples_seen"]