Skip to content

Commit

Permalink
Efficiently get the length of the tokenized docs (#1063)
Browse files Browse the repository at this point in the history
* Efficiently get the length of the tokenized docs

* chore: lint

---------

Co-authored-by: Wing Lian <[email protected]>
  • Loading branch information
RicardoDominguez and winglian authored Jan 8, 2024
1 parent 732851f commit 81d3845
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 27 deletions.
16 changes: 3 additions & 13 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
DataCollatorForSeq2Seq,
MambaDataCollator,
)
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

try:
Expand Down Expand Up @@ -170,12 +170,7 @@ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
Expand All @@ -189,12 +184,7 @@ def _get_eval_sampler(
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=(
eval_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
Expand Down
9 changes: 2 additions & 7 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
Expand Down Expand Up @@ -889,12 +889,7 @@ def encode_packed_pretraining(
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
lengths=get_dataset_lengths(train_dataset),
)

chunked_data = defaultdict(list)
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
axolotl samplers module
"""
from .multipack import MultipackBatchSampler # noqa: F401
from .utils import get_dataset_lengths # noqa: F401
17 changes: 17 additions & 0 deletions src/axolotl/utils/samplers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
helper util to calculate dataset lengths
"""
import numpy as np


def get_dataset_lengths(dataset):
if "length" in dataset.data.column_names:
lengths = np.array(dataset.data.column("length"))
else:
lengths = (
dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
)
return lengths
9 changes: 2 additions & 7 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFDPOTrainerBuilder
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths

LOG = get_logger("axolotl")

Expand Down Expand Up @@ -212,12 +212,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
drop_last=True,
batch_max_len=cfg.micro_batch_size
* (cfg.max_packed_sequence_len or cfg.sequence_len),
lengths=(
train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
lengths=get_dataset_lengths(train_dataset),
)

data_loader = DataLoader(
Expand Down

0 comments on commit 81d3845

Please sign in to comment.