Skip to content

Commit

Permalink
fix internal_eval lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 8, 2024
1 parent be80580 commit cf222d4
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def train(config: SFTConfig):
raise ValueError("Must specify either --initialize_from_hf or --initialize_from")
else:
converter = None
model_config = config.model

levanter.initialize(config)

Expand All @@ -106,10 +107,10 @@ def train(config: SFTConfig):
input_role=config.input_role,
output_role=config.output_role,
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer, model_config.Pos)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer, model_config.Pos)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)

Expand Down
1 change: 1 addition & 0 deletions src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]):
"""

def __init__(self, cache: TreeCache[AudioTextDict]):
super().__init__()
self.cache = cache

async def async_len(self) -> int:
Expand Down
13 changes: 12 additions & 1 deletion src/levanter/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ class AsyncDataset(DatasetBase[T_co]):
* `current_len`: Returns the current length of the dataset. This may be None if no current length is known.
"""

def __init__(self):
self._min_known_len = 0

@abc.abstractmethod
async def async_len(self) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -95,7 +98,12 @@ async def wait_until_len_at_least(self, length: int) -> int:
The default implementation is a naive busy-wait loop. You should override this method for more efficient
implementations.
"""
return await naive_busy_wait_until_len_at_least(self, length)
if length <= self._min_known_len:
return self._min_known_len

res_len = await naive_busy_wait_until_len_at_least(self, length)
self._min_known_len = max(self._min_known_len, res_len)
return res_len

def as_sync_dataset(self):
return SyncifiedDataset(self)
Expand Down Expand Up @@ -206,6 +214,7 @@ def __getitem__(self, index: int) -> T_co:

class AsyncifiedDataset(AsyncDataset[T_co]):
def __init__(self, dataset: SyncDataset[T_co]):
super().__init__()
self.dataset = dataset

async def async_len(self) -> int:
Expand Down Expand Up @@ -239,6 +248,7 @@ class ListAsyncDataset(AsyncDataset[T]):
"""

def __init__(self, data: list[T], is_complete: bool = False):
super().__init__()
self.data = data
self.is_complete = is_complete
if not is_complete:
Expand Down Expand Up @@ -315,6 +325,7 @@ def __init__(
*extra_args,
**extra_kwargs,
):
super().__init__()
self.dataset = dataset
self.fn = fn
self._extra_args = extra_args
Expand Down
1 change: 1 addition & 0 deletions src/levanter/data/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
key: PRNGKeyArray | int,
stop_strategy: str = StopStrategy.RESTART_STRATEGY,
):
super().__init__()
self.weights = MixtureDataset._normalize_weights(weights)
self.datasets = {name: dataset for name, dataset in datasets.items() if self.weights.get(name, 0) > 0}
self.dataset_index = Index(self.datasets.keys())
Expand Down
2 changes: 2 additions & 0 deletions src/levanter/data/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PermutationDataset(AsyncDataset[T_co]):
# TODO: add epoch reshuffling

def __init__(self, dataset: AsyncDataset[T_co], key: jax.random.PRNGKey):
super().__init__()
self.dataset = dataset
self.key = key
self._permutation: Optional[Permutation] = None
Expand Down Expand Up @@ -72,6 +73,7 @@ class EraShufflingDataset(AsyncDataset[T_co]):
"""

def __init__(self, dataset: AsyncDataset[T_co], era_length: int, *, key: jax.random.PRNGKey):
super().__init__()
self.dataset = dataset
self.era_length = era_length
self.key = key
Expand Down
24 changes: 16 additions & 8 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class EpochDataset(AsyncDataset[T_co]):
"""

def __init__(self, dataset: AsyncDataset[T_co], max_epochs: Optional[int] = None):
super().__init__()
self.dataset = dataset
self.max_epochs = max_epochs

Expand Down Expand Up @@ -154,6 +155,7 @@ class TokenSeqDataset(AsyncDataset[np.ndarray]):
"""

def __init__(self, doc_cache: TreeCache[dict], seq_len: int):
super().__init__()
self.doc_cache = doc_cache
self.seq_len = seq_len
self._store: Optional[TreeStore] = None
Expand Down Expand Up @@ -687,7 +689,7 @@ def preprocess_supervised_example(
}


def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) -> LmExample:
def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis) -> LmExample:
"""
Prepare an example for training. This function converts the (cached) batch encoding into an LmExample.
Expand All @@ -699,11 +701,15 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) ->
"""
with local_cpu_mesh():
# annoyingly, pad expects things to be batched so we have to prepend a batch axis
ex = tokenizer.pad({k: np.expand_dims(v, 0) for k, v in ex.items()}, return_tensors="np", padding="max_length")
ex = tokenizer.pad(
{k: np.expand_dims(v, 0) for k, v in ex.items()},
return_tensors="np",
padding="max_length",
max_length=Pos.size,
)
ex = {k: v[0] for k, v in ex.items()}
input_ids = hax.named(ex["input_ids"], "position")
input_ids = hax.named(ex["input_ids"], Pos)
# mask out padding and anything before the start of the target
Pos = input_ids.resolve_axis("position")
loss_mask = hax.arange(Pos) >= ex["sources_len"] - 1

# don't predict the padding
Expand All @@ -714,7 +720,7 @@ def _prepare_supervised_example(ex: dict, tokenizer: PreTrainedTokenizerBase) ->
return lm_ex


def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase):
def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis):
import levanter.data

# Choose data source based on config
Expand Down Expand Up @@ -746,7 +752,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))


@dataclass
Expand Down Expand Up @@ -799,7 +805,9 @@ def preprocess_chat_example(batch, tokenizer: PreTrainedTokenizerBase) -> dict:
}


def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase) -> AsyncDataset[LmExample]:
def mk_chat_sft_dataset(
config: ChatSFTDatasetConfig, tokenizer: PreTrainedTokenizerBase, Pos: hax.Axis
) -> AsyncDataset[LmExample]:
"""Creates a dataset from JSONL files containing chat format data for SFT."""
source = config.get_shard_source("train")
if source is None:
Expand All @@ -824,7 +832,7 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer, Pos))


@dataclass
Expand Down
1 change: 1 addition & 0 deletions src/levanter/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def tags(self):
def __init__(
self, datasets: Sequence[tuple[AsyncDataset[T], Sequence[str]]], max_examples_per_dataset: Optional[int] = None
):
super().__init__()
self.datasets = []
tag_index: dict[str, int] = {}
for i, (dataset, tags) in enumerate(datasets):
Expand Down
2 changes: 1 addition & 1 deletion src/levanter/main/train_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def main(config: TrainLmConfig):

if config.supervised_data is not None:
logger.info("Using supervised data")
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer), "")]
supervised_eval = [(levanter.data.text.mk_supervised_dataset(config.supervised_data, tokenizer, Pos), "")]
# TODO Add tags
cb = levanter.eval.cb_tagged_lm_evaluate(
EvalBatch,
Expand Down
1 change: 1 addition & 0 deletions src/levanter/store/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__(
ledger: Optional["CacheLedger"],
_broker, # handle of _TreeStoreCacheBuilder
):
super().__init__()
self.cache_dir = cache_dir
self.ledger = ledger
self._was_already_finished = ledger is not None and ledger.is_finished
Expand Down
1 change: 1 addition & 0 deletions tests/test_doremi.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def platform_of_array(x):

class LogitDataset(AsyncDataset[Example]):
def __init__(self, W, noise, x_mask, x_bias, *, key):
super().__init__()
self.W = W
self.noise = noise
self.x_mask = x_mask
Expand Down
2 changes: 2 additions & 0 deletions tests/test_new_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def test_local_batched_data_loading_model_axis_1():

class StructuredDataset(AsyncDataset):
def __init__(self, seq_len):
super().__init__()
self.seq_len = seq_len
self.begin = 0
self.end = 256
Expand Down Expand Up @@ -138,6 +139,7 @@ def test_structured_batches_model_axis_2():

class StructuredDatasetWithNames(AsyncDataset):
def __init__(self, Height: Axis, Width: Axis, begin, end, stride):
super().__init__()
self.Height = Height
self.Width = Width
self.begin = begin
Expand Down

0 comments on commit cf222d4

Please sign in to comment.