Skip to content

Commit

Permalink
type checkers
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 7, 2024
1 parent 5aa7e23 commit 009eb28
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def train(config: SFTConfig):
# Create supervised dataset using generic machinery
logger.info("Creating supervised dataset")
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None
chat_config = ChatSFTDatasetConfig(
cache_dir=config.supervised_data.cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
Expand All @@ -106,6 +108,7 @@ def train(config: SFTConfig):
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)
Expand All @@ -122,7 +125,7 @@ def train(config: SFTConfig):
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer:
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore
parameter_axis_mapping = trainer.parameter_axis_mapping

# We have two axis_mappings: one for storing the model and optimizer states, and one for compute
Expand All @@ -141,7 +144,7 @@ def train(config: SFTConfig):
logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
)
) # type: ignore
model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model=model)
else:
Expand All @@ -163,10 +166,14 @@ def train(config: SFTConfig):
next(loader)

if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload),
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
every=config.hf_save_steps,
)

Expand Down

0 comments on commit 009eb28

Please sign in to comment.