Skip to content

Commit

Permalink
drop monkeypatches
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 13, 2024
1 parent 8a4cae3 commit becca9d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,14 @@ def apply_patches(self) -> None:
if self.cfg.flash_attention:
self.patch_attention()

if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_forward_for_ga,
patch_training_step_for_ga,
)

patch_forward_for_ga()
patch_training_step_for_ga()
# if self.cfg.model_config_type == "llama":
# from axolotl.monkeypatch.trainer_grad_accum import (
# patch_forward_for_ga,
# patch_training_step_for_ga,
# )
#
# patch_forward_for_ga()
# patch_training_step_for_ga()

if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError(
Expand Down
3 changes: 3 additions & 0 deletions tests/patched/test_llama_trainer_ga.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
""""Test module for checking whether the Hugging Face Transformers is working as expected."""
import unittest

import pytest

from axolotl.monkeypatch.trainer_grad_accum import (
check_forward_is_patchable,
check_training_step_is_patchable,
)


@pytest.mark.skip("should be fixed upstream")
class TestTrainerGAIntegration(unittest.TestCase):
"""llama monkeypatch integration tests."""

Expand Down

0 comments on commit becca9d

Please sign in to comment.