Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] MLM Training Objective #680

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
62 changes: 62 additions & 0 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import equinox as eqx
import fsspec
import jax
import jax.numpy as jnp
import numpy as np
import pyarrow as pa
import regex
Expand Down Expand Up @@ -64,6 +65,65 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class MaskedLmDataset(ShardableDataset[LmExample]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi we're gonna do a big refactor on datasets soon, but I'll either handle the refactor or guide you through it)

def __init__(
self,
dataset: ShardableDataset[np.ndarray],
QPos: Axis,
KPos: Axis,
mask_prob: float = 0.15,
key: Optional[PRNGKeyArray] = None,
ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX,
):
self.dataset = dataset
self.QPos = QPos
self.KPos = KPos
self.mask_prob = mask_prob
prady-saligram marked this conversation as resolved.
Show resolved Hide resolved
self.key = key
self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX

if self.mask_prob > 0.0 and self.key is None:
raise ValueError("must provide key if mask_prob > 0.0")

def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset":
return MaskedLmDataset(
self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.key, self.ignore_id
)

def __iter__(self) -> Iterator[LmExample]:
key = self.key
sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0])

with use_cpu_device():
@functools.partial(eqx.filter_jit, out_shardings=sharding)
def _create_mlm_example(tokens, key):
tokens_array = tokens.array

example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need a non-causal attention mask for Roberta, and you need to set a loss_mask to be only the masked tokens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you also can't use the current LmExample actually because you need a separate targets field (with the non-masked tokens). With more work you could avoid the need for targets (with just masked tokens), but probably better to add an targets: Optional[NamedArray] to the class (or make your own class)


if self.mask_prob > 0:
this_key, key = jax.random.split(key)
mask_shape = tokens_array.shape
mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape)

# Create a mask for 80% [MASK], 10% random, 10% original token
rand = jax.random.uniform(this_key, mask_shape)
mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array)
mask_token = jnp.where((rand >= 0.8) & (rand < 0.9), tokens_array, mask_token)
random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1)
masked_tokens = jnp.where(mask, mask_token, random_tokens)

masked_tokens_named = hax.named(masked_tokens, self.QPos)
example = dataclasses.replace(example, tokens=masked_tokens_named)

return example

for tokens in self.dataset:
tokens_array = jnp.array(tokens)
tokens_named = hax.named(tokens_array, self.QPos)
example = _create_mlm_example(tokens_named, key)
yield example


class CausalLmDataset(ShardableDataset[LmExample]):
def __init__(
Expand Down Expand Up @@ -120,6 +180,8 @@ def _create_lm_example(tokens, key):
yield example




prady-saligram marked this conversation as resolved.
Show resolved Hide resolved
class TokenSeqDataset(ShardableDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache.
Expand Down
200 changes: 200 additions & 0 deletions src/levanter/main/train_mlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# train_mlm.py

import dataclasses
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Optional, Union

import jax.random as jrandom

import haliax as hax
from haliax import Axis
from haliax.partitioning import named_jit, round_axis_for_partitioning

import levanter
from levanter import callbacks
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig
from levanter.models.gpt2 import Gpt2Config
from levanter.models.llama import LlamaConfig
from levanter.models.lm_model import LmConfig
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.jax_utils import parameter_count

logger = logging.getLogger(__name__)

@dataclass
class TrainMlmConfig:
data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig)
trainer: TrainerConfig = field(default_factory=TrainerConfig)
model: LmConfig = field(default_factory=LlamaConfig)
optimizer: OptimizerConfig = field(default_factory=AdamConfig)

# config related to continued pretraining
initialize_from_hf: Union[bool, str] = False
"""if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class"""
use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint

# TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying
# TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint

mlm_prob: float = 0.15 # masking probability for MLM
hf_save_path: Optional[str] = None
hf_upload: Optional[str] = None
hf_save_steps: int = 10000

update_hessian_steps: int = 10
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer

def main(config: TrainMlmConfig):
tokenizer = config.data.the_tokenizer

# this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through,
# I recommend skipping it for now
if config.initialize_from_hf:
if config.trainer.initialize_from is not None:
raise ValueError("Cannot specify both initialize_from_hf and initialize_from")

assert isinstance(config.model, HFCompatConfig)
converter = config.model.hf_checkpoint_converter()
if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab:
logger.warning("The tokenizers appear to be different. You may want to check this.")

if isinstance(config.initialize_from_hf, str):
converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer)
else:
converter = converter.replaced(tokenizer=tokenizer)

if config.use_hf_model_config:
# TODO: log diff of old and new config
# NB: gross mutability
config.model = converter.config_from_hf_config(converter.default_hf_config)
elif isinstance(config.model, HFCompatConfig):
converter = config.model.hf_checkpoint_converter()
converter = converter.replaced(tokenizer=tokenizer)
else:
converter = None

levanter.initialize(config)
optimizer = config.optimizer.build(config.trainer.num_train_steps)

# Using the trainer as a context manager does 3 things:
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer) as trainer:
# randomness in jax is tightly controlled by "keys" which are the states of the random number generators
# this makes deterministic training pretty easy
seed = config.trainer.seed
data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4)

if config.data_seed is not None:
logger.info(f"Overriding data seed with {config.data_seed}")
data_key = jrandom.PRNGKey(config.data_seed)

# We have two axis_mappings: one for storing the model and optimizer states, and one for compute
# This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh
compute_axis_mapping = trainer.compute_axis_mapping
parameter_axis_mapping = trainer.parameter_axis_mapping

# some axes we need
Batch = config.trainer.TrainBatch
EvalBatch = config.trainer.EvalBatch
Pos = config.model.Pos
KeyPos = config.model.KeyPos

tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size)
train_dataset = MaskedLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id
)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
vocab_size = len(tokenizer)
Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping)
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")

state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key))

if int(state.step) == 0:
# TODO: I don't love that we init the model twice, but it's not a big deal i think?
if config.initialize_from_hf:
# initialize from an hf pretrained model
logger.info(
"No training checkpoint found. Initializing model from HF checkpoint"
f" '{converter.reference_checkpoint}'"
)
# this is a bit gross, but we want to free up the memory from the model we just built
state = dataclasses.replace(state, model=None)
gc.collect()
model = converter.load_pretrained(
config.model.model_type,
config.model,
axis_mapping=parameter_axis_mapping,
dtype=trainer.mp.compute_dtype,
)
model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model)
state = dataclasses.replace(state, model=model)
else:
logger.info("No checkpoint found. Starting from scratch.")

levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)})

if len(tagged_eval_datasets) == 0:
logger.warning("No evaluation datasets provided.")
else:
masked_datasets = [
(MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id), tags)
for ds, tags in tagged_eval_datasets
]
max_eval_examples_per_ds = config.trainer.max_eval_batches
if max_eval_examples_per_ds is not None:
max_eval_examples_per_ds *= config.trainer.eval_batch_size

cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch, masked_datasets, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds
)
trainer.add_hook(cb, every=config.trainer.steps_per_eval)

flops_per_token = config.model.flops_per_token(vocab_size)
flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None
trainer.add_hook(
callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1
)
if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
every=config.hf_save_steps,
)

# visualize log probs
@named_jit(
in_axis_resources=parameter_axis_mapping,
axis_resources=compute_axis_mapping,
out_axis_resources=compute_axis_mapping,
)
def compute_log_probs(model, example):
model = trainer.mp.cast_to_compute(model)
logprobs = model.compute_loss(example, key=None, reduction=None)
# roll forward to get the loss for each predicted token
logprobs = hax.roll(logprobs, 1, Pos)
return logprobs.rearrange((EvalBatch, Pos)).array

train_loader = iter(trainer.sharded_loader(train_dataset, Batch))

if int(state.step) > 0:
import tqdm
for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"):
next(train_loader)

trainer.train(state, train_loader)

if __name__ == "__main__":
levanter.config.main(main)()
Loading