Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epochs to levanter #768

Merged
merged 38 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
234a945
wip epochs
ahmeda14960 Oct 16, 2024
f0b1eaa
fix
ahmeda14960 Oct 16, 2024
020a1b2
add epoch flag, sanity check tulu one epoch
ahmeda14960 Oct 16, 2024
50500b9
epochs work
ahmeda14960 Oct 16, 2024
49afb5d
minor fix
ahmeda14960 Oct 16, 2024
c2ed3ee
fix ci
ahmeda14960 Oct 16, 2024
667a5a3
fix ci
ahmeda14960 Oct 16, 2024
37e77fb
fix config file
ahmeda14960 Oct 17, 2024
7c195ba
add suggested fix from david
ahmeda14960 Oct 18, 2024
e71ed16
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 23, 2024
54a6007
restore toml
ahmeda14960 Oct 23, 2024
e2646d6
Update src/levanter/callbacks.py
ahmeda14960 Oct 23, 2024
fd18cae
refactor
ahmeda14960 Oct 23, 2024
1706803
add suggested fix from david
ahmeda14960 Oct 23, 2024
f0ca163
update for v4 so we don't crash
ahmeda14960 Oct 23, 2024
c971ebf
remove changes that break epochs
ahmeda14960 Oct 23, 2024
4733f3b
final fixes
ahmeda14960 Oct 23, 2024
e82eec2
final fixes
ahmeda14960 Oct 24, 2024
08fd427
substatial changes to save on epochs w callback
ahmeda14960 Oct 24, 2024
18a5352
epoch tracking still broken
ahmeda14960 Oct 24, 2024
f1ef2c7
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 25, 2024
c38b076
WIP
ahmeda14960 Oct 25, 2024
7331774
update epochs to save latest checkpoints
ahmeda14960 Oct 28, 2024
aa47d4e
Update src/levanter/checkpoint.py
ahmeda14960 Oct 28, 2024
0148cd0
update tulu config to match olmo sft
ahmeda14960 Oct 28, 2024
a7459e0
Merge remote-tracking branch 'origin/sft' into sft
ahmeda14960 Oct 28, 2024
dde75ac
Merge remote-tracking branch 'origin/main' into sft
ahmeda14960 Oct 28, 2024
5343096
pre commit
ahmeda14960 Oct 28, 2024
fd39828
fix sft bug caused by exemplar
ahmeda14960 Oct 29, 2024
313a3f4
add actual sft file
ahmeda14960 Oct 29, 2024
b3718c1
precommit
ahmeda14960 Oct 29, 2024
5f36eb8
sft working w levanter chkpt
ahmeda14960 Oct 29, 2024
f5533d6
add back option for hf models on sft
ahmeda14960 Oct 29, 2024
91fc5df
WIP for david
ahmeda14960 Oct 30, 2024
ba682ca
debug epochs
ahmeda14960 Oct 31, 2024
812accb
load data from marin sources
ahmeda14960 Nov 6, 2024
2d7170c
merge main
ahmeda14960 Nov 6, 2024
caf0a38
merge main
ahmeda14960 Nov 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config/llama_7b_with_olmo_config.yaml
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ trainer:
project: "marin"
tags: ["dolma", "olmo", "llama"]

checkpointer:
keep:
- every: 250

mp: p=f32,c=bfloat16
train_batch_size: 2048
num_train_steps: 750000 # 3,000,000,000,000 / 4,000,000 = 750,000
Expand All @@ -27,3 +31,5 @@ optimizer:
weight_decay: 0.1
min_lr_ratio: 0.1
warmup: 0.01

data_shuffle: true
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 3 additions & 1 deletion examples/alpaca/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,13 @@ def _prepare_example(ex: dict) -> LmExample:
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
if config.mask_inputs:
loss_mask = hax.arange(Pos) >= ex["source_lens"]
loss_mask = hax.arange(Pos) >= ex["source_lens"] - 1 # should be minus 1?

# don't predict the padding
targets = hax.roll(input_ids, -1, Pos)
loss_mask = loss_mask & (targets != tokenizer.pad_token_id)
# to not predict EOS token since we don't have target!
loss_mask = loss_mask & (1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.bool_))
else:
loss_mask = 1 - hax.nn.one_hot(-1, Pos, dtype=jax.numpy.float32)
lm_ex = LmExample.causal(input_ids, loss_mask=loss_mask)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ dependencies = [
"pydantic<3",
"rich~=13.0",
"filelock~=3.13",
# "ai2-olmo",
"async-lru~=2.0",
"tqdm-loggable>=0.2",
"deepdiff"
Expand Down
44 changes: 44 additions & 0 deletions src/levanter/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,55 @@
from levanter.utils import flop_utils
from levanter.utils.jax_utils import barrier_sync, jnp_to_python
from levanter.visualization import compute_and_visualize_log_probs as viz_probs
from levanter.data.text import TokenSeqEpochDataset
from concurrent.futures import ThreadPoolExecutor



logger = pylogging.getLogger(__name__)


def log_epoch_progress(total_tokens_future, tokens_per_example, batch_size):
total_tokens = None

def log_epoch(step_info: StepInfo):
nonlocal total_tokens
if total_tokens is None:
if not total_tokens_future.done():
return # We don't have the total tokens yet, so we can't calculate epoch
dlwh marked this conversation as resolved.
Show resolved Hide resolved
total_tokens = total_tokens_future.result()

# Get the total processed tokens from the metrics logged by log_performance_stats
processed_tokens = tokens_per_example * batch_size * step_info.step
if processed_tokens is None:
return # No token count available yet

current_epoch = processed_tokens / total_tokens
levanter.tracker.log_metrics({"train/current_epoch": current_epoch}, step=step_info.step)

return log_epoch

def get_total_dataset_tokens(ds: TokenSeqEpochDataset, seq_length: int):
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
def log_length():
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# If ds.async_len() is the only option, run it in an event loop inside the thread
import asyncio

async def compute_length():
length = await ds.async_len()
return length

# Run the async function synchronously in this thread
length = asyncio.run(compute_length())
total_tokens = length * seq_length
levanter.tracker.log_summary({"dataset/total_tokens": total_tokens})
return total_tokens

# Create a ThreadPoolExecutor with a single worker thread
executor = ThreadPoolExecutor(max_workers=1)
# Submit the log_length function to be executed in a separate thread
future = executor.submit(log_length)
return future

def eval_loss_loop(loss_fn, model, dataset, max_batches: Optional[int] = None, name: Optional[str] = None):
total_loss = 0.0
total_load_time = 0.0
Expand Down
69 changes: 67 additions & 2 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,57 @@

DEFAULT_IGNORE_INDEX = -100 # Mirrors pytorch's default ignore index

class TokenSeqEpochDataset(AsyncDataset[np.ndarray]):
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, doc_cache: TreeCache[dict], seq_len: int):
self.doc_cache = doc_cache
self.seq_len = seq_len
self._store: Optional[TreeStore] = None
self._cached_len: Optional[int] = None

async def async_len(self) -> int:
await self.doc_cache.finished()
token_arrays = await self._await_token_cache()
return token_arrays.data_size // self.seq_len

async def _await_token_cache(self) -> JaggedArrayStore:
if self._store is None:
self._store = await self.doc_cache.store_async()
return self._store.tree["input_ids"]

async def final_length_is_known(self) -> bool:
return await self.doc_cache.final_length_is_known()

def is_finite(self) -> bool:
return False # Now infinite due to epoch wrapping

async def current_len(self) -> Optional[int]:
store = await self._await_token_cache()
return store.data_size // self.seq_len

async def get_batch(self, indices: Sequence[int]) -> Sequence[T_co]:
token_arrays = await self._await_token_cache()
dataset_len = await self.async_len()

wrapped_indices = [idx % dataset_len for idx in indices]
offsets = np.array(wrapped_indices) * self.seq_len

with ts.Batch():
out = []
for offset in offsets:
out.append(token_arrays.data[offset : offset + self.seq_len].read())

out = await asyncio.gather(*out)
return out

async def wait_until_len_at_least(self, length: int) -> int:
# length is brutally slow to compute, so we cache it
if self._cached_len is not None:
return self._cached_len

# TODO: would be better to listen for cache updates
length = await super().wait_until_len_at_least(length)
self._cached_len = length
return length

class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""
Expand Down Expand Up @@ -640,9 +691,15 @@ class LMDatasetConfig(LMDatasetSourceConfig, LMTaskConfig):
cache_dir: Optional[str] = "cache/"

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray] = None, epochs: bool = False
) -> AsyncDataset[np.ndarray]:
ds = self.token_seq_dataset("train", seq_len, monitors)

if epochs:
ds = self.token_epoch_dataset("train", seq_len, monitors)
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
else:
ds = self.token_seq_dataset("train", seq_len, monitors)

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

Expand Down Expand Up @@ -693,6 +750,14 @@ def token_seq_dataset(
if cache is None:
return None
return TokenSeqDataset(cache, seq_len)

def token_epoch_dataset(
self, split: str, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True
) -> Optional[TokenSeqDataset]:
cache = self.build_or_load_cache(split, monitors=monitors)
if cache is None:
return None
return TokenSeqEpochDataset(cache, seq_len)

def build_or_load_cache(
self, split: str, monitors: Union[bool, List[MetricsMonitor]] = True, logger_name: Optional[str] = None
Expand Down
11 changes: 10 additions & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class TrainLmConfig:
data_seed: Optional[int] = None # if provided, will override the data seed from the trainer
initialize_from_checkpoint_path: Optional[str] = None
# if provided, will initialize from this checkpoint, used for llama style data mixture
epoch: bool = False # if true, will keep epoching over the dataset and track epochs
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved


def main(config: TrainLmConfig):
Expand Down Expand Up @@ -117,10 +118,17 @@ def main(config: TrainLmConfig):

# TODO: fix this
tagged_eval_datasets: list = config.data.tagged_eval_sets(Pos.size)
# TokenSeqDataset is config.data.train_set(Pos.size, key=data_key)

train_dataset = CausalLmDataset(
config.data.train_set(Pos.size, key=data_key), Pos, KeyPos, ignore_index=config.data.ignore_token_id
config.data.train_set(Pos.size, key=data_key, epochs=config.epoch), Pos, KeyPos, ignore_index=config.data.ignore_token_id
)

if config.epoch:
ahmeda14960 marked this conversation as resolved.
Show resolved Hide resolved
# add epoch logging
total_tokens_future = callbacks.get_total_dataset_tokens(train_dataset.dataset, config.model.seq_len)
trainer.add_hook(callbacks.log_epoch_progress(total_tokens_future, Pos.size, trainer.config.train_batch_size), every=1)

# to do partitioning, our dimensions have to be divisible by the size of the physical axes they're mapped to
# For most things, we just insist you specify the config right, but tokenizers often have strange numbers of
# tokens: gpt-2 has 50257, for example. So we round up.
Expand Down Expand Up @@ -236,6 +244,7 @@ def compute_log_probs(model, example):

## OK, actually run training!
trainer.train(state, train_loader)

# checkpointer.on_step(last_step, force=True)


Expand Down
1 change: 0 additions & 1 deletion src/levanter/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,6 @@ def training_steps(self, state: S, train_loader, run_hooks: bool = True) -> typi
while int(state.step) < self.num_train_steps:
with capture_time() as loading_time:
example = next(iter_data)

info = self.train_step(state, example)
state = info.state

Expand Down
Loading