From c13a14850f2ea661c22ef5f94c8d9e1ad7dd0fb9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 08:43:57 -0500 Subject: [PATCH] DataCollatorWithFlattening doesn't accept most args/kwargs --- src/axolotl/core/trainer_builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c5cfa4fbe..c730079d2 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1986,6 +1986,7 @@ def build_collator( RewardDataCollatorWithPadding, ] ] + collator_args = [self.tokenizer] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding if "max_length" in kwargs: @@ -2007,12 +2008,16 @@ def build_collator( kwargs["chat_template"] = training_args.chat_template elif self.cfg.flash_attention: collator = DataCollatorWithFlattening + collator_args.pop(0) + kwargs.pop("pad_to_multiple_of", None) + kwargs.pop("padding", None) else: collator = DataCollatorForSeq2Seq + kwargs["return_tensors"] = "pt" + return collator( - self.tokenizer, - return_tensors="pt", + *collator_args, **kwargs, )