Skip to content

Commit

Permalink
skip over rows in pretraining dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 26, 2024
1 parent 7a38dbe commit b0d3bfb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
skip: Optional[int] = None


class UserDefinedPrompterType(BaseModel):
Expand Down
8 changes: 7 additions & 1 deletion src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset
split = "train"
name = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]

Expand All @@ -104,8 +106,12 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)

iter_ds = load_dataset(path, streaming=True, split=split, name=name)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
iter_ds,
tokenizer,
cfg,
ds_wrapper_partial,
Expand Down

0 comments on commit b0d3bfb

Please sign in to comment.