diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index 594970716..4b97bfe10 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -96,6 +96,7 @@ def __init__(self, config: LlamaDifferentialConfig): ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # pylint: disable=duplicate-code def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -322,6 +323,7 @@ def __init__( config.hidden_size, eps=config.rms_norm_eps ) + # pylint: disable=duplicate-code def forward( self, hidden_states: torch.Tensor,