Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise SFT File #793

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions src/levanter/main/sft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional

import jax.random as jrandom
import transformers

import haliax as hax
from haliax import Axis
from haliax.partitioning import round_axis_for_partitioning

import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig, save_hf_checkpoint_callback
from levanter.data import PermutationDataset
from levanter.data.text import ChatSFTDatasetConfig, EpochDataset, mk_chat_sft_dataset, mk_supervised_dataset
from levanter.main.train_lm import TrainLmConfig
from levanter.models.lm_model import LmHeadModel, compute_next_token_loss
from levanter.trainer import Trainer


logger = logging.getLogger(__name__)

# Define default special tokens
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"


class DatasetType(str, Enum):
"""Type of dataset to use"""

HUGGINGFACE = "huggingface" # Use HF dataset
CHAT_JSONL = "chat_jsonl" # Use JSONL files with chat format


@dataclass
class SFTConfig(TrainLmConfig):
# inherit most of the config from TrainLmConfig
max_tune_length: int = 2048
model_name_or_path: str = "meta-llama/Llama-2-7b-hf"
tokenizer: str = "meta-llama/Llama-2-7b-hf"

# Add dataset type and chat-specific fields
dataset_type: DatasetType = DatasetType.HUGGINGFACE
chat_train_urls: Optional[List[str]] = None
messages_field: str = "messages"
input_role: str = "user"
output_role: str = "assistant"


def train(config: SFTConfig):
tokenizer = transformers.AutoTokenizer.from_pretrained(
config.tokenizer,
model_max_length=config.max_tune_length,
padding_side="right",
trust_remote_code=True,
)
logger.info(f"Loaded tokenizer {tokenizer}")

if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot use both --initialize_from_hf and --initialize_from")

assert isinstance(config.model, HFCompatConfig)

converter = HFCheckpointConverter.from_hf(config.model_name_or_path, trust_remote_code=True)
if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
logger.warning("The tokenizers appear to be different. You may want to check this.")
if isinstance(config.initialize_from_hf, str):
converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
else:
converter = converter.replaced(tokenizer=tokenizer)

model_config = converter.default_config
elif config.trainer.initialize_from is None:
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
converter = None

levanter.initialize(config)

num_new_tokens = add_special_tokens(tokenizer)
logger.info(f"Added {num_new_tokens} new tokens")
# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
# this makes deterministic training pretty easy
seed = config.trainer.seed
data_key, _, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)

if config.data_seed is not None:
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

# Create supervised dataset using generic machinery
logger.info("Creating supervised dataset")
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None
chat_config = ChatSFTDatasetConfig(
cache_dir=config.supervised_data.cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
messages_field=config.messages_field,
input_role=config.input_role,
output_role=config.output_role,
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

# Then wrap for epochs
if config.epoch > 0:
logger.info(f"Wrapping dataset for {config.epoch} epochs")
train_dataset = EpochDataset(train_dataset, max_epochs=config.epoch)

logger.info("Creating optimizer")
optimizer = config.optimizer.build(config.trainer.num_train_steps)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore
parameter_axis_mapping = trainer.parameter_axis_mapping

# We have two axis_mappings: one for storing the model and optimizer states, and one for compute
# This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh
parameter_axis_mapping = trainer.parameter_axis_mapping

# some axes we need
Pos = config.model.Pos

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
vocab_size = len(tokenizer)
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
if config.initialize_from_hf:
logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
) # type: ignore
model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model=model)
else:
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")
state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)

loader = trainer.data_loader(train_dataset, trainer.TrainBatch)

if int(state.step) != 0:
logger.info(f"Resuming training from step {state.step}")
for i in range(state.step):
next(loader)

if config.hf_save_path is not None:
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
every=config.hf_save_steps,
)

trainer.train(state, loader)


def add_special_tokens(tokenizer, use_unk_instead_of_adding=False):
special_tokens_dict = dict()
if use_unk_instead_of_adding:
if tokenizer.unk_token is None:
raise ValueError("use_unk_instead_of_add is True but tokenizer doesn't have an unk token")

unk = tokenizer.unk_token if use_unk_instead_of_adding else None

if tokenizer.pad_token is None:
special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.eos_token is None:
special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.bos_token is None:
special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN if not use_unk_instead_of_adding else unk
if tokenizer.unk_token is None:
special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN

return tokenizer.add_special_tokens(special_tokens_dict)


if __name__ == "__main__":
levanter.config.main(train)()
Loading