diff --git a/turbo_alignment/common/tf/loaders/model/model.py b/turbo_alignment/common/tf/loaders/model/model.py index ce942a8..94fa96f 100755 --- a/turbo_alignment/common/tf/loaders/model/model.py +++ b/turbo_alignment/common/tf/loaders/model/model.py @@ -56,7 +56,7 @@ def load_model( cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, swiglu=model_settings.liger_kernels_settings.use_mlp, rms_norm=model_settings.liger_kernels_settings.use_rms_norm, - fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, ) apply_liger_kernel_to_gemma2( @@ -64,7 +64,7 @@ def load_model( cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, geglu=model_settings.liger_kernels_settings.use_mlp, rms_norm=model_settings.liger_kernels_settings.use_rms_norm, - fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, ) apply_liger_kernel_to_qwen2( @@ -72,7 +72,7 @@ def load_model( cross_entropy=model_settings.liger_kernels_settings.use_cross_entropy, swiglu=model_settings.liger_kernels_settings.use_mlp, rms_norm=model_settings.liger_kernels_settings.use_rms_norm, - fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_liner_cross_entropy, + fused_linear_cross_entropy=model_settings.liger_kernels_settings.use_fused_linear_cross_entropy, ) model = TransformersAutoModelRegistry.by_name(model_settings.model_type).from_pretrained( @@ -119,3 +119,4 @@ def load_model( model.base_model.model.score.weight.requires_grad = True return model + diff --git a/turbo_alignment/settings/model.py b/turbo_alignment/settings/model.py index 99c87a6..c57a157 100755 --- a/turbo_alignment/settings/model.py +++ b/turbo_alignment/settings/model.py @@ -18,13 +18,13 @@ class ModelType(str, Enum): class LigerKernelSettings(ExtraFieldsNotAllowedBaseModel): use_rope: bool = True use_cross_entropy: bool = False - use_fused_liner_cross_entropy: bool = False + use_fused_linear_cross_entropy: bool = False use_mlp: bool = True use_rms_norm: bool = False @model_validator(mode='after') def correct_cross_entopy_kernels(self) -> 'LigerKernelSettings': - if self.use_fused_liner_cross_entropy and self.use_cross_entropy: + if self.use_fused_linear_cross_entropy and self.use_cross_entropy: raise ValueError( 'You cannot use both FusedLinearCrossEntropy and CrossEntropy kernels. ' 'FusedLinearCrossEntropy is preferred if possible.'