From 8f7402ef2cc63848bd7a8d81ed419c2a91a5451e Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 11:33:46 -0700 Subject: [PATCH 1/7] Implements dynamic masking objective --- src/levanter/main/train_mlm.py | 200 +++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 src/levanter/main/train_mlm.py diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py new file mode 100644 index 000000000..208131310 --- /dev/null +++ b/src/levanter/main/train_mlm.py @@ -0,0 +1,200 @@ +# train_mlm.py + +import dataclasses +import gc +import logging +import os +from dataclasses import dataclass, field +from typing import Optional, Union + +import jax.random as jrandom + +import haliax as hax +from haliax import Axis +from haliax.partitioning import named_jit, round_axis_for_partitioning + +import levanter +from levanter import callbacks +from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback +from levanter.data.text import MaskedLmDataset, LMDatasetConfig, LMMixtureDatasetConfig +from levanter.models.gpt2 import Gpt2Config +from levanter.models.llama import LlamaConfig +from levanter.models.lm_model import LmConfig +from levanter.optim import AdamConfig, OptimizerConfig +from levanter.trainer import Trainer, TrainerConfig +from levanter.utils.jax_utils import parameter_count + +logger = logging.getLogger(__name__) + +@dataclass +class TrainMlmConfig: + data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) + trainer: TrainerConfig = field(default_factory=TrainerConfig) + model: LmConfig = field(default_factory=LlamaConfig) + optimizer: OptimizerConfig = field(default_factory=AdamConfig) + + # config related to continued pretraining + initialize_from_hf: Union[bool, str] = False + """if provided, this will override the model config in the config. if true, use the default hf checkpoint for this model class""" + use_hf_model_config: bool = False # if true, replace the model config with the hf config from the checkpoint + + # TODO: atm we don't support loading from a checkpoint that has a different tokenizer. this is a bit annoying + # TODO: atm you have to at least specify a levanter model config with the same type as the hf checkpoint + + mlm_prob: float = 0.15 # masking probability for MLM + hf_save_path: Optional[str] = None + hf_upload: Optional[str] = None + hf_save_steps: int = 10000 + + update_hessian_steps: int = 10 + data_seed: Optional[int] = None # if provided, will override the data seed from the trainer + +def main(config: TrainMlmConfig): + tokenizer = config.data.the_tokenizer + + # this is some unpleasant code to allow us to initialize from a hf checkpoint. If this is your first read through, + # I recommend skipping it for now + if config.initialize_from_hf: + if config.trainer.initialize_from is not None: + raise ValueError("Cannot specify both initialize_from_hf and initialize_from") + + assert isinstance(config.model, HFCompatConfig) + converter = config.model.hf_checkpoint_converter() + if hasattr(tokenizer, "vocab") and tokenizer.vocab != converter.tokenizer.vocab: + logger.warning("The tokenizers appear to be different. You may want to check this.") + + if isinstance(config.initialize_from_hf, str): + converter = converter.replaced(reference_checkpoint=config.initialize_from_hf, tokenizer=tokenizer) + else: + converter = converter.replaced(tokenizer=tokenizer) + + if config.use_hf_model_config: + # TODO: log diff of old and new config + # NB: gross mutability + config.model = converter.config_from_hf_config(converter.default_hf_config) + elif isinstance(config.model, HFCompatConfig): + converter = config.model.hf_checkpoint_converter() + converter = converter.replaced(tokenizer=tokenizer) + else: + converter = None + + levanter.initialize(config) + optimizer = config.optimizer.build(config.trainer.num_train_steps) + + # Using the trainer as a context manager does 3 things: + # 1. Sets the device mesh + # 2. Sets the axis mapping (for fsdp) + # 3. Sets the global metrics tracker + with Trainer(config.trainer, optimizer) as trainer: + # randomness in jax is tightly controlled by "keys" which are the states of the random number generators + # this makes deterministic training pretty easy + seed = config.trainer.seed + data_key, loader_key, model_key, training_key = jrandom.split(jrandom.PRNGKey(seed), 4) + + if config.data_seed is not None: + logger.info(f"Overriding data seed with {config.data_seed}") + data_key = jrandom.PRNGKey(config.data_seed) + + # We have two axis_mappings: one for storing the model and optimizer states, and one for compute + # This allows Zero-3-style parameter sharding, where we shard the parameters and optimizer state across the mesh + compute_axis_mapping = trainer.compute_axis_mapping + parameter_axis_mapping = trainer.parameter_axis_mapping + + # some axes we need + Batch = config.trainer.TrainBatch + EvalBatch = config.trainer.EvalBatch + Pos = config.model.Pos + KeyPos = config.model.KeyPos + + tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + train_dataset = MaskedLmDataset( + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id + ) + + # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to + # For most things, we just insist you specify the config right, but tokenizers often have strange numbers of + # tokens: gpt-2 has 50257, for example. So we round up. + vocab_size = len(tokenizer) + Vocab = round_axis_for_partitioning(Axis("vocab", vocab_size), parameter_axis_mapping) + if vocab_size != Vocab.size: + logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning") + + state = trainer.initial_state(training_key, model_init=lambda: config.model.build(Vocab, key=model_key)) + + if int(state.step) == 0: + # TODO: I don't love that we init the model twice, but it's not a big deal i think? + if config.initialize_from_hf: + # initialize from an hf pretrained model + logger.info( + "No training checkpoint found. Initializing model from HF checkpoint" + f" '{converter.reference_checkpoint}'" + ) + # this is a bit gross, but we want to free up the memory from the model we just built + state = dataclasses.replace(state, model=None) + gc.collect() + model = converter.load_pretrained( + config.model.model_type, + config.model, + axis_mapping=parameter_axis_mapping, + dtype=trainer.mp.compute_dtype, + ) + model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model) + state = dataclasses.replace(state, model=model) + else: + logger.info("No checkpoint found. Starting from scratch.") + + levanter.tracker.log_summary({"parameter_count": parameter_count(state.model)}) + + if len(tagged_eval_datasets) == 0: + logger.warning("No evaluation datasets provided.") + else: + masked_datasets = [ + (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id), tags) + for ds, tags in tagged_eval_datasets + ] + max_eval_examples_per_ds = config.trainer.max_eval_batches + if max_eval_examples_per_ds is not None: + max_eval_examples_per_ds *= config.trainer.eval_batch_size + + cb = levanter.eval.cb_tagged_lm_evaluate( + EvalBatch, masked_datasets, trainer.device_mesh, compute_axis_mapping, max_eval_examples_per_ds + ) + trainer.add_hook(cb, every=config.trainer.steps_per_eval) + + flops_per_token = config.model.flops_per_token(vocab_size) + flops_per_example = 3 * flops_per_token * Pos.size if flops_per_token is not None else None + trainer.add_hook( + callbacks.log_performance_stats(Pos.size, trainer.config.train_batch_size, flops_per_example), every=1 + ) + if config.hf_save_path is not None: + full_save_path = os.path.join(config.hf_save_path, trainer.run_id) + + trainer.add_hook( + save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False), + every=config.hf_save_steps, + ) + + # visualize log probs + @named_jit( + in_axis_resources=parameter_axis_mapping, + axis_resources=compute_axis_mapping, + out_axis_resources=compute_axis_mapping, + ) + def compute_log_probs(model, example): + model = trainer.mp.cast_to_compute(model) + logprobs = model.compute_loss(example, key=None, reduction=None) + # roll forward to get the loss for each predicted token + logprobs = hax.roll(logprobs, 1, Pos) + return logprobs.rearrange((EvalBatch, Pos)).array + + train_loader = iter(trainer.sharded_loader(train_dataset, Batch)) + + if int(state.step) > 0: + import tqdm + for _ in tqdm.tqdm(range(state.step), desc="seeking data for resume"): + next(train_loader) + + trainer.train(state, train_loader) + +if __name__ == "__main__": + levanter.config.main(main)() From 670b053761c6806b39787bfd91626fea8be6876c Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 12:08:08 -0700 Subject: [PATCH 2/7] Implements dynamic masked dataset --- src/levanter/data/text.py | 58 +++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 484a98bf6..d0dcc95d7 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -14,6 +14,7 @@ import equinox as eqx import fsspec import jax +import jax.numpy as jnp import numpy as np import pyarrow as pa import regex @@ -64,30 +65,29 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index - -class CausalLmDataset(ShardableDataset[LmExample]): +class MaskedLmDataset(ShardableDataset[LmExample]): def __init__( self, dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, - fcm_prob: float = 0.0, + mask_prob: float = 0.15, key: Optional[PRNGKeyArray] = None, - ignore_index: Optional[int] = None, + ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): self.dataset = dataset self.QPos = QPos self.KPos = KPos - self.fcm_prob = fcm_prob + self.mask_prob = mask_prob self.key = key - self.ignore_id = ignore_index + self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX - if self.fcm_prob > 0.0 and self.key is None: - raise ValueError("must provide key if fcm_prob > 0.0") + if self.mask_prob > 0.0 and self.key is None: + raise ValueError("must provide key if mask_prob > 0.0") - def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": - return CausalLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.fcm_prob, self.key, self.ignore_id + def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": + return MaskedLmDataset( + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.key, self.ignore_id ) def __iter__(self) -> Iterator[LmExample]: @@ -95,31 +95,37 @@ def __iter__(self) -> Iterator[LmExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) with use_cpu_device(): - @functools.partial(eqx.filter_jit, out_shardings=sharding) - def _create_lm_example(tokens, key): - tokens = hax.named(tokens, self.QPos) - + def _create_mlm_example(tokens, key): + tokens_array = tokens.array + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - - if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 - assert self.key is not None + + if self.mask_prob > 0: this_key, key = jax.random.split(key) - fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) - attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) - example = dataclasses.replace(example, attn_mask=attn_mask) + mask_shape = tokens_array.shape + mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) + + # Create a mask for 80% [MASK], 10% random, 10% original token + rand = jax.random.uniform(this_key, mask_shape) + mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) + mask_token = jnp.where((rand >= 0.8) & (rand < 0.9), tokens_array, mask_token) + random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) + masked_tokens = jnp.where(mask, mask_token, random_tokens) + + masked_tokens_named = hax.named(masked_tokens, self.QPos) + example = dataclasses.replace(example, tokens=masked_tokens_named) return example for tokens in self.dataset: - example = _create_lm_example(tokens, key) + tokens_array = jnp.array(tokens) + tokens_named = hax.named(tokens_array, self.QPos) + example = _create_mlm_example(tokens_named, key) yield example + class TokenSeqDataset(ShardableDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. From 42f54042fe9a2b089c79607ab273fd46fd393b18 Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 30 Jul 2024 14:17:16 -0700 Subject: [PATCH 3/7] Reintroduced accidentally deleted CausalLMDataset class --- src/levanter/data/text.py | 56 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index d0dcc95d7..4e625ee8b 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -125,6 +125,62 @@ def _create_mlm_example(tokens, key): yield example +class CausalLmDataset(ShardableDataset[LmExample]): + def __init__( + self, + dataset: ShardableDataset[np.ndarray], + QPos: Axis, + KPos: Axis, + fcm_prob: float = 0.0, + key: Optional[PRNGKeyArray] = None, + ignore_index: Optional[int] = None, + ): + self.dataset = dataset + self.QPos = QPos + self.KPos = KPos + self.fcm_prob = fcm_prob + self.key = key + self.ignore_id = ignore_index + + if self.fcm_prob > 0.0 and self.key is None: + raise ValueError("must provide key if fcm_prob > 0.0") + + def shard(self, shard_id: int, num_shards: int) -> "CausalLmDataset": + return CausalLmDataset( + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.fcm_prob, self.key, self.ignore_id + ) + + def __iter__(self) -> Iterator[LmExample]: + key = self.key + sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) + + with use_cpu_device(): + + @functools.partial(eqx.filter_jit, out_shardings=sharding) + def _create_lm_example(tokens, key): + tokens = hax.named(tokens, self.QPos) + + example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) + + if self.fcm_prob > 0: + # masks for attention + # We support forgetful causal masking (FCM) which is a technique that improves training speed by + # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention + # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 + assert self.key is not None + this_key, key = jax.random.split(key) + fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) + attn_mask = example.attn_mask & AttentionMask.explicit(fcm_mask) + example = dataclasses.replace(example, attn_mask=attn_mask) + + return example + + for tokens in self.dataset: + example = _create_lm_example(tokens, key) + yield example + + + class TokenSeqDataset(ShardableDataset[np.ndarray]): """ From 53fd8d23061acfd0e936c9b86006f386874e051f Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Mon, 5 Aug 2024 14:51:47 -0700 Subject: [PATCH 4/7] [WIP] Re-implements MLM training objective --- src/levanter/data/text.py | 56 ++++++++++++++----------- src/levanter/models/lm_model.py | 74 +++++++++++++++++++++++---------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index 4e625ee8b..dfc7df4ea 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -26,13 +26,11 @@ from levanter.data.mixture import MixtureDataset, StopStrategy -# intercept the logging nonsense here from levanter.logging import silence_transformer_nag # noqa from levanter.models.attention import AttentionMask -from levanter.models.lm_model import LmExample +from levanter.models.lm_model import MaskedLmExample, LmExample from levanter.utils.hf_utils import num_cpus_used_by_tokenizer - silence_transformer_nag() # noqa from transformers import BatchEncoding, PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast # noqa @@ -54,7 +52,6 @@ from levanter.shapes import NamedShapeSpec, ShapeSpec # noqa from levanter.utils.jax_utils import use_cpu_device # noqa - logger = logging.getLogger("levanter.data.text") # TASKS: @@ -65,13 +62,14 @@ DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index -class MaskedLmDataset(ShardableDataset[LmExample]): +class MaskedLmDataset(ShardableDataset[MaskedLmExample]): def __init__( self, dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, mask_prob: float = 0.15, + noise_prob: float = 0.1, key: Optional[PRNGKeyArray] = None, ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): @@ -79,6 +77,7 @@ def __init__( self.QPos = QPos self.KPos = KPos self.mask_prob = mask_prob + self.noise_prob = noise_prob self.key = key self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX @@ -87,10 +86,10 @@ def __init__( def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": return MaskedLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.key, self.ignore_id + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.noise_prob, self.key, self.ignore_id ) - def __iter__(self) -> Iterator[LmExample]: + def __iter__(self) -> Iterator[MaskedLmExample]: key = self.key sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) @@ -98,31 +97,44 @@ def __iter__(self) -> Iterator[LmExample]: @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_mlm_example(tokens, key): tokens_array = tokens.array - - example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) - + targets = tokens_array.copy() + if self.mask_prob > 0: this_key, key = jax.random.split(key) mask_shape = tokens_array.shape mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) - # Create a mask for 80% [MASK], 10% random, 10% original token rand = jax.random.uniform(this_key, mask_shape) mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) - mask_token = jnp.where((rand >= 0.8) & (rand < 0.9), tokens_array, mask_token) random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) - masked_tokens = jnp.where(mask, mask_token, random_tokens) + mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token) + masked_tokens = jnp.where(mask, mask_token, tokens_array) + + # Set targets to the original tokens where mask is True, otherwise set to ignore_id + targets = jnp.where(mask, tokens_array, self.ignore_id) masked_tokens_named = hax.named(masked_tokens, self.QPos) - example = dataclasses.replace(example, tokens=masked_tokens_named) + targets_named = hax.named(targets, self.QPos) + + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + else: + targets_named = hax.named(targets, self.QPos) + attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) + attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) + + example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) return example - for tokens in self.dataset: - tokens_array = jnp.array(tokens) - tokens_named = hax.named(tokens_array, self.QPos) - example = _create_mlm_example(tokens_named, key) - yield example + for tokens in self.dataset: + tokens_array = jnp.array(tokens) + tokens_named = hax.named(tokens_array, self.QPos) + example = _create_mlm_example(tokens_named, key) + yield example + class CausalLmDataset(ShardableDataset[LmExample]): @@ -155,7 +167,6 @@ def __iter__(self) -> Iterator[LmExample]: sharding = jax.sharding.SingleDeviceSharding(jax.local_devices(backend="cpu")[0]) with use_cpu_device(): - @functools.partial(eqx.filter_jit, out_shardings=sharding) def _create_lm_example(tokens, key): tokens = hax.named(tokens, self.QPos) @@ -163,10 +174,6 @@ def _create_lm_example(tokens, key): example = LmExample.causal(tokens=tokens, ignore_id=self.ignore_id) if self.fcm_prob > 0: - # masks for attention - # We support forgetful causal masking (FCM) which is a technique that improves training speed by - # randomly masking out some of the context. This is a bit like dropout, but it's applied to the attention - # mask instead of the activations. It's described in https://arxiv.org/abs/2210.13432 assert self.key is not None this_key, key = jax.random.split(key) fcm_mask = hax.nn.attention.forgetful_causal_mask(self.KPos, self.fcm_prob, key=this_key) @@ -181,7 +188,6 @@ def _create_lm_example(tokens, key): - class TokenSeqDataset(ShardableDataset[np.ndarray]): """ A dataset that yields sequences of tokens of fixed length from a TokenizedDocumentCache. diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index 543c6a5ca..edcbb59f9 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -3,6 +3,7 @@ import draccus import equinox as eqx +import jax import jax.numpy as jnp from jax.random import PRNGKey @@ -12,15 +13,36 @@ from levanter.models.attention import AttentionMask - LmConfigT = TypeVar("LmConfigT", bound="LmConfig") LmT = TypeVar("LmT", bound="LmHeadModel") +class MaskedLmExample(eqx.Module): + tokens: hax.NamedArray + loss_mask: hax.NamedArray + attn_mask: hax.NamedArray + targets: Optional[hax.NamedArray] = None + + @staticmethod + def masked_lm( + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + ) -> "MaskedLmExample": + Pos = tokens.axes[0] + + mask = tokens.array != targets.array + loss_mask = hax.named(mask.astype(jnp.float32), Pos) + + if ignore_id is not None: + ignore_mask = targets.array != ignore_id + loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos) + + return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) + class LmExample(eqx.Module): tokens: hax.NamedArray loss_mask: hax.NamedArray - attn_mask: AttentionMask | NamedArray = AttentionMask.causal() + attn_mask: hax.NamedArray + targets: Optional[hax.NamedArray] = None @staticmethod def causal( @@ -34,20 +56,38 @@ def causal( Pos = tokens.axes[0] - # don't predict the last token. if loss_mask is None: loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jnp.float32) if ignore_id is not None: - # we don't compute loss for any tokens matching the ignore index ignore_mask = hax.roll(tokens, -1, Pos) != ignore_id loss_mask = loss_mask * ignore_mask attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) + @staticmethod + def masked_lm( + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + ) -> "LmExample": + Pos = tokens.axes[0] + + mask = tokens.array != targets.array + loss_mask = mask.astype(jnp.float32) + + if ignore_id is not None: + ignore_mask = targets.array != ignore_id + loss_mask = loss_mask * ignore_mask.astype(jnp.float32) + + print(f"tokens shape: {tokens.shape}") + print(f"targets shape: {targets.shape}") + print(f"loss_mask shape: {loss_mask.shape}") + print(f"attn_mask shape: {attn_mask.shape}") + + return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) + + -# TODO: for some reason, mypy doesn't like the discover_packages_path argument? class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property @abc.abstractmethod @@ -70,12 +110,7 @@ def flops_per_token(self, vocab_size: int) -> Optional[float]: def build(self, Vocab: Axis, *, key: PRNGKey) -> "LmT": return self.model_type.init(Vocab, self, key=key) # type: ignore - class LmHeadModel(Generic[LmConfigT], abc.ABC): - """ - Superclass for models with a language modeling head. - """ - @property @abc.abstractmethod def config(self) -> LmConfigT: @@ -103,14 +138,11 @@ def init(cls, Vocab: Axis, config: LmConfigT, *, key: PRNGKey) -> "LmHeadModel[L def __call__( self, input_ids: NamedArray, attn_mask: Optional[AttentionMask | NamedArray] = None, *, key=None ) -> NamedArray: - pass + print(f"input_ids shape: {input_ids.shape}") + print(f"attn_mask shape: {attn_mask.shape}") @abc.abstractmethod def resize_vocab(self, new_size: int, key: Optional[PRNGKey] = None) -> "LmHeadModel[LmConfigT]": - """ - Resizes the vocabulary of the model. Key may be provided to use random initialization, otherwise, there - should be some deterministic initialization of any new parameters. - """ pass def compute_loss( @@ -121,15 +153,13 @@ def compute_loss( reduction: Optional[hax.ReductionFunction] = hax.mean, reduction_axis: Optional[hax.AxisSelection] = None, ) -> jnp.ndarray | NamedArray: - """ - Computes the cross-entropy loss for a language modeling example. If reduction is not None, the loss is reduced - across the reduction axis (with reduction_axis=None meaning all axes). If reduction is None, the loss is not - reduced, and the result is a named array with axes (*batch axes, sequence_length). - """ logits = self(example.tokens, example.attn_mask, key=key) - # TODO: would be nice if we made the dtype configurable logits = logits.astype(jnp.float32) - targets = hax.roll(example.tokens, -1, axis=self.Pos.name) + if example.targets is not None: + targets = example.targets + else: + targets = hax.roll(example.tokens, -1, axis=self.Pos.name) + target_y = hax.nn.one_hot(targets, self.Vocab, dtype=logits.dtype) loss = cross_entropy_loss( logits, self.Vocab, target_y, reduction, reduction_axis=reduction_axis, where=example.loss_mask From dcd45b209b81946efa6a57253545c03a618db28d Mon Sep 17 00:00:00 2001 From: prady-saligram Date: Tue, 6 Aug 2024 11:38:09 -0700 Subject: [PATCH 5/7] Adds error handling and reverts LmExample class to original --- src/levanter/models/lm_model.py | 33 ++++++++++----------------------- 1 file changed, 10 insertions(+), 23 deletions(-) diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index edcbb59f9..c36e0e622 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -26,6 +26,15 @@ class MaskedLmExample(eqx.Module): def masked_lm( tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None ) -> "MaskedLmExample": + if tokens.ndim != 1: + raise ValueError("tokens must be a 1D array") + + if not jnp.issubdtype(tokens.dtype, jnp.integer): + raise ValueError("tokens must be an integer array") + + if tokens.shape != targets.shape: + raise ValueError("tokens and targets must have the same shape") + Pos = tokens.axes[0] mask = tokens.array != targets.array @@ -41,8 +50,7 @@ def masked_lm( class LmExample(eqx.Module): tokens: hax.NamedArray loss_mask: hax.NamedArray - attn_mask: hax.NamedArray - targets: Optional[hax.NamedArray] = None + attn_mask: AttentionMask | NamedArray = AttentionMask.causal() @staticmethod def causal( @@ -66,27 +74,6 @@ def causal( attn_mask = AttentionMask.causal() return LmExample(tokens=tokens, loss_mask=loss_mask, attn_mask=attn_mask) - @staticmethod - def masked_lm( - tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None - ) -> "LmExample": - Pos = tokens.axes[0] - - mask = tokens.array != targets.array - loss_mask = mask.astype(jnp.float32) - - if ignore_id is not None: - ignore_mask = targets.array != ignore_id - loss_mask = loss_mask * ignore_mask.astype(jnp.float32) - - print(f"tokens shape: {tokens.shape}") - print(f"targets shape: {targets.shape}") - print(f"loss_mask shape: {loss_mask.shape}") - print(f"attn_mask shape: {attn_mask.shape}") - - return LmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) - - class LmConfig(draccus.PluginRegistry, abc.ABC, Generic[LmT], discover_packages_path="levanter.models"): # type: ignore @property From 027b17623913017408345bd5817bc443d8295961 Mon Sep 17 00:00:00 2001 From: Prady Saligram Date: Mon, 26 Aug 2024 15:03:35 -0700 Subject: [PATCH 6/7] Sets RobertaConfig as model architecture and creates default config file --- config/roberta.yaml | 38 ++ src/levanter/main/train_mlm.py | 3 +- src/levanter/models/roberta.py | 847 +++++++++++++++++++++++++++++++++ 3 files changed, 887 insertions(+), 1 deletion(-) create mode 100644 config/roberta.yaml create mode 100644 src/levanter/models/roberta.py diff --git a/config/roberta.yaml b/config/roberta.yaml new file mode 100644 index 000000000..81f5d4d35 --- /dev/null +++ b/config/roberta.yaml @@ -0,0 +1,38 @@ +data: + train_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" + validation_urls: + - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "gs://levanter-data/tokenized/openwebtext_roberta/" + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 768 + intermediate_size: 3072 + num_hidden_layers: 12 + num_attention_heads: 12 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 208131310..80e941d5a 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -20,6 +20,7 @@ from levanter.models.gpt2 import Gpt2Config from levanter.models.llama import LlamaConfig from levanter.models.lm_model import LmConfig +from levanter.models.roberta import RobertaConfig from levanter.optim import AdamConfig, OptimizerConfig from levanter.trainer import Trainer, TrainerConfig from levanter.utils.jax_utils import parameter_count @@ -30,7 +31,7 @@ class TrainMlmConfig: data: Union[LMDatasetConfig, LMMixtureDatasetConfig] = field(default_factory=LMDatasetConfig) trainer: TrainerConfig = field(default_factory=TrainerConfig) - model: LmConfig = field(default_factory=LlamaConfig) + model: LmConfig = field(default_factory=RobertaConfig) optimizer: OptimizerConfig = field(default_factory=AdamConfig) # config related to continued pretraining diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py new file mode 100644 index 000000000..f51771eff --- /dev/null +++ b/src/levanter/models/roberta.py @@ -0,0 +1,847 @@ +import dataclasses +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import equinox as eqx +import jax +import jax.numpy as jnp +import jax.random as jrandom +from jaxtyping import PRNGKeyArray + +import haliax as hax +import haliax.nn as hnn +from haliax import Axis, AxisSpec, NamedArray +from haliax.jax_utils import maybe_rng_split, named_call, shaped_rng_split +from haliax.nn.scan import Stacked + +from levanter.compat.hf_checkpoints import HFCheckpointConverter, HFCompatConfig +from levanter.compat.torch_serialization import ( + StateDict, + StateDictSerializationMixin, + apply_prefix, + flatten_linear_layers, + stack_state_dict, + unflatten_linear_layers, + unstack_state_dict, +) +from levanter.logging import silence_transformer_nag +from levanter.models.attention import AttentionBackend, AttentionMask, simple_attention_with_dropout +from levanter.models.gpt2 import ACT2FN +from levanter.models.lm_model import LmConfig, LmHeadModel +from levanter.types import BlockFoldable +from levanter.utils.flop_utils import lm_flops_per_token + +silence_transformer_nag() +from transformers import PretrainedConfig as HfConfig +from transformers import RobertaConfig as HfRobertaConfig + + + +@LmConfig.register_subclass("roberta") +@dataclass(frozen=True) +class RobertaConfig(HFCompatConfig): + r""" + + Adapted from HuggingFace RobertaConfig, description below + + + This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is + used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa + [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50265): + Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from transformers import RobertaConfig, RobertaModel + + >>> # Initializing a RoBERTa configuration + >>> configuration = RobertaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = RobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + vocab_size: int = 50265 + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + intermediate_size: int = 3072 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + max_position_embeddings: int = 512 + type_vocab_size: int = 2 + initializer_range: float = 0.02 + layer_norm_eps: float = 1e-12 + pad_token_id: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + position_embedding_type: Optional[str] = "absolute" + use_cache: bool = False + classifier_dropout: Optional[float] = None + + scan_layers: bool = True + gradient_checkpointing: bool = True + + reference_checkpoint: str = "FacebookAI/roberta-base" + tokenizer: Optional[str] = None + + # Axes + Pos = property(lambda self: Axis(name="position", size=self.max_position_embeddings)) + KeyPos = property(lambda self: self.Pos.alias("key_position")) + Embed = property(lambda self: Axis(name="embed", size=self.hidden_size)) + EmbedAtt = property(lambda self: self.Embed.alias("embed_att")) + FinalEmbed = property(lambda self: self.Embed.alias("final_embed")) + Heads = property(lambda self: Axis(name="heads", size=self.num_attention_heads)) + Layers = property(lambda self: Axis(name="layers", size=self.num_hidden_layers)) + Mlp = property(lambda self: Axis(name="mlp", size=self.intermediate_size)) + HeadSize = property(lambda self: Axis(name="head_size", size=self.hidden_size // self.num_attention_heads)) + + + @classmethod + def from_hf_config(cls, hf_config: HfConfig) -> "RobertaConfig": + return RobertaConfig( + vocab_size = hf_config.vocab_size, + hidden_size = hf_config.hidden_size, + num_hidden_layers = hf_config.num_hidden_layers, + num_attention_heads = hf_config.num_attention_heads, + intermediate_size = hf_config.intermediate_size, + hidden_act = hf_config.hidden_act, + hidden_dropout_prob= hf_config.hidden_dropout_prob, + attention_probs_dropout_prob = hf_config.attention_probs_dropout_prob, + max_position_embeddings = hf_config.max_position_embeddings, + type_vocab_size = hf_config.type_vocab_size, + initializer_range = hf_config.initializer_range, + layer_norm_eps = hf_config.layer_norm_eps, + pad_token_id = hf_config.pad_token_id, + bos_token_id = hf_config.bos_token_id, + eos_token_id = hf_config.eos_token_id, + position_embedding_type = hf_config.position_embedding_type, + use_cache = hf_config.use_cache, + classifier_dropout = hf_config.classifier_dropout, + ) + + def hf_checkpoint_converter(self) -> HFCheckpointConverter["RobertaConfig"]: # type: ignore + return HFCheckpointConverter( + self.__class__, + reference_checkpoint=self.reference_checkpoint, + trust_remote_code=True, + tokenizer=self.tokenizer if self.tokenizer else self.reference_checkpoint, + HfConfigClass=HfRobertaConfig, + ) + + def to_hf_config(self, vocab_size: int, config_overrides: Optional[Dict] = None) -> HfRobertaConfig: + """Convert to HuggingFace's LlamaConfig + + Args: + vocab_size (int, optional): Vocabulary size of the tokenizer. Defaults to 32000. + config_overrides (dict, optional): Overrides for the config. Defaults to None. + + Returns: + HfRobertaConfig: HuggingFace's RobertaConfig + """ + + if config_overrides is None: + config_overrides = {} + + return HfRobertaConfig( + vocab_size = vocab_size, + hidden_size = self.hidden_size, + num_hidden_layers = self.num_hidden_layers, + num_attention_heads = self.num_attention_heads, + intermediate_size = self.intermediate_size, + hidden_act = self.hidden_act, + hidden_dropout_prob = self.hidden_dropout_prob, + attention_probs_dropout_prob = self.attention_probs_dropout_prob, + max_position_embeddings = self.max_position_embeddings, + type_vocab_size = self.type_vocab_size, + initializer_range = self.initializer_range, + layer_norm_eps = self.layer_norm_eps, + pad_token_id = self.pad_token_id, + bos_token_id = self.bos_token_id, + eos_token_id = self.eos_token_id, + position_embedding_type = self.position_embedding_type, + use_cache = self.use_cache, + classifier_dropout = self.classifier_dropout, + ) + + @property + def model_type(self) -> Type["RobertaModel"]: + return RobertaModel + + def flops_per_token(self, vocab_size: int): + return lm_flops_per_token( + hidden_dim=self.hidden_size, + intermediate_dim=self.intermediate_size, + num_layers=self.num_hidden_layers, + num_kv_heads=self.num_attention_heads, + num_heads=self.num_attention_heads, + seq_len=self.max_position_embeddings, + vocab_size=vocab_size, + glu=True, + ) + +class RobertaSelfAttention(eqx.Module, StateDictSerializationMixin): + + config: RobertaConfig + Heads: Axis + HeadSize: Axis + EmbedAtt: Axis + + q_proj: hnn.Linear + k_proj: hnn.Linear + v_proj: hnn.Linear + + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + Pos: Axis + KeyPos: Axis + distance_embedding: Optional[hnn.Embedding] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfAttention": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + + k_q, k_k, k_v, k_e = jrandom.split(key, 4) + q_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_q, out_first=True) + k_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_k, out_first=True) + v_proj = hnn.Linear.init(In=Embed, Out=EmbedAtt, key=k_v, out_first=True) + + dropout = hnn.Dropout(config.attention_probs_dropout_prob) + + distance_embedding = None + position_embedding_type = config.position_embedding_type + + if position_embedding_type == "relative_key" or position_embedding_type == "relative_key_query": + RelPos = Axis("rel_pos", 2 * config.max_position_embeddings - 1) + distance_embedding = hnn.Embedding.init(RelPos, config.HeadSize, k_e) + + return RobertaSelfAttention(config, config.Heads, config.HeadSize, EmbedAtt, + q_proj, k_proj, v_proj, + dropout, position_embedding_type, + config.Pos, config.KeyPos, distance_embedding, + ) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"q_proj": "query", "k_proj": "key", "v_proj": "value"} + + def _rope_scale_factor(self) -> float: + # hasattr for gemma and I'm feeling lazy + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + assert self.config.rope_scaling["type"] == "linear" + return self.config.rope_scaling["factor"] + return 1.0 + + def transpose_for_scores(self, x: NamedArray) -> NamedArray: + # Makes sure to have the correct output order as well + y = hax.rearrange(x, "... position (embed_att: heads head_size) -> ... heads position head_size", heads=self.Heads, head_size=self.HeadSize) + return y + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key = None + ) -> Tuple[NamedArray]: + + query_layer = self.transpose_for_scores(self.q_proj(hidden_states)) + key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) + value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) + + if self.position_embedding_type == "rope": + cos, sin = llama_rotary_pos_emb( + self.config.HeadSize, hidden_states.resolve_axis("position"), scale=self._rope_scale_factor() + ) + query_layer, key_layer = _apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + key_layer = key_layer.rename({"position": "key_position"}) + value_layer = value_layer.rename({"position": "key_position"}) + + attention_scores = hax.dot(query_layer, key_layer, axis=self.HeadSize) # aka hax.einsum("bhld, bhrd -> bhlr") + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + Left = self.Pos # Queries + Right = self.KeyPos # Keys + + position_ids_l = hax.arange(Left).broadcast_to((Left,Right)) + position_ids_r = hax.arange(Right).broadcast_to((Left,Right)) + + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.Pos.size) + + if self.position_embedding_type == "relative_key": + relative_position_scores = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = hax.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = hax.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores /= jnp.sqrt(self.HeadSize.size) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + # Attention_mask should have shape Batch Pos, so it should broadcast to shape Batch Heads Pos KeyPos for summation + attention_scores = attention_scores + attention_mask + + attention_probs = hnn.softmax(attention_scores, axis=self.KeyPos) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs, key=key) + + hax.dot(query_layer, key_layer, axis=self.HeadSize) + + context_layer = hax.dot(attention_probs, value_layer, axis=self.KeyPos) + + outputs = hax.rearrange(context_layer, ("... heads position head_size -> ... position (embed_att: heads head_size)"), heads=self.Heads, head_size=self.HeadSize) + + return outputs + +class RobertaSelfOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + EmbedAtt = config.EmbedAtt + dense = hnn.Linear.init(In=EmbedAtt, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray,*, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaAttention(eqx.Module, StateDictSerializationMixin): + self_attn: RobertaSelfAttention + output: RobertaSelfOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaAttention": + k_a, k_o = jrandom.split(key, 2) + + self_attn = RobertaSelfAttention.init(config, key=k_a) + output = RobertaSelfOutput.init(config, key=k_o) + + return RobertaAttention(self_attn, output) + + def _state_dict_key_map(self) -> Dict[str, Optional[str]]: + return {"self_attn": "self"} + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> NamedArray: + k_a, k_o = maybe_rng_split(key, 2) + + self_outputs = self.self_attn( + hidden_states, + attention_mask, + key=k_a + ) + attention_output = self.output(self_outputs, hidden_states, key=k_o) + return attention_output + +class RobertaIntermediate(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + intermediate_act_fn: Callable = eqx.static_field() + + @staticmethod + def init(config, *, key) -> "RobertaIntermediate": + dense = hnn.Linear.init(config.Embed, config.Mlp, key=key, out_first=True) + if isinstance(config.hidden_act, str): + intermediate_act_fn = ACT2FN[config.hidden_act] + else: + intermediate_act_fn = config.hidden_act + + return RobertaIntermediate(dense, intermediate_act_fn) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key = None) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + +class RobertaOutput(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaSelfOutput": + Embed = config.Embed + dense = hnn.Linear.init(In=config.Mlp, Out=Embed, key=key, out_first=True) + LayerNorm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + return RobertaSelfOutput(dense, LayerNorm, dropout) + + @named_call + def __call__(self, hidden_states: NamedArray, input: NamedArray, *, key) -> NamedArray: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, key=key) + hidden_states = self.LayerNorm(hidden_states + input) + return hidden_states + +class RobertaLayer(eqx.Module, StateDictSerializationMixin): + attention: RobertaAttention + intermediate: RobertaIntermediate + output: RobertaOutput + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaLayer": + k_a, k_i, k_o = jrandom.split(key, 3) + + attention = RobertaAttention.init(config, key=k_a) + intermediate = RobertaIntermediate.init(config, key=k_i) + output = RobertaOutput.init(config, key=k_o) + + return RobertaLayer(attention, intermediate, output) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + k_a, k_o = maybe_rng_split(key, 2) + + attention_output = self.attention( + hidden_states, + attention_mask, + key=k_a, + ) + + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output, key=k_o) + + return layer_output + + +class RobertaEncoder(eqx.Module, StateDictSerializationMixin): + config: RobertaConfig + layer: BlockFoldable[RobertaLayer] + + @staticmethod + def init(config: RobertaConfig, *, key) -> "RobertaEncoder": + S = Stacked + if not config.scan_layers: + from haliax.nn.scan import BlockSeq + + S = BlockSeq + + layer = S.init(config.Layers, RobertaLayer, gradient_checkpointing=config.gradient_checkpointing)( + config, + key=shaped_rng_split(key, config.num_hidden_layers), #TODO: config.gradient_checkpointing + ) + + return RobertaEncoder(config, layer) + + @named_call + def __call__( + self, + hidden_states: NamedArray, + attention_mask: Optional[NamedArray] = None, + *, + key + ) -> Tuple[NamedArray]: + + keys = maybe_rng_split(key, self.config.num_hidden_layers) if key is not None else None + x = self.layer.fold(hidden_states, attention_mask, key=keys) + + return x + + def from_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None): + if isinstance(self.layer, Stacked): + state_dict = stack_state_dict(state_dict, prefix=apply_prefix(prefix, "layer")) + + out = super().from_state_dict(state_dict, prefix=prefix) + return out + + def update_state_dict(self, state_dict: StateDict, prefix: Optional[str] = None) -> StateDict: + my_state_dict: StateDict = {} + super().update_state_dict(my_state_dict, prefix=prefix) + + if isinstance(self.layer, Stacked): + stacked_dict = unstack_state_dict(my_state_dict, prefix=apply_prefix(prefix, "layer")) + state_dict.update(stacked_dict) + else: + state_dict.update(my_state_dict) + + return state_dict + +class RobertaEmbedding(eqx.Module, StateDictSerializationMixin): + Vocab: Axis = eqx.static_field() + Pos: Axis = eqx.static_field() + + word_embeddings: hnn.Embedding + position_embeddings: hnn.Embedding + token_type_embeddings: Optional[hnn.Embedding] + padding_idx: NamedArray + + LayerNorm: hnn.LayerNorm + dropout: hnn.Dropout + position_embedding_type: Optional[str] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key) -> "RobertaEmbedding": + key_w, key_p, key_t = jrandom.split(key, 3) + + padding_idx = config.pad_token_id + + word_embeddings = hnn.Embedding.init(Vocab, config.Embed, key=key_w) # padding_idx not specified + position_embeddings = hnn.Embedding.init(config.Pos, config.Embed, key=key_p) + + Token = hax.Axis("token", config.type_vocab_size) + + token_type_embeddings = hnn.Embedding.init(Token, config.Embed, key=key_t) + + LayerNorm = hnn.LayerNorm.init(config.Embed, config.layer_norm_eps) + dropout = hnn.Dropout(config.hidden_dropout_prob) + + return RobertaEmbedding(Vocab, config.Pos, word_embeddings, position_embeddings, token_type_embeddings, padding_idx, LayerNorm, dropout, config.position_embedding_type) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + mask = hax.not_equal(input_ids, self.padding_idx) * 1 + incremental_indices = (hax.cumsum(mask, axis=self.Pos).astype(mask) + past_key_values_length) * mask + return incremental_indices + self.padding_idx + + def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): + position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + return hax.broadcast_to(position_ids, input_axes) + + @named_call + def embed(self, input_ids=None, token_type_ids=None, position_ids=None, input_embeds=None, past_key_values_length=0, *, key = None): + """ + Note: When inputting your own embeds into input_embeds, make sure that the embeds axis has the name "embed" + for compatibility with the position_id creation function. Make sures its length is not equal to + """ + + # Get Axes + if input_ids is not None: + input_axes = input_ids.axes + else: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + + # Get position_ids + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids(input_ids, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(input_axes, input_embeds.resolve_axis("position")) + + # Get token_type_ids + if token_type_ids is None: + token_type_ids = hax.zeros(input_axes, dtype=jnp.int32) + + if input_embeds is None: + input_embeds = self.word_embeddings(input_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = input_embeds + token_type_embeddings + + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings, key=key) + return embeddings + +class RobertaPooler(eqx.Module, StateDictSerializationMixin): + dense: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(config: RobertaConfig, *, key): + dense = hnn.Linear.init(In=config.Embed, Out=config.FinalEmbed, key=key, out_first=True) + + return RobertaPooler(dense, config) + + @named_call + def __call__(self, hidden_states: NamedArray, *, key=None) -> NamedArray: + first_token = hidden_states[{"position" : 0}] + x = self.dense(first_token, key=key).rename({self.config.FinalEmbed: self.config.Embed}) + x = hax.tanh(x) + return x + + +class RobertaModel(eqx.Module, StateDictSerializationMixin): + encoder: RobertaEncoder + embeddings: RobertaEmbedding + pooler : Optional[RobertaPooler] + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, add_pooling_layer: bool = True, *, key) -> "RobertaModel": + k_t, k_emb, k_p = jrandom.split(key, 3) + encoder = RobertaEncoder.init(config=config, key=k_t) + embeddings = RobertaEmbedding.init(Vocab, config, key=k_emb) + + pooler = RobertaPooler.init(config, key=k_p) if add_pooling_layer else None + return RobertaModel(encoder, embeddings, pooler) + + @property + def config(self): + return self.encoder.config + + @property + def vocab_size(self) -> int: + return self.Vocab.size + + @property + def Vocab(self) -> Axis: + return self.embeddings.Vocab + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + *, + key, + ) -> Tuple[NamedArray]: + """ + Not Used: meant to be used to improve performance in decoder implementations + + head_mask: Optional[NamedArray] = None, + encoder_hidden_states: Optional[NamedArray] = None, + encoder_attention_mask: Optional[NamedArray] = None, + past_key_values_length = 0, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + """ + k_emb, k_e, k_p = maybe_rng_split(key, 3) + + if input_ids is not None and input_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_axes = input_ids.axes + elif input_embeds is not None: + input_axes = hax.eliminate_axes(input_embeds.axes, "embed") + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if attention_mask is None: + attention_mask = hax.ones(input_axes) + + # Attention mask from mask to actual numbers + attention_mask = (attention_mask == 0) * -jnp.inf + + embedding_output = self.embeddings.embed(input_ids, token_type_ids, position_ids, input_embeds, key=k_emb) + sequence_output = self.encoder(embedding_output, attention_mask=attention_mask, key=k_e) + + pooled_output = self.pooler(sequence_output, key=k_p) if self.pooler is not None else None + + return (sequence_output, pooled_output) + +class RobertaLMHead(eqx.Module, StateDictSerializationMixin): + """Roberta Head for masked language modeling.""" + + dense: hnn.Linear + layer_norm: hnn.LayerNorm + decoder: hnn.Linear + config: RobertaConfig + + @staticmethod + def init(Vocab: Axis, config: RobertaConfig, *, key): + k_dense, k_decoder = jrandom.split(key, 2) + Embed = config.Embed + + dense = hnn.Linear.init(In=Embed, Out=config.FinalEmbed, key=k_dense, out_first=True) + layer_norm = hnn.LayerNorm.init(axis=Embed, eps=config.layer_norm_eps) + + decoder = hnn.Linear.init(Embed, Vocab, key=k_decoder, out_first=True) + + # idk what this is: TODO + # self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + # self.decoder.bias = self.bias + + return RobertaLMHead(dense, layer_norm, decoder, config) + + @named_call + def __call__(self, features: NamedArray, *, key=None) -> NamedArray: + x = self.dense(features).rename({self.config.FinalEmbed: self.config.Embed}) + x = hnn.gelu(x, approximate=False) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + +class RobertaForMaskedLM(eqx.Module, StateDictSerializationMixin): + roberta: RobertaModel + lm_head: RobertaLMHead + Vocab: Axis + + @classmethod + def init(self, Vocab: Axis, config: RobertaConfig, *, key): + + # if config.is_decoder: + # raise AttributeError("Model is being run as a MaskedLM aka an encoder model, but is_decoder is true") + + key_rob, key_head = jrandom.split(key, 2) + roberta = RobertaModel.init(Vocab, config, add_pooling_layer=False, key=key_rob) + lm_head = RobertaLMHead.init(Vocab, config, key=key_head) + + return RobertaForMaskedLM(roberta, lm_head, Vocab) + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @named_call + def __call__( + self, + input_ids: Optional[NamedArray] = None, + attention_mask: Optional[NamedArray] = None, + token_type_ids: Optional[NamedArray] = None, + position_ids: Optional[NamedArray] = None, + input_embeds: Optional[NamedArray] = None, + *, + key=None + ) -> Tuple[NamedArray]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + + k_rob, k_lm = maybe_rng_split(key, 2) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + input_embeds=input_embeds, + key=k_rob + ) + + prediction_scores = self.lm_head(outputs[0], key=k_lm) + + return prediction_scores + + +def _rotate_half(x: NamedArray) -> NamedArray: + """Rotates half of the hidden dims of the input and concatenates them.""" + HeadSize = x.axes[-1] + x1 = x[HeadSize, : HeadSize.size // 2] + x2 = x[HeadSize, HeadSize.size // 2 :] + out = hax.concatenate(HeadSize, (-x2, x1)) + return out + + +def _apply_rotary_pos_emb( + q: NamedArray, # [batch, position, kv_heads, q_heads_per_group, head_size] + k: NamedArray, # [batch, position, kv_heads, head_size] + cos: NamedArray, # [position, head_size] + sin: NamedArray, # [position, head_size] +) -> Tuple[NamedArray, NamedArray]: + """Applies rotary position embedding to q and k.""" + q_embed = q * cos + _rotate_half(q) * sin + k_embed = k * cos + _rotate_half(k) * sin + return q_embed, k_embed + + +def llama_rotary_pos_emb( + HeadSize: Axis, Pos: Axis, base: float = 10000, scale: float = 1.0 +) -> Tuple[NamedArray, NamedArray]: + with jax.ensure_compile_time_eval(): + HeadHalfSize = HeadSize.resize(HeadSize.size // 2) + inv_freq: NamedArray = 1.0 / (base ** (hax.arange(HeadHalfSize, step=2) / HeadSize.size)) + + position_ids: NamedArray = hax.arange(Pos) / scale + + freqs = position_ids * inv_freq.broadcast_axis(Pos) + # This is different from the paper but aligns with HF implementation: + # It uses a different permutation in order to obtain the same calculation + emb = hax.concatenate(HeadSize, (freqs, freqs)) + cos = hax.cos(emb) + sin = hax.sin(emb) + # This is different from the paper but aligns with HF implementation: + return cos, sin \ No newline at end of file From 399e08c5024a7d4f80e8d90e4375b19c7ad13366 Mon Sep 17 00:00:00 2001 From: Prady Saligram Date: Sat, 31 Aug 2024 17:23:09 -0700 Subject: [PATCH 7/7] Adds compute_loss to roberta and changes positional ids to begin from 0 --- config/roberta-tiny.yaml | 39 +++++++++++++++++++++++++++++++++ config/roberta.yaml | 2 +- src/levanter/data/text.py | 21 ++++++++++-------- src/levanter/main/train_mlm.py | 12 ++++++---- src/levanter/models/lm_model.py | 8 +++---- src/levanter/models/roberta.py | 2 +- 6 files changed, 65 insertions(+), 19 deletions(-) create mode 100644 config/roberta-tiny.yaml diff --git a/config/roberta-tiny.yaml b/config/roberta-tiny.yaml new file mode 100644 index 000000000..4b61ff7e4 --- /dev/null +++ b/config/roberta-tiny.yaml @@ -0,0 +1,39 @@ +data: + id: dlwh/wikitext_103_detokenized +# train_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_train.{1..128}-of-128.jsonl.gz" +# validation_urls: +# - "gs://pubmed-mosaic/openwebtext-sharded/openwebtext_val.{1..8}-of-8.jsonl.gz" + cache_dir: "cache/roberta-tiny" + tokenizer: "roberta-base" + +model: + type: roberta + vocab_size: 50265 + hidden_size: 32 + intermediate_size: 64 + num_hidden_layers: 4 + num_attention_heads: 2 + max_position_embeddings: 512 + hidden_act: "gelu" + hidden_dropout_prob: 0.1 + attention_probs_dropout_prob: 0.1 + gradient_checkpointing: true + +trainer: + tracker: + - type: wandb + project: "levanter" + tags: ["openwebtext", "roberta", "itest"] + + mp: p=f32,c=bfloat16 + model_axis_size: 1 + per_device_parallelism: -1 + + train_batch_size: 32 + num_train_steps: 20000 + +optimizer: + learning_rate: 1E-3 + weight_decay: 0.1 + warmup: 0.01 \ No newline at end of file diff --git a/config/roberta.yaml b/config/roberta.yaml index 81f5d4d35..c854f8109 100644 --- a/config/roberta.yaml +++ b/config/roberta.yaml @@ -35,4 +35,4 @@ trainer: optimizer: learning_rate: 1E-3 weight_decay: 0.1 - warmup: 0.01 + warmup: 0.01 \ No newline at end of file diff --git a/src/levanter/data/text.py b/src/levanter/data/text.py index dfc7df4ea..d0898f2f0 100644 --- a/src/levanter/data/text.py +++ b/src/levanter/data/text.py @@ -68,10 +68,11 @@ def __init__( dataset: ShardableDataset[np.ndarray], QPos: Axis, KPos: Axis, + mask_token_id: int, mask_prob: float = 0.15, noise_prob: float = 0.1, key: Optional[PRNGKeyArray] = None, - ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, + # ignore_index: Optional[int] = DEFAULT_IGNORE_INDEX, ): self.dataset = dataset self.QPos = QPos @@ -79,14 +80,16 @@ def __init__( self.mask_prob = mask_prob self.noise_prob = noise_prob self.key = key - self.ignore_id = ignore_index if ignore_index is not None else DEFAULT_IGNORE_INDEX + self.mask_token_id = mask_token_id if self.mask_prob > 0.0 and self.key is None: raise ValueError("must provide key if mask_prob > 0.0") def shard(self, shard_id: int, num_shards: int) -> "MaskedLmDataset": return MaskedLmDataset( - self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, self.mask_prob, self.noise_prob, self.key, self.ignore_id + self.dataset.shard(shard_id, num_shards), self.QPos, self.KPos, + self.mask_token_id, + self.mask_prob, self.noise_prob, self.key ) def __iter__(self) -> Iterator[MaskedLmExample]: @@ -105,13 +108,13 @@ def _create_mlm_example(tokens, key): mask = jax.random.bernoulli(this_key, self.mask_prob, mask_shape) rand = jax.random.uniform(this_key, mask_shape) - mask_token = jnp.where(rand < 0.8, self.ignore_id, tokens_array) + mask_token = jnp.where(rand < 0.8, self.mask_token_id, tokens_array) random_tokens = jax.random.randint(this_key, mask_shape, 0, tokens_array.max() + 1) mask_token = jnp.where((rand >= 0.8) & (rand < 0.8 + self.noise_prob), random_tokens, mask_token) masked_tokens = jnp.where(mask, mask_token, tokens_array) - # Set targets to the original tokens where mask is True, otherwise set to ignore_id - targets = jnp.where(mask, tokens_array, self.ignore_id) + # Set targets to the original tokens where mask is True, otherwise set to mask_token_id + targets = jnp.where(mask, tokens_array, self.mask_token_id) masked_tokens_named = hax.named(masked_tokens, self.QPos) targets_named = hax.named(targets, self.QPos) @@ -119,13 +122,13 @@ def _create_mlm_example(tokens, key): attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) - example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + example = MaskedLmExample.masked_lm(tokens=masked_tokens_named, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) else: targets_named = hax.named(targets, self.QPos) attn_mask_shape = (tokens_array.shape[0], tokens_array.shape[0]) attn_mask = hax.named(jnp.ones(attn_mask_shape, dtype=jnp.bool_), (self.QPos, self.KPos)) - example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, ignore_id=self.ignore_id, attn_mask=attn_mask) + example = MaskedLmExample.masked_lm(tokens=tokens, targets=targets_named, mask_token_id=self.mask_token_id, attn_mask=attn_mask) return example @@ -900,4 +903,4 @@ def build_caches( @property def sources(self) -> dict[str, LMDatasetSourceConfig]: - return self.configs + return self.configs \ No newline at end of file diff --git a/src/levanter/main/train_mlm.py b/src/levanter/main/train_mlm.py index 80e941d5a..435abe5bf 100644 --- a/src/levanter/main/train_mlm.py +++ b/src/levanter/main/train_mlm.py @@ -7,6 +7,7 @@ from dataclasses import dataclass, field from typing import Optional, Union +import jax import jax.random as jrandom import haliax as hax @@ -86,7 +87,7 @@ def main(config: TrainMlmConfig): # 1. Sets the device mesh # 2. Sets the axis mapping (for fsdp) # 3. Sets the global metrics tracker - with Trainer(config.trainer, optimizer) as trainer: + with Trainer(config.trainer, optimizer) as trainer, jax.disable_jit(True): # randomness in jax is tightly controlled by "keys" which are the states of the random number generators # this makes deterministic training pretty easy seed = config.trainer.seed @@ -108,8 +109,11 @@ def main(config: TrainMlmConfig): KeyPos = config.model.KeyPos tagged_eval_datasets = config.data.tagged_eval_sets(Pos.size) + mask_id = tokenizer.mask_token_id train_dataset = MaskedLmDataset( - config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id + config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, + mask_token_id=mask_id, + mask_prob=config.mlm_prob, key=data_key, #ignore_index=config.data.ignore_token_id ) # to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to @@ -150,7 +154,7 @@ def main(config: TrainMlmConfig): logger.warning("No evaluation datasets provided.") else: masked_datasets = [ - (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, ignore_index=config.data.ignore_token_id), tags) + (MaskedLmDataset(ds, Pos, KeyPos, mask_prob=config.mlm_prob, key=data_key, mask_token_id=mask_id), tags) for ds, tags in tagged_eval_datasets ] max_eval_examples_per_ds = config.trainer.max_eval_batches @@ -198,4 +202,4 @@ def compute_log_probs(model, example): trainer.train(state, train_loader) if __name__ == "__main__": - levanter.config.main(main)() + levanter.config.main(main)() \ No newline at end of file diff --git a/src/levanter/models/lm_model.py b/src/levanter/models/lm_model.py index c36e0e622..01e1252a8 100644 --- a/src/levanter/models/lm_model.py +++ b/src/levanter/models/lm_model.py @@ -24,7 +24,7 @@ class MaskedLmExample(eqx.Module): @staticmethod def masked_lm( - tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, ignore_id: Optional[int] = None + tokens: hax.NamedArray, targets: hax.NamedArray, attn_mask: hax.NamedArray, mask_token_id: Optional[int] = None ) -> "MaskedLmExample": if tokens.ndim != 1: raise ValueError("tokens must be a 1D array") @@ -40,8 +40,8 @@ def masked_lm( mask = tokens.array != targets.array loss_mask = hax.named(mask.astype(jnp.float32), Pos) - if ignore_id is not None: - ignore_mask = targets.array != ignore_id + if mask_token_id is not None: + ignore_mask = targets.array != mask_token_id loss_mask = loss_mask * hax.named(ignore_mask.astype(jnp.float32), Pos) return MaskedLmExample(tokens=tokens, targets=targets, loss_mask=loss_mask, attn_mask=attn_mask) @@ -156,4 +156,4 @@ def compute_loss( @property def vocab_size(self) -> int: - return self.Vocab.size + return self.Vocab.size \ No newline at end of file diff --git a/src/levanter/models/roberta.py b/src/levanter/models/roberta.py index f51771eff..ed23708f7 100644 --- a/src/levanter/models/roberta.py +++ b/src/levanter/models/roberta.py @@ -574,7 +574,7 @@ def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0 return incremental_indices + self.padding_idx def create_position_ids_from_inputs_embeds(self, input_axes, PosInput): - position_ids = hax.arange(axis = PosInput, start = self.padding_idx + 1, dtype=jnp.int32) + position_ids = hax.arange(axis = PosInput, start = 0, dtype=jnp.int32) return hax.broadcast_to(position_ids, input_axes) @named_call