diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 73d9e0e655..c5cfa4fbe3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( + DataCollatorWithFlattening, EarlyStoppingCallback, Trainer, TrainerCallback, @@ -1981,6 +1982,7 @@ def build_collator( V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + DataCollatorWithFlattening, RewardDataCollatorWithPadding, ] ] @@ -2003,6 +2005,8 @@ def build_collator( collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template + elif self.cfg.flash_attention: + collator = DataCollatorWithFlattening else: collator = DataCollatorForSeq2Seq