Skip to content

Commit

Permalink
fix epochs in type signature, fix type checker (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh authored Nov 7, 2024
1 parent 0f94ff2 commit be80580
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 16 deletions.
15 changes: 11 additions & 4 deletions examples/sft/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def train(config: SFTConfig):
# Create supervised dataset using generic machinery
logger.info("Creating supervised dataset")
if config.dataset_type == DatasetType.CHAT_JSONL:
assert config.chat_train_urls is not None
assert config.supervised_data is not None
chat_config = ChatSFTDatasetConfig(
cache_dir=config.supervised_data.cache_dir,
train_urls=config.chat_train_urls, # No validation in this config
Expand All @@ -106,6 +108,7 @@ def train(config: SFTConfig):
)
train_dataset = mk_chat_sft_dataset(chat_config, tokenizer)
else:
assert config.supervised_data is not None
train_dataset = mk_supervised_dataset(config.supervised_data, tokenizer)
logger.info("Supervised dataset created")
train_dataset = PermutationDataset(train_dataset, data_key)
Expand All @@ -122,7 +125,7 @@ def train(config: SFTConfig):
# 1. Sets the device mesh
# 2. Sets the axis mapping (for fsdp)
# 3. Sets the global metrics tracker
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer:
with Trainer(config.trainer, optimizer, loss_fn=compute_next_token_loss) as trainer: # type: ignore
parameter_axis_mapping = trainer.parameter_axis_mapping

# We have two axis_mappings: one for storing the model and optimizer states, and one for compute
Expand All @@ -141,7 +144,7 @@ def train(config: SFTConfig):
logger.info(f"Loading pretrained model from {converter.reference_checkpoint}")
model: LmHeadModel = converter.load_pretrained(
model_config.model_type, axis_mapping=parameter_axis_mapping, dtype=trainer.mp.param_dtype
)
) # type: ignore
model = hax.named_jit(lambda m: m.resize_vocab(len(tokenizer)))(model)
state = trainer.initial_state(training_key, model=model)
else:
Expand All @@ -163,10 +166,14 @@ def train(config: SFTConfig):
next(loader)

if config.hf_save_path is not None:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
# bit gross to reach this far into the config, but it's fine
if config.trainer.checkpointer.append_run_id_to_base_path:
full_save_path = os.path.join(config.hf_save_path, trainer.run_id)
else:
full_save_path = config.hf_save_path

trainer.add_hook(
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload),
save_hf_checkpoint_callback(full_save_path, converter, upload_to_hf=config.hf_upload or False),
every=config.hf_save_steps,
)

Expand Down
39 changes: 27 additions & 12 deletions src/levanter/data/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from levanter.store.cache import CacheOptions, TreeCache
from levanter.store.jagged_array import JaggedArrayStore
from levanter.store.tree_store import TreeStore
from levanter.utils import fsspec_utils
from levanter.utils.fsspec_utils import expand_glob
from levanter.utils.hf_utils import num_cpus_used_by_tokenizer

Expand Down Expand Up @@ -616,7 +617,12 @@ def the_tokenizer(self) -> PreTrainedTokenizerBase:

@abc.abstractmethod
def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray]
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray],
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:
pass

Expand Down Expand Up @@ -717,7 +723,7 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
dataset = levanter.data.datasource_from_hf(config.hf_dataset_name, split=config.hf_dataset_split)
else:
# Using local files
validation_urls = [url for url_pat in config.validation_urls for url in fsspec_expand_glob(url_pat)]
validation_urls = [url for url_pat in config.validation_urls for url in fsspec_utils.expand_glob(url_pat)]
if not validation_urls:
raise ValueError("Must specify either hf_dataset_name or validation_urls")
dataset = levanter.data.datasource_from_jsonl(validation_urls)
Expand All @@ -735,12 +741,12 @@ def mk_supervised_dataset(config: LMSupervisedDatasetConfig, tokenizer: PreTrain
output_exemplar=output_exemplar,
)

dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
Expand Down Expand Up @@ -811,14 +817,14 @@ def mk_chat_sft_dataset(config: ChatSFTDatasetConfig, tokenizer: PreTrainedToken
)

# Cache the processed data
dataset = dataset.build_or_load_cache(config.cache_dir, await_finished=True)
cached_dataset: AsyncDataset[dict] = dataset.build_or_load_cache(config.cache_dir, await_finished=True)

# Ensure padding token is set (needed by _prepare_supervised_example)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Reuse the supervised prepare function directly
return dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))
return cached_dataset.map(lambda ex: _prepare_supervised_example(ex, tokenizer))


@dataclass
Expand All @@ -833,18 +839,19 @@ def train_set(
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray] = None,
epochs: int = 0,
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:

ds = self.token_seq_dataset("train", seq_len, monitors)
if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)
ds: AsyncDataset[np.ndarray] | None = self.token_seq_dataset("train", seq_len, monitors)

# add epoch flag here.
if ds is None:
raise ValueError("No training set!")

if epochs:
logger.info("Wrapping dataset in epoch dataset")
ds = EpochDataset(ds, max_epochs=epochs)

if self.shuffle is True:
ds = ds.shuffle(key)
elif isinstance(self.shuffle, int) and self.shuffle > 0:
Expand Down Expand Up @@ -989,11 +996,19 @@ def __post_init__(self):
)

def train_set(
self, seq_len: int, monitors: Union[bool, List[MetricsMonitor]] = True, *, key: Optional[PRNGKeyArray]
self,
seq_len: int,
monitors: Union[bool, List[MetricsMonitor]] = True,
*,
key: Optional[PRNGKeyArray],
epochs: Optional[int] = None,
) -> AsyncDataset[np.ndarray]:
doc_caches = self.build_caches("train", monitors=monitors)
token_datasets = {name: TokenSeqDataset(cache, seq_len) for name, cache in doc_caches.items()}

if epochs:
raise ValueError("Epochs are not supported for mixture datasets")

if key is None:
key = jax.random.PRNGKey(0)

Expand Down

0 comments on commit be80580

Please sign in to comment.