Skip to content

Commit

Permalink
fixes and cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 28, 2024
1 parent 3c74e3e commit 566729e
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 56 deletions.
6 changes: 6 additions & 0 deletions model-out/eval_summary.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
metric,training,validation
loss,15.633337020874023,15.604033470153809
model_preparation_time,0.0058,0.0058
runtime,77.8124,8.4643
samples_per_second,23.133,23.629
steps_per_second,23.133,23.629
2 changes: 0 additions & 2 deletions src/axolotl/cli/integrations/convert_diff_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
register_diff_attn,
)
from axolotl.utils.yaml import dump_yaml_preserved_order

Expand Down Expand Up @@ -51,7 +50,6 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):


def convert_diff_transformer(cfg, cli_args, config_path):
register_diff_attn()
debug_info = {}

# Load model and tokenizer
Expand Down
5 changes: 5 additions & 0 deletions src/axolotl/integrations/diff_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,10 @@
class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl."""

def __init__(self):
from .modeling_diff_attn import register_diff_attn

register_diff_attn()

def get_input_args(self):
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
66 changes: 29 additions & 37 deletions src/axolotl/integrations/diff_transformer/diff_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,27 +46,22 @@ class LlamaDifferentialAttentionBase(nn.Module):

def __init__(self, config: Any, layer_idx: int):
super().__init__()

self.config = config
self._init_config(config, layer_idx)
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization(config)
self._init_normalization()

def _init_config(self, config: Any, layer_idx: int):
def _init_config(self, layer_idx: int):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads

if config.split_heads:
if self.config.split_heads:
# Split heads mode - single projections
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even
Expand All @@ -81,31 +76,29 @@ def _init_config(self, config: Any, layer_idx: int):

def _init_projections(self):
"""Initialize Q, K, V projections."""
if self.split_heads:
if self.config.split_heads:
# Split heads mode - single projections
q_out_dim = self.hidden_size
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.hidden_size * 2
k_out_dim = (
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2

self.q_proj = nn.Linear(
self.hidden_size, q_out_dim, bias=self.config.attention_bias
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size, k_out_dim, bias=self.config.attention_bias
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.hidden_size,
self.config.hidden_size,
bias=self.config.attention_bias,
)

Expand All @@ -129,11 +122,11 @@ def _init_differential_params(self):
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)

def _init_normalization(self, config):
def _init_normalization(self):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()

Expand All @@ -148,7 +141,6 @@ def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)

# Reshape
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
Expand All @@ -161,9 +153,7 @@ def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)

return q1, q2, k1, k2, v

Expand Down Expand Up @@ -198,6 +188,8 @@ def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)

return k1, k2, v

Expand All @@ -215,7 +207,7 @@ def _process_attention_output(self, attn, bsz, q_len):
"""Process and project attention output."""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)


Expand Down Expand Up @@ -255,7 +247,7 @@ def forward(
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)

dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)

Expand Down Expand Up @@ -318,7 +310,7 @@ def forward(
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0

if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
Expand Down Expand Up @@ -396,9 +388,9 @@ def forward(
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)

dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0

if self.split_heads:
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
Expand Down
57 changes: 40 additions & 17 deletions src/axolotl/integrations/diff_transformer/modeling_diff_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,10 @@
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel

from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialAttentionBase,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
Expand Down Expand Up @@ -46,17 +41,6 @@ def __init__(
}


class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
"""Base class for differential LLaMA models."""

config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)):
module.gradient_checkpointing = value


class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention."""

Expand Down Expand Up @@ -222,6 +206,37 @@ def __init__(self, config):
super().__init__(config)
self.model = LlamaDifferentialModel(config)

# pylint: disable=protected-access
@classmethod
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)

# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config

# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)

return config

@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
Expand Down Expand Up @@ -257,3 +272,11 @@ def register_diff_attn():
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)

from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES

LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def test_conversion_cli_repoduce_attentions(
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"

# Smallest model with an even number of attention heads
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True

Expand Down

0 comments on commit 566729e

Please sign in to comment.