From c6def2732d8322f9ae65e7b57b73777aa4f11fb6 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 23 Dec 2024 14:14:51 -0500 Subject: [PATCH] added modeling code; cleanup + refactor --- .../integrations/diff_transformer/convert.py | 3 +- .../diff_transformer/diff_attn.py | 100 +++-- .../diff_transformer/modeling_diff_attn.py | 370 ++++++++++++++++++ 3 files changed, 439 insertions(+), 34 deletions(-) create mode 100644 src/axolotl/integrations/diff_transformer/modeling_diff_attn.py diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index d942567d5..298a0232e 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -50,7 +50,7 @@ def copy_attention_weights( new_attn.q_proj.weight.data.copy_(new_q) # For K projection (K1 and K2) - old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads + old_kv_size = old_attn.k_proj.weight.data.size(0) new_k = torch.empty_like(new_attn.k_proj.weight.data) new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1 if zero_init: @@ -99,6 +99,7 @@ def convert_module(module): # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): child_class_name = type(child).__name__ + if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]: # Find matching attention class by name for orig_class, diff_class in ATTENTION_MAPPING.items(): diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index a8d7536dd..cccb0adeb 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F -from flash_attn.flash_attn_interface import flash_attn_func from torch import nn from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import ( @@ -17,7 +16,14 @@ ) logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +LOG = logging.getLogger(__name__) + +try: + from flash_attn.flash_attn_interface import flash_attn_func + + FLASH_ATTENTION_AVAILABLE = True +except ImportError: + FLASH_ATTENTION_AVAILABLE = False def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -35,11 +41,12 @@ def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) -class DifferentialAttentionBase(nn.Module): +class LlamaDifferentialAttentionBase(nn.Module): """Base class for differential attention implementations.""" def __init__(self, config: Any, layer_idx: int): super().__init__() + self.config = config self._init_config(config, layer_idx) self._init_projections() self._init_differential_params() @@ -59,9 +66,9 @@ def _init_config(self, config: Any, layer_idx: int): if config.split_heads: # Split heads mode - single projections - self.head_dim = config.hidden_size // config.num_attention_heads // 2 + self.head_dim = config.hidden_size // config.num_attention_heads # NOTE: This rounds down `base_num_heads / 2` as opposed to the original - # implementation, which asserts `self.base_num_heads` is even. + # implementation, which asserts `self.base_num_heads` is even self.heads_per_component = self.base_num_heads // 2 self.value_head_dim = 2 * self.head_dim else: @@ -110,36 +117,43 @@ def _init_differential_params(self): self.lambda_k2 = nn.Parameter( torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) - self.rotary_emb = LlamaRotaryEmbedding( - self.max_position_embeddings, self.head_dim, self.rope_theta - ) + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def _init_normalization(self, config): """Initialize normalization layers.""" sublayer_norm = getattr(config, "sublayer_norm", True) - self.subln = ( - LlamaRMSNorm(self.value_head_dim, eps=1e-5) - if sublayer_norm - else nn.Identity() - ) + if sublayer_norm: + self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps) + else: + self.subln = nn.Identity() def _prepare_attention_inputs(self, hidden_states: torch.Tensor): """Prepare inputs for attention computation.""" bsz, q_len, _ = hidden_states.size() # Project and split - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) # Reshape - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2) + q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose( + 1, 2 + ) return q1, q2, k1, k2, v @@ -148,16 +162,16 @@ def _apply_rotary_embeddings( ): """Apply rotary embeddings to queries and keys.""" if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q1.size(-2), device=q1.device) + LOG.warning( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) cos, sin = self.rotary_emb(q1, position_ids) else: cos, sin = position_embeddings - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) @@ -195,7 +209,7 @@ def _process_attention_output(self, attn, bsz, q_len): return self.o_proj(attn) -class LlamaDifferentialAttention(DifferentialAttentionBase): +class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): """Standard implementation of differential attention.""" def forward( @@ -237,15 +251,16 @@ def forward( lambda_full = self._compute_lambda(q1) attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) - attn = self._process_attention_output(attn, bsz, q_len) if output_attentions: - return attn, attn1 - lambda_full * attn2, past_key_value + attn_weights = attn1 - lambda_full * attn2 + attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1) + return attn, attn_weights, past_key_value return attn, None, past_key_value -class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): +class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): """SDPA-based implementation of differential attention.""" # pylint: disable=duplicate-code @@ -262,6 +277,11 @@ def forward( **kwargs, # pylint: disable=unused-argument ): if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but " + + "`torch.nn.functional.scaled_dot_product_attention` does not support " + + "`output_attentions=True`. Falling back to the eager attention implementation." + ) return LlamaDifferentialAttention.forward( self, hidden_states, @@ -309,9 +329,18 @@ def forward( return attn, None, past_key_value -class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): +class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): """Flash Attention 2-based implementation of differential attention.""" + def __init__(self, *args, **kwargs): + if not FLASH_ATTENTION_AVAILABLE: + raise ImportError( + "LlamaDifferentialFlashAttention2 requires flash-attn library. " + "Please install with `pip install flash-attn --no-build-isolation`" + ) + + super().__init__(*args, **kwargs) + # pylint: disable=duplicate-code def forward( self, @@ -326,6 +355,11 @@ def forward( **kwargs, # pylint: disable=unused-argument ): if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but " + + "flash attenion does not support `output_attentions=True`. Falling back " + + "to the eager attention implementation." + ) return LlamaDifferentialAttention.forward( self, hidden_states, diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py new file mode 100644 index 000000000..594970716 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -0,0 +1,370 @@ +"""Modeling for differential transformers.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import ( + LlamaMLP, + LlamaModel, + LlamaPreTrainedModel, + LlamaRMSNorm, +) + +from .diff_attn import ( + LlamaDifferentialAttention, + LlamaDifferentialAttentionBase, + LlamaDifferentialFlashAttention2, + LlamaDifferentialSdpaAttention, +) + + +class LlamaDifferentialConfig(LlamaConfig): + """Configuration class for Differential LLaMA model.""" + + def __init__( + self, + split_heads: bool = False, + sublayer_norm: bool = True, + zero_init: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.split_heads = split_heads + self.sublayer_norm = sublayer_norm + self.zero_init = zero_init + self.architectures = ["LlamaDifferentialModel"] + self._attn_implementations = { + "eager": "differential_eager", + "sdpa": "differential_sdpa", + "flash_attention_2": "differential_flash_attention_2", + } + + +class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel): + """Base class for differential LLaMA models.""" + + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)): + module.gradient_checkpointing = value + + +def lambda_init_fn(depth: int) -> float: + """Initialize lambda parameter based on layer depth.""" + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class LlamaDifferentialModel(LlamaDifferentialPreTrainedModel): + """Differential version of the LLaMA model.""" + + def __init__(self, config: LlamaDifferentialConfig): + super().__init__(config) + # Map attn implementations to classes + self.attn_implementation_to_class = { + "differential_eager": LlamaDifferentialAttention, + "differential_sdpa": LlamaDifferentialSdpaAttention, + "differential_flash_attention_2": LlamaDifferentialFlashAttention2, + } + + # Get correct attention implementation + attn_implementation = getattr(config, "_attn_implementation", "eager") + if attn_implementation in config._attn_implementations: + attn_implementation = config._attn_implementations[attn_implementation] + + self.attention_class = self.attn_implementation_to_class.get( + attn_implementation, LlamaDifferentialAttention + ) + + # Initialize model components + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, config.pad_token_id + ) + self.layers = nn.ModuleList( + [ + LlamaDifferentialDecoderLayer( + config=config, layer_idx=i, attention_class=self.attention_class + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # Check if either input_ids or inputs_embeds is provided + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + if input_ids is not None: + batch_size, seq_length = input_ids.shape + device = input_ids.device + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + device = inputs_embeds.device + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Initialize past_key_values if needed + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + # Create attention mask if not provided + if attention_mask is not None: + attention_mask = self._prepare_attention_mask( + attention_mask, (batch_size, seq_length), device + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # Initialize lists to store outputs + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_cache = () if use_cache else None + + for _, (layer, past_key_value) in enumerate(zip(self.layers, past_key_values)): + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore + + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_cache += (layer_outputs[-1],) # type: ignore + + if output_attentions: + all_self_attns += (layer_outputs[1],) # type: ignore + + # Add last hidden state + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states += (hidden_states,) # type: ignore + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _prepare_attention_mask( + self, + attention_mask: torch.Tensor, + input_shape: Tuple[int, int], + device: torch.device, + ) -> torch.Tensor: + """Prepare attention mask for computing attention.""" + # Create causal mask + # [batch_size, seq_length] -> [batch_size, 1, seq_length, seq_length] + combined_attention_mask = None + _, seq_length = input_shape + + if self.config.is_decoder: + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(1, seq_length, 1) + <= seq_ids[None, :, None] + ) + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1:] != (seq_length, seq_length): + causal_mask = causal_mask[:, :seq_length, :seq_length] + + # Extend attention mask + combined_attention_mask = ( + causal_mask[None, None, :, :] * attention_mask[:, None, None, :] + ) + else: + combined_attention_mask = attention_mask[:, None, None, :] + + return combined_attention_mask + + @classmethod + def from_llama( + cls, + llama_model: LlamaModel, + differential_config: Optional[LlamaDifferentialConfig] = None, + ) -> "LlamaDifferentialModel": + """Convert a standard LLaMA model to use differential attention.""" + if differential_config is None: + # pylint: disable=protected-access + differential_config = LlamaDifferentialConfig.from_pretrained( + llama_model.config._name_or_path + ) + + # Create new model + new_model = cls(differential_config) + + # Copy non-attention weights directly + new_model.embed_tokens.load_state_dict(llama_model.embed_tokens.state_dict()) + new_model.norm.load_state_dict(llama_model.norm.state_dict()) + + # Copy layer weights, handling attention layers specially + for new_layer, old_layer in zip(new_model.layers, llama_model.layers): + # Copy self-attention weights with special handling + if differential_config.split_heads: + # Split heads mode + new_layer.self_attn.q_proj.weight.data.copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data.copy_( + old_layer.self_attn.k_proj.weight.data + ) + else: + # Double projection mode - copy weights to positive components + new_layer.self_attn.q_proj.weight.data[ + : differential_config.hidden_size + ].copy_(old_layer.self_attn.q_proj.weight.data) + new_layer.self_attn.k_proj.weight.data[ + : differential_config.hidden_size + ].copy_(old_layer.self_attn.k_proj.weight.data) + + # Zero out relevant parameters for exact equivalence + if differential_config.zero_init: + old_kv_size = old_layer.self_attn.k_proj.weight.data.size(0) + new_layer.self_attn.q_proj.weight.data[ + new_layer.self_attn.hidden_size : + ] = 0 + new_layer.self_attn.k_proj.weight.data[old_kv_size:] = 0 + nn.init.zeros_(new_layer.self_attn.lambda_q1) + nn.init.zeros_(new_layer.self_attn.lambda_k1) + nn.init.zeros_(new_layer.self_attn.lambda_q2) + nn.init.zeros_(new_layer.self_attn.lambda_k2) + nn.init.zeros_(new_layer.self_attn.lambda_init) + + # Copy remaining weights + new_layer.self_attn.v_proj.load_state_dict( + old_layer.self_attn.v_proj.state_dict() + ) + new_layer.self_attn.o_proj.load_state_dict( + old_layer.self_attn.o_proj.state_dict() + ) + + # Copy MLP and layer norm weights + new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) + new_layer.input_layernorm.load_state_dict( + old_layer.input_layernorm.state_dict() + ) + new_layer.post_attention_layernorm.load_state_dict( + old_layer.post_attention_layernorm.state_dict() + ) + + return new_model + + +class LlamaDifferentialDecoderLayer(nn.Module): + """Custom decoder layer for diffrential Llama model.""" + + def __init__( + self, config: LlamaDifferentialConfig, layer_idx: int, attention_class + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = attention_class(config, layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Layer forward pass with differential attention. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) # type: ignore + + if use_cache: + outputs += (present_key_value,) # type: ignore + + return outputs # type: ignore