Skip to content

Commit

Permalink
added modeling code; cleanup + refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Dec 23, 2024
1 parent a1a3f1d commit c6def27
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 34 deletions.
3 changes: 2 additions & 1 deletion src/axolotl/integrations/diff_transformer/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def copy_attention_weights(
new_attn.q_proj.weight.data.copy_(new_q)

# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
old_kv_size = old_attn.k_proj.weight.data.size(0)
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
Expand Down Expand Up @@ -99,6 +99,7 @@ def convert_module(module):
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
child_class_name = type(child).__name__

if child_class_name in [k.__name__ for k in ATTENTION_MAPPING]:
# Find matching attention class by name
for orig_class, diff_class in ATTENTION_MAPPING.items():
Expand Down
100 changes: 67 additions & 33 deletions src/axolotl/integrations/diff_transformer/diff_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
Expand All @@ -17,7 +16,14 @@
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)

try:
from flash_attn.flash_attn_interface import flash_attn_func

FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand All @@ -35,11 +41,12 @@ def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)


class DifferentialAttentionBase(nn.Module):
class LlamaDifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""

def __init__(self, config: Any, layer_idx: int):
super().__init__()
self.config = config
self._init_config(config, layer_idx)
self._init_projections()
self._init_differential_params()
Expand All @@ -59,9 +66,9 @@ def _init_config(self, config: Any, layer_idx: int):

if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads // 2
self.head_dim = config.hidden_size // config.num_attention_heads
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even.
# implementation, which asserts `self.base_num_heads` is even
self.heads_per_component = self.base_num_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
Expand Down Expand Up @@ -110,36 +117,43 @@ def _init_differential_params(self):
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(
self.max_position_embeddings, self.head_dim, self.rope_theta
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

def _init_normalization(self, config):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
else:
self.subln = nn.Identity()

def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
bsz, q_len, _ = hidden_states.size()

# Project and split
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)

# Reshape
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.heads_per_component, self.value_head_dim).transpose(
1, 2
)

return q1, q2, k1, k2, v

Expand All @@ -148,16 +162,16 @@ def _apply_rotary_embeddings(
):
"""Apply rotary embeddings to queries and keys."""
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q1.size(-2), device=q1.device)
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings

if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)

q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)

Expand Down Expand Up @@ -195,7 +209,7 @@ def _process_attention_output(self, attn, bsz, q_len):
return self.o_proj(attn)


class LlamaDifferentialAttention(DifferentialAttentionBase):
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""Standard implementation of differential attention."""

def forward(
Expand Down Expand Up @@ -237,15 +251,16 @@ def forward(

lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)

attn = self._process_attention_output(attn, bsz, q_len)

if output_attentions:
return attn, attn1 - lambda_full * attn2, past_key_value
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value


class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""

# pylint: disable=duplicate-code
Expand All @@ -262,6 +277,11 @@ def forward(
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
return LlamaDifferentialAttention.forward(
self,
hidden_states,
Expand Down Expand Up @@ -309,9 +329,18 @@ def forward(
return attn, None, past_key_value


class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""

def __init__(self, *args, **kwargs):
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)

super().__init__(*args, **kwargs)

# pylint: disable=duplicate-code
def forward(
self,
Expand All @@ -326,6 +355,11 @@ def forward(
**kwargs, # pylint: disable=unused-argument
):
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
return LlamaDifferentialAttention.forward(
self,
hidden_states,
Expand Down
Loading

0 comments on commit c6def27

Please sign in to comment.