Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epochs to levanter #768

Merged
merged 38 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
234a945
wip epochs
ahmeda14960 Oct 16, 2024
f0b1eaa
fix
ahmeda14960 Oct 16, 2024
020a1b2
add epoch flag, sanity check tulu one epoch
ahmeda14960 Oct 16, 2024
50500b9
epochs work
ahmeda14960 Oct 16, 2024
49afb5d
minor fix
ahmeda14960 Oct 16, 2024
c2ed3ee
fix ci
ahmeda14960 Oct 16, 2024
667a5a3
fix ci
ahmeda14960 Oct 16, 2024
37e77fb
fix config file
ahmeda14960 Oct 17, 2024
7c195ba
add suggested fix from david
ahmeda14960 Oct 18, 2024
e71ed16
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 23, 2024
54a6007
restore toml
ahmeda14960 Oct 23, 2024
e2646d6
Update src/levanter/callbacks.py
ahmeda14960 Oct 23, 2024
fd18cae
refactor
ahmeda14960 Oct 23, 2024
1706803
add suggested fix from david
ahmeda14960 Oct 23, 2024
f0ca163
update for v4 so we don't crash
ahmeda14960 Oct 23, 2024
c971ebf
remove changes that break epochs
ahmeda14960 Oct 23, 2024
4733f3b
final fixes
ahmeda14960 Oct 23, 2024
e82eec2
final fixes
ahmeda14960 Oct 24, 2024
08fd427
substatial changes to save on epochs w callback
ahmeda14960 Oct 24, 2024
18a5352
epoch tracking still broken
ahmeda14960 Oct 24, 2024
f1ef2c7
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 25, 2024
c38b076
WIP
ahmeda14960 Oct 25, 2024
7331774
update epochs to save latest checkpoints
ahmeda14960 Oct 28, 2024
aa47d4e
Update src/levanter/checkpoint.py
ahmeda14960 Oct 28, 2024
0148cd0
update tulu config to match olmo sft
ahmeda14960 Oct 28, 2024
a7459e0
Merge remote-tracking branch 'origin/sft' into sft
ahmeda14960 Oct 28, 2024
dde75ac
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 28, 2024
5343096
pre commit
ahmeda14960 Oct 28, 2024
fd39828
fix sft bug caused by exemplar
ahmeda14960 Oct 29, 2024
313a3f4
add actual sft file
ahmeda14960 Oct 29, 2024
b3718c1
precommit
ahmeda14960 Oct 29, 2024
5f36eb8
sft working w levanter chkpt
ahmeda14960 Oct 29, 2024
f5533d6
add back option for hf models on sft
ahmeda14960 Oct 29, 2024
91fc5df
WIP for david
ahmeda14960 Oct 30, 2024
ba682ca
debug epochs
ahmeda14960 Oct 31, 2024
812accb
load data from marin sources
ahmeda14960 Nov 6, 2024
2d7170c
merge main
ahmeda14960 Nov 6, 2024
caf0a38
merge main
ahmeda14960 Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions config/llama_7b_tulu.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
data:
train_urls:
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-000.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-001.jsonl.gz"
- "gs://marin-us-central2/documents/instruct/tulu_v2_mix/text/tulu-v2-sft-mixture-002.jsonl.gz"
cache_dir: "gs://marin-us-central2/tokenized/OLMo-1B/tuluv2_sft/"
tokenizer: "allenai/OLMo-1B"
model: # 7B class model
type: llama
seq_len: 4096
hidden_dim: 4096
intermediate_dim: 11008
num_layers: 32
num_heads: 32
num_kv_heads: 32
use_flash_attention: True
flash_attention_block_size: 512
use_bias: false
use_layer_norm_weight: false
trainer:
tracker:
type: wandb
project: "marin"
tags: ["dolma", "olmo", "llama"]

mp: p=f32,c=bfloat16
train_batch_size: 256
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
steps_per_eval: 1000
tensor_parallel_axes: ["mlp", "heads"]
fsdp_axis: "embed"
batch_axis: "batch"
optimizer:
learning_rate: 4E-4
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 5000

epoch: 0
4 changes: 3 additions & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"]
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
# to not predict EOS token since we don't have target!
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
Expand Down
50 changes: 48 additions & 2 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import threading
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta
from typing import Callable, Optional

Expand All @@ -17,7 +18,7 @@
from tqdm_loggable.auto import tqdm

import levanter.tracker
from levanter.data import DataLoader
from levanter.data import DataLoader, AsyncDataset
from levanter.logging import save_xla_dumps_to_wandb
from levanter.tracker.helpers import log_optimizer_hyperparams
from levanter.tracker.wandb import WandbConfig
Expand All @@ -26,9 +27,54 @@
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs


logger = pylogging.getLogger(__name__)

def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size, max_epochs: Optional[int] = None):
total_tokens = None

def log_epoch(step_info: StepInfo):
nonlocal total_tokens
if total_tokens is None:
if not total_tokens_future.done():
if step_info.step % 1000 == 0:
logger.info("Dataset not finished. Can't compute epochs.")
return # We don't have the total tokens yet, so we can't calculate epoch
dlwh marked this conversation as resolved.
Show resolved Hide resolved
total_tokens = total_tokens_future.result()

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step

# If we're doing multiple epochs, adjust the denominator
total_tokens_for_epochs = total_tokens * max_epochs if max_epochs else total_tokens
current_epoch = processed_tokens / total_tokens_for_epochs

levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch


def get_total_dataset_tokens(ds: AsyncDataset, seq_length: int):

def log_length():
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# If ds.async_len() is the only option, run it in an event loop inside the thread
import asyncio

async def compute_length():
length = await ds.dataset.async_len()
return length

# Run the async function synchronously in this thread
length = asyncio.run(compute_length())
total_tokens = length * seq_length
levanter.tracker.log_summary({"dataset/total_tokens": total_tokens})
return total_tokens

# Create a ThreadPoolExecutor with a single worker thread
executor = ThreadPoolExecutor(max_workers=1)
# Submit the log_length function to be executed in a separate thread
future = executor.submit(log_length)
return future


def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
Expand Down
34 changes: 34 additions & 0 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from levanter.tensorstore_serialization import tree_deserialize_leaves_tensorstore, tree_serialize_leaves_tensorstore
from levanter.types import FilterSpec
# from levanter.trainer import StepInfo
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -261,6 +262,39 @@ def _async_checkpoint_remover(self):
self._do_rm_checkpoint(checkpoint)
self._checkpoint_being_removed = None

# In callbacks.py - Add a new callback that handles epoch checkpointing
class EpochCheckpointer:
"""
A separate checkpointing system that saves based on epochs.
Works alongside the regular step-based checkpointer without modifying core state.
"""
def __init__(self,
checkpointer: Checkpointer,
every_n_epochs: int = 1,
total_dataset_size: Optional[int] = None,
batch_size: int = 1):
self.checkpointer = checkpointer
self.every_n_epochs = every_n_epochs
self.total_dataset_size = total_dataset_size
self.batch_size = batch_size
self._last_saved_epoch = -1

def __call__(self, step_info):
if self.total_dataset_size is None:
return # Can't calculate epochs without dataset size

# Calculate current epoch from steps without modifying StepInfo
current_epoch = (step_info.step * self.batch_size) // self.total_dataset_size
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved

# Only save if we've moved to a new epoch and it matches our interval
if (current_epoch > self._last_saved_epoch and
current_epoch % self.every_n_epochs == 0):
# Use existing checkpointer's save_checkpoint method
self.checkpointer.save_checkpoint(
step_info,
f"epoch-{current_epoch}",
)
self._last_saved_epoch = current_epoch

def save_checkpoint(
tree: M,
Expand Down
89 changes: 88 additions & 1 deletion src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,82 @@
DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index


class EpochDataset(AsyncDataset[T_co]):
"""
A dataset that wraps another dataset, providing infinite epochs by recycling indices.
If `max_epochs` is specified, it limits the number of cycles before raising StopIteration.

:param dataset: The dataset to wrap.
:param max_epochs: The maximum number of epochs to cycle through. If None, cycle indefinitely.
"""
def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
self.dataset = dataset
self.max_epochs = max_epochs

async def async_len(self) -> int:
if self.max_epochs is None:
raise ValueError("Cannot determine length of an infinite dataset without max_epochs.")
# Return the total number of samples: max_epochs * length of the dataset
return self.max_epochs * await self.dataset.async_len()

async def final_length_is_known(self) -> bool:
return await self.dataset.final_length_is_known()

def is_finite(self) -> bool:
# EpochDataset can be finite if max_epochs is set.
return self.max_epochs is not None

async def current_len(self) -> Optional[int]:
# If max_epochs is None, the dataset is effectively infinite.
if self.max_epochs is None:
return None

# If the final length of the dataset is not known, return the current length of the underlying dataset.
if not await self.dataset.final_length_is_known():
return await self.dataset.current_len()

# If the final length is known, return the max_epochs * async_len of the dataset.
return self.max_epochs * await self.dataset.async_len()

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
# Use self.wait_until_len_at_least to ensure we have enough data for the batch.
max_index = max(indices)
ds_len = await self.dataset.wait_until_len_at_least(max_index + 1)

# Determine the epoch based on the largest index
epoch = max_index // ds_len

# If max_epochs is specified, raise an error if the epoch exceeds the allowed number of epochs
if self.max_epochs is not None and epoch >= self.max_epochs:
raise StopIteration(f"Reached maximum number of epochs: epoch {epoch} exceeds the maximum allowed {self.max_epochs}")

# Wrap the indices within the bounds of the dataset length
wrapped_indices = [idx % ds_len for idx in indices]

# Delegate to the underlying dataset's get_batch
return await self.dataset.get_batch(wrapped_indices)

async def wait_until_len_at_least(self, length: int) -> int:
"""
Returns the length of the dataset once it is at least `length` or if the dataset has a known (finished) length.
If the dataset's actual length is less than `length`, it returns the minimum of async_len and the current length.
"""
# Wait until the underlying dataset's length is at least `length`
if not self.is_finite():
return length

if await self.dataset.final_length_is_known():
base_length = await self.dataset.async_len()
else:
base_length = await self.dataset.wait_until_len_at_least(length)

if base_length < length:
# hit epoch boundary
assert self.max_epochs is not None
return self.max_epochs * base_length

return base_length

class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
A dataset that yields sequences of tokens of fixed length from an underlying TreeCache.
Expand Down Expand Up @@ -648,9 +724,20 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: int = 0,
) -> AsyncDataset[np.ndarray]:

ds = self.token_seq_dataset("train", seq_len, monitors)
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

Expand Down
28 changes: 26 additions & 2 deletions src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import levanter
from levanter import callbacks
from levanter.checkpoint import load_checkpoint
from levanter.checkpoint import EpochCheckpointer, load_checkpoint
from levanter.compat.hf_checkpoints import HFCompatConfig, save_hf_checkpoint_callback
from levanter.data.text import CausalLmDataset, LMDatasetConfig, LMMixtureDatasetConfig, LMSupervisedDatasetConfig
from levanter.models.gpt2 import Gpt2Config
Expand Down Expand Up @@ -54,6 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: int = 0


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -117,10 +118,32 @@ def main(config: TrainLmConfig):

# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch),
Pos,
KeyPos,
ignore_index=config.data.ignore_token_id,
)


# add epoch logging if epochs specified
if config.epoch > 0:
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(
callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size, max_epochs=config.epoch), every=1
)

# Add epoch checkpoint callback
epoch_checkpointer = EpochCheckpointer(
checkpointer=trainer.config.checkpointer.create(trainer.run_id),
every_n_epochs=1, # Or configure as needed
total_dataset_size=total_tokens_future.result(),
batch_size=trainer.config.train_batch_size
)
trainer.add_hook(epoch_checkpointer, every=1)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
Expand Down Expand Up @@ -236,6 +259,7 @@ def compute_log_probs(model, example):

## OK, actually run training!
trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi
while int(state.step) < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

info = self.train_step(state, example)
state = info.state

Expand Down
Loading