Skip to content

Commit

Permalink
DataCollatorWithFlattening doesn't accept most args/kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 10, 2024
1 parent 06e7461 commit c13a148
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,7 @@ def build_collator(
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
Expand All @@ -2007,12 +2008,16 @@ def build_collator(
kwargs["chat_template"] = training_args.chat_template
elif self.cfg.flash_attention:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
else:
collator = DataCollatorForSeq2Seq

kwargs["return_tensors"] = "pt"

return collator(
self.tokenizer,
return_tensors="pt",
*collator_args,
**kwargs,
)

Expand Down

0 comments on commit c13a148

Please sign in to comment.