Skip to content

Commit

Permalink
add validation for batch flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 12, 2024
1 parent 7bd5478 commit d0d0a3d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
6 changes: 1 addition & 5 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2006,11 +2006,7 @@ def build_collator(
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
elif (
self.cfg.flash_attention
and self.cfg.micro_batch_size > 1
and not self.cfg.sample_packing
):
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
Expand Down
15 changes: 15 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ class Config:
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None

batch_flattening: Optional[bool] = None

# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
Expand Down Expand Up @@ -923,6 +925,19 @@ def check_sample_packing_wo_flash(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
if not data.get("flash_attention"):
raise ValueError("batch_flattening requires flash attention")
if data.get("sample_packing"):
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")

return data

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):
Expand Down

0 comments on commit d0d0a3d

Please sign in to comment.