Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Малахов Алексей Павлович committed Nov 18, 2024
1 parent 3169439 commit ef4425e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions turbo_alignment/common/tf/loaders/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,23 +56,23 @@ def load_model(
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
swiglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

apply_liger_kernel_to_gemma2(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
geglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

apply_liger_kernel_to_qwen2(
rope=model_settings.liger_kernels_settings.use_rope,
cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy,
swiglu=model_settings.liger_kernels_settings.use_mlp,
rms_norm=model_settings.liger_kernels_settings.use_rms_norm,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy,
fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy,
)

model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained(
Expand Down Expand Up @@ -119,3 +119,4 @@ def load_model(
model.base_model.model.score.weight.requires_grad = True

return model

4 changes: 2 additions & 2 deletions turbo_alignment/settings/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class ModelType(str, Enum):
class LigerKernelSettings(ExtraFieldsNotAllowedBaseModel):
use_rope: bool = True
use_cross_entropy: bool = False
use_fused_liner_cross_entropy: bool = False
use_fused_linear_cross_entropy: bool = False
use_mlp: bool = True
use_rms_norm: bool = False

@model_validator(mode='after')
def correct_cross_entopy_kernels(self) -> 'LigerKernelSettings':
if self.use_fused_liner_cross_entropy and self.use_cross_entropy:
if self.use_fused_linear_cross_entropy and self.use_cross_entropy:
raise ValueError(
'You cannot use both FusedLinearCrossEntropy and CrossEntropy kernels. '
'FusedLinearCrossEntropy is preferred if possible.'
Expand Down

0 comments on commit ef4425e

Please sign in to comment.