Skip to content

Commit

Permalink
fix epochs in type signature, fix type checker
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Nov 7, 2024
1 parent 0f94ff2 commit 5aa7e23
Showing 1 changed file with 27 additions and 12 deletions.
39 changes: 27 additions & 12 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from levanter.store.cache import CacheOptions, TreeCache
from levanter.store.jagged_array import JaggedArrayStore
from levanter.store.tree_store import TreeStore
from levanter.utils import fsspec_utils
from levanter.utils.fsspec_utils import expand_glob
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer

Expand Down Expand Up @@ -616,7 +617,12 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase:

@abc.abstractmethod
def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray]
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray],
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:
pass

Expand Down Expand Up @@ -717,7 +723,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split)
else:
# Using local files
validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)]
validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)]
if not validation_urls:
raise ValueError("Must specify either hf_dataset_name or validation_urls")
dataset = levanter.data.datasource_from_jsonl(validation_urls)
Expand All @@ -735,12 +741,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
output_exemplar=output_exemplar,
)

dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
Expand Down Expand Up @@ -811,14 +817,14 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken
)

# Cache the processed data
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

# Ensure padding token is set (needed by _prepare_supervised_example)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
Expand All @@ -833,18 +839,19 @@ def train_set(
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: int = 0,
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:

ds = self.token_seq_dataset("train", seq_len, monitors)
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)
ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors)

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)

if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, int) and self.shuffle > 0:
Expand Down Expand Up @@ -989,11 +996,19 @@ def __post_init__(self):
)

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray]
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray],
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:
doc_caches = self.build_caches("train", monitors=monitors)
token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()}

if epochs:
raise ValueError("Epochs are not supported for mixture datasets")

if key is None:
key = jax.random.PRNGKey(0)

Expand Down

0 comments on commit 5aa7e23

Please sign in to comment.