Skip to content

Commit

Permalink
move liger kernels to modeling
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Sep 29, 2024
1 parent 3352899 commit ba76ded
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 7 deletions.
4 changes: 1 addition & 3 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
from transformers import PreTrainedModel, PreTrainedTokenizerBase

from turbo_alignment.common.tf.liger_kernels.monkey_patch_liger import (
apply_liger_kernel_to_gemma2,
)
from turbo_alignment.common.tf.loaders.model.registry import (
PeftConfigRegistry,
TransformersAutoModelRegistry,
)
from turbo_alignment.modeling.liger_kernels import apply_liger_kernel_to_gemma2
from turbo_alignment.settings.model import (
ModelForPeftSettings,
PreTrainedAdaptersModelSettings,
Expand Down
1 change: 1 addition & 0 deletions turbo_alignment/modeling/liger_kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from monkey_patch_liger import *
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton
import triton.language as tl

from turbo_alignment.common.tf.liger_kernels.utils import (
from turbo_alignment.modeling.liger_kernels.utils import (
calculate_settings,
compare_version,
ensure_contiguous,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from transformers import PretrainedConfig, PreTrainedModel

from turbo_alignment.common.logging import get_project_logger
from turbo_alignment.common.tf.liger_kernels.cross_entropy import LigerCrossEntropyLoss
from turbo_alignment.common.tf.liger_kernels.geglu import LigerGEGLUMLP
from turbo_alignment.common.tf.liger_kernels.rope import liger_rotary_pos_emb
from turbo_alignment.modeling.liger_kernels.cross_entropy import LigerCrossEntropyLoss
from turbo_alignment.modeling.liger_kernels.geglu import LigerGEGLUMLP
from turbo_alignment.modeling.liger_kernels.rope import liger_rotary_pos_emb

logger = get_project_logger()

Expand Down

0 comments on commit ba76ded

Please sign in to comment.