diff --git a/requirements.txt b/requirements.txt index d0942193f..0373548d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,22 +11,27 @@ liger-kernel==0.4.2 # END section packaging==23.2 + peft==0.14.0 -transformers==4.47.0 +transformers==4.47.1 tokenizers>=0.20.1 -accelerate==1.2.0 +accelerate==1.2.1 datasets==3.1.0 deepspeed==0.16.1 +trl==0.12.1 + +optimum==1.16.2 +hf_transfer +sentencepiece +gradio==3.50.2 + pydantic==2.6.3 addict fire PyYAML>=6.0 requests -sentencepiece wandb einops -optimum==1.16.2 -hf_transfer colorama numba numpy>=1.24.4,<=2.0.1 @@ -36,7 +41,6 @@ scipy scikit-learn==1.4.2 nvidia-ml-py==12.560.30 art -gradio==3.50.2 tensorboard python-dotenv==1.0.1 @@ -45,7 +49,6 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl==0.12.1 zstandard==0.22.0 fastcore @@ -55,5 +58,5 @@ langdetect==1.0.9 immutabledict==4.2.0 antlr4-python3-runtime==4.13.2 -torchao==0.5.0 +torchao==0.7.0 schedulefree==1.3.0 diff --git a/scripts/unsloth_install.py b/scripts/unsloth_install.py index a6570b4e9..bffab4670 100644 --- a/scripts/unsloth_install.py +++ b/scripts/unsloth_install.py @@ -32,5 +32,5 @@ raise RuntimeError(f"Torch = {v} too new!") x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "") print( - f'pip install unsloth-zoo==2024.11.7 && pip install --no-deps "unsloth[{x}]==2024.11.9"' + f'pip install unsloth-zoo==2024.12.1 && pip install --no-deps "unsloth[{x}]==2024.12.4"' ) diff --git a/src/axolotl/monkeypatch/trainer_grad_accum.py b/src/axolotl/monkeypatch/trainer_grad_accum.py index 76fcf57ce..550f00e30 100644 --- a/src/axolotl/monkeypatch/trainer_grad_accum.py +++ b/src/axolotl/monkeypatch/trainer_grad_accum.py @@ -6,6 +6,7 @@ import logging from transformers import LlamaForCausalLM, Trainer +from transformers.modeling_flash_attention_utils import _flash_attention_forward from axolotl.monkeypatch.unsloth_ import detab_code @@ -13,10 +14,7 @@ ORIGINAL_CONTEXT_CODE = """ with self.compute_loss_context_manager(): - if self.model_accepts_loss_kwargs: - loss = self.compute_loss(model, inputs) - else: - loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) + loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) """ PATCHED_CONTEXT_CODE = """ @@ -288,3 +286,23 @@ def patch_training_loop_for_deepspeed_0_16_x(): Trainer._inner_training_loop = ( # pylint: disable=protected-access _fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821 ) + + +def patch_flash_attention_forward(): + """ + monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch + """ + + import transformers.modeling_flash_attention_utils + + def proxy_flash_attention_forward(*args, **kwargs): + kwargs.pop("num_items_in_batch", None) + + return _flash_attention_forward(*args, **kwargs) + + transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access + proxy_flash_attention_forward + ) + transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access + proxy_flash_attention_forward + ) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 11f4c6d0f..523fd76fe 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -380,19 +380,6 @@ def apply_patches(self) -> None: plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) - if self.cfg.fsdp: - from axolotl.monkeypatch.trainer_fsdp_optim import ( - patch_training_loop_for_fsdp, - ) - - patch_training_loop_for_fsdp() - elif self.cfg.deepspeed and self.cfg.gradient_accumulation_steps > 1: - from axolotl.monkeypatch.trainer_grad_accum import ( - patch_training_loop_for_deepspeed_0_16_x, - ) - - patch_training_loop_for_deepspeed_0_16_x() - if self.cfg.gradient_checkpointing == "unsloth": transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper @@ -401,10 +388,12 @@ def apply_patches(self) -> None: if self.cfg.model_config_type == "llama": from axolotl.monkeypatch.trainer_grad_accum import ( + patch_flash_attention_forward, patch_forward_for_ga, patch_training_step_for_ga, ) + patch_flash_attention_forward() patch_forward_for_ga() patch_training_step_for_ga()