Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 27, 2024
1 parent a1aecf9 commit 62f2092
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 19 deletions.
16 changes: 8 additions & 8 deletions src/axolotl/cli/integrations/convert_diff_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
register_diff_attn,
)
from axolotl.utils.yaml import dump_yaml_preserved_order

Expand Down Expand Up @@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):


def convert_diff_transformer(cfg, cli_args, config_path):
register_diff_attn()
debug_info = {}

# Load model and tokenizer
Expand Down Expand Up @@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET
)
try:
model = LlamaDifferentialForCausalLM.from_llama(
model,
LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
),
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
Expand Down
92 changes: 81 additions & 11 deletions src/axolotl/integrations/diff_transformer/modeling_diff_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Modeling for differential transformers."""

from typing import Optional
import logging
from typing import Optional, Union

import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
Expand All @@ -18,6 +19,8 @@
LlamaDifferentialSdpaAttention,
)

logger = logging.getLogger(__name__)


class LlamaDifferentialConfig(LlamaConfig):
"""Configuration class for Differential LLaMA model."""
Expand Down Expand Up @@ -55,26 +58,85 @@ def _set_gradient_checkpointing(self, module, value=False):
class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention."""

config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"

def __init__(self, config):
super().__init__(config)

# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]

# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")

# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)

for idx, layer in enumerate(self.layers):
attn_impl = config._attn_implementation or "eager"
if attn_impl == "eager":
layer.self_attn = LlamaDifferentialAttention(config, idx)
elif attn_impl == "sdpa":
layer.self_attn = LlamaDifferentialSdpaAttention(config, idx)
elif attn_impl == "flash_attention_2":
layer.self_attn = LlamaDifferentialFlashAttention2(config, idx)
layer.self_attn = attn_class(config, idx)

# pylint: disable=protected-access
@classmethod
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)

# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config

# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)

return config

@classmethod
def from_llama(
cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None
cls,
model: Union[LlamaModel, LlamaForCausalLM],
config: Optional[LlamaDifferentialConfig] = None,
) -> "LlamaDifferentialModel":
"""Convert a LlamaModel to use differential attention."""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")

# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model

if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")

# Validate head counts if using split heads mode
if config.split_heads:
Expand All @@ -92,10 +154,14 @@ def from_llama(
new_model = cls(config)

# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())

for new_layer, old_layer in zip(new_model.layers, model.layers):
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
Expand All @@ -109,7 +175,6 @@ def from_llama(
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
print(old_layer.self_attn.o_proj.weight.shape)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
Expand All @@ -119,6 +184,9 @@ def from_llama(
old_k_size = old_layer.self_attn.k_proj.weight.size(0)

if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
Expand All @@ -127,6 +195,7 @@ def from_llama(
)

if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
# Zero out components as needed
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
Expand All @@ -137,6 +206,7 @@ def from_llama(
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()

logger.info("Conversion complete")
return new_model


Expand Down

0 comments on commit 62f2092

Please sign in to comment.