Skip to content

Commit

Permalink
use DataCollatorWithFlattening when not sample packing
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 10, 2024
1 parent 6aa31b4 commit 06e7461
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
TrainerCallback,
Expand Down Expand Up @@ -1981,6 +1982,7 @@ def build_collator(
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
Expand All @@ -2003,6 +2005,8 @@ def build_collator(
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
elif self.cfg.flash_attention:
collator = DataCollatorWithFlattening
else:
collator = DataCollatorForSeq2Seq

Expand Down

0 comments on commit 06e7461

Please sign in to comment.