diff --git a/examples/sft/sft.py b/examples/sft/sft.py index 2ced8591c..629b556c2 100644 --- a/examples/sft/sft.py +++ b/examples/sft/sft.py @@ -97,6 +97,8 @@ def train(config: SFTConfig): # Create supervised dataset using generic machinery logger.info("Creating supervised dataset") if config.dataset_type == DatasetType.CHAT_JSONL: + assert config.chat_train_urls is not None + assert config.supervised_data is not None chat_config = ChatSFTDatasetConfig( cache_dir=config.supervised_data.cache_dir, train_urls=config.chat_train_urls, # No validation in this config @@ -106,6 +108,7 @@ def train(config: SFTConfig): ) train_dataset = mk_chat_sft_dataset(chat_config, tokenizer) else: + assert config.supervised_data is not None train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer) logger.info("Supervised dataset created") train_dataset = PermutationDataset(train_dataset, data_key) @@ -122,7 +125,7 @@ def train(config: SFTConfig): # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: + with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore parameter_axis_mapping = trainer.parameter_axis_mapping # We have two axis_mappings: one for storing the model and optimizer states, and one for compute @@ -141,7 +144,7 @@ def train(config: SFTConfig): logger.info(f"Loading pretrained model from {converter.reference_checkpoint}") model: LmHeadModel = converter.load_pretrained( model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype - ) + ) # type: ignore model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model) state = trainer.initial_state(training_key, model=model) else: @@ -163,10 +166,14 @@ def train(config: SFTConfig): next(loader) if config.hf_save_path is not None: - full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + # bit gross to reach this far into the config, but it's fine + if config.trainer.checkpointer.append_run_id_to_base_path: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + else: + full_save_path = config.hf_save_path trainer.add_hook( - save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload), + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), every=config.hf_save_steps, )