diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c6798..a191924f1 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel): text_column: Optional[str] = "text" type: Optional[str] = "pretrain" trust_remote_code: Optional[bool] = False + skip: Optional[int] = None class UserDefinedPrompterType(BaseModel): diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 3e784ca3e..57f63f88e 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -88,11 +88,13 @@ def prepare_dataset(cfg, tokenizer, processor=None): path = cfg.pretraining_dataset split = "train" name = None + skip = 0 if isinstance(cfg.pretraining_dataset, list) and isinstance( cfg.pretraining_dataset[0], dict ): path = cfg.pretraining_dataset[0]["path"] name = cfg.pretraining_dataset[0]["name"] + skip = cfg.pretraining_dataset[0]["skip"] if "split" in cfg.pretraining_dataset[0]: split = cfg.pretraining_dataset[0]["split"] @@ -104,8 +106,12 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg.pretraining_dataset[0]["type"] or "pretrain", ) + iter_ds = load_dataset(path, streaming=True, split=split, name=name) + if skip: + LOG.info(f"Skipping {skip} samples from the dataset") + iter_ds = iter_ds.skip(skip) train_dataset = wrap_pretraining_dataset( - load_dataset(path, streaming=True, split=split, name=name), + iter_ds, tokenizer, cfg, ds_wrapper_partial,