diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index 0dc58ab2f..fe702403e 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -25,6 +25,8 @@ class LlamaDifferentialConfig(LlamaConfig): """Configuration class for Differential LLaMA model.""" + model_type = "llama-differential" + def __init__( self, split_heads: bool = False, @@ -213,6 +215,9 @@ def from_llama( class LlamaDifferentialForCausalLM(LlamaForCausalLM): """LlamaForCausalLM with differential attention.""" + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + def __init__(self, config): super().__init__(config) self.model = LlamaDifferentialModel(config)