Skip to content

Commit

Permalink
multipack w batch sampler (#795)
Browse files Browse the repository at this point in the history
* test batch sampler w varying batch lens

* wip

* multipack batchsampler wip

* wip

* fix for prepare data loader to get correct # of steps based on gpues

* lint and clean up

* calculate len estimate

* fix total num steps calc

* add options for dataloader_num_workers and dataloader_pin_memory

* remove gitbook

* support prefetch_factor for dataloader optimization

* fix the kwarg
  • Loading branch information
winglian authored Nov 8, 2023
1 parent 6dc68a6 commit 641e6f7
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 94 deletions.
1 change: 0 additions & 1 deletion gitbook/README.md

This file was deleted.

4 changes: 0 additions & 4 deletions gitbook/SUMMARY.md

This file was deleted.

3 changes: 0 additions & 3 deletions gitbook/small-dev-details.md

This file was deleted.

152 changes: 105 additions & 47 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import importlib
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
Expand All @@ -18,9 +17,9 @@
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from transformers.trainer_utils import seed_worker

from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
Expand All @@ -31,8 +30,9 @@
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.samplers import MultipackBatchSampler
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup

try:
Expand Down Expand Up @@ -102,6 +102,10 @@ class AxolotlTrainingArguments(TrainingArguments):
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)


class AxolotlTrainer(Trainer):
Expand Down Expand Up @@ -145,46 +149,69 @@ def create_scheduler(
return self.lr_scheduler

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
if self.args.sample_packing:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=(
self.train_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()

def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=(
eval_dataset.data.column("position_ids")
.to_pandas()
.apply(lambda x: x[-1] + 1)
.values
),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)

def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
train_dataset = self.train_dataset
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor

sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker

self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()

Expand All @@ -197,18 +224,29 @@ def get_eval_dataloader(
)

eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor

if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last

self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)

Expand All @@ -229,6 +267,8 @@ def get_bench_dataloader(
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
Expand Down Expand Up @@ -493,6 +533,19 @@ def build(self, total_num_steps):
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est

if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor

if self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
Expand Down Expand Up @@ -672,7 +725,7 @@ def build(self, total_num_steps):
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
data_collator=BatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
Expand All @@ -690,4 +743,9 @@ def build(self, total_num_steps):
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)

if self.cfg.deepspeed and self.cfg.sample_packing:
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
"train_micro_batch_size_per_gpu"
] = self.cfg.micro_batch_size

return trainer
27 changes: 27 additions & 0 deletions src/axolotl/utils/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,30 @@ def __call__(self, features, return_tensors=None):
features["decoder_input_ids"] = decoder_input_ids

return features


@dataclass
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
Collator for multipack specific to the using the BatchSampler
"""

def __call__(self, features, return_tensors=None):
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
4 changes: 2 additions & 2 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def prepare_dataset(cfg, tokenizer):
)
if cfg.max_steps:
total_num_steps = min(
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
)
LOG.info(f"Maximum number of steps set at {total_num_steps}")
else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
return train_dataset, eval_dataset, total_num_steps, prompters


Expand Down
4 changes: 4 additions & 0 deletions src/axolotl/utils/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""
axolotl samplers module
"""
from .multipack import MultipackBatchSampler # noqa: F401
Loading

0 comments on commit 641e6f7

Please sign in to comment.