From 62f2092f77d5dadb5d665906ca9b81515c1ae231 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 27 Dec 2024 21:24:16 +0000 Subject: [PATCH] changes --- .../integrations/convert_diff_transformer.py | 16 ++-- .../diff_transformer/modeling_diff_attn.py | 92 ++++++++++++++++--- 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 28cc87bbd..11a43f6a8 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -18,6 +18,7 @@ from axolotl.integrations.diff_transformer.modeling_diff_attn import ( LlamaDifferentialConfig, LlamaDifferentialForCausalLM, + register_diff_attn, ) from axolotl.utils.yaml import dump_yaml_preserved_order @@ -50,6 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): def convert_diff_transformer(cfg, cli_args, config_path): + register_diff_attn() debug_info = {} # Load model and tokenizer @@ -82,15 +84,13 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) try: - model = LlamaDifferentialForCausalLM.from_llama( - model, - LlamaDifferentialConfig( - **model.config.__dict__, - zero_init=cli_args.zero_init, - sublayer_norm=cli_args.sublayer_norm, - split_heads=cli_args.split_heads, - ), + config = LlamaDifferentialConfig( + **model.config.__dict__, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + split_heads=cli_args.split_heads, ) + model = LlamaDifferentialForCausalLM.from_llama(model, config) model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index cf75ff37a..0dc58ab2f 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -1,6 +1,7 @@ """Modeling for differential transformers.""" -from typing import Optional +import logging +from typing import Optional, Union import torch from transformers import AutoConfig, AutoModel, AutoModelForCausalLM @@ -18,6 +19,8 @@ LlamaDifferentialSdpaAttention, ) +logger = logging.getLogger(__name__) + class LlamaDifferentialConfig(LlamaConfig): """Configuration class for Differential LLaMA model.""" @@ -55,26 +58,85 @@ def _set_gradient_checkpointing(self, module, value=False): class LlamaDifferentialModel(LlamaModel): """LlamaModel with differential attention.""" + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + def __init__(self, config): super().__init__(config) + # Handle attention implementation + attn_impl = config._attn_implementation or "eager" + if attn_impl in config._attn_implementations: + attn_impl = config._attn_implementations[attn_impl] + + # Validate attention implementation + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_impl not in valid_impls: + raise ValueError(f"Invalid attention implementation: {attn_impl}") + # Replace standard attention with differential attention in each layer + attn_classes = { + "differential_eager": LlamaDifferentialAttention, + "differential_sdpa": LlamaDifferentialSdpaAttention, + "differential_flash_attention_2": LlamaDifferentialFlashAttention2, + } + attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention) + for idx, layer in enumerate(self.layers): - attn_impl = config._attn_implementation or "eager" - if attn_impl == "eager": - layer.self_attn = LlamaDifferentialAttention(config, idx) - elif attn_impl == "sdpa": - layer.self_attn = LlamaDifferentialSdpaAttention(config, idx) - elif attn_impl == "flash_attention_2": - layer.self_attn = LlamaDifferentialFlashAttention2(config, idx) + layer.self_attn = attn_class(config, idx) + + # pylint: disable=protected-access + @classmethod + def _autoset_attn_implementation( + cls, config, **kwargs + ): # pylint: disable=unused-argument + config._attn_implementation_autoset = True + attn_implementation = getattr(config, "_attn_implementation", None) + + # Map standard types to differential types if mapping exists + if attn_implementation in config._attn_implementations: + config._attn_implementation = config._attn_implementations[ + attn_implementation + ] + return config + + # If no mapping, validate it's a valid differential type + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message) + + return config @classmethod def from_llama( - cls, model: LlamaModel, config: Optional[LlamaDifferentialConfig] = None + cls, + model: Union[LlamaModel, LlamaForCausalLM], + config: Optional[LlamaDifferentialConfig] = None, ) -> "LlamaDifferentialModel": """Convert a LlamaModel to use differential attention.""" + logger.info(f"Converting {type(model).__name__} to {cls.__name__}") + + # Handle LlamaForCausalLM + if isinstance(model, LlamaForCausalLM): + model = model.model + if config is None: config = LlamaDifferentialConfig(**model.config.__dict__) + logger.debug(f"Created config: {config}") # Validate head counts if using split heads mode if config.split_heads: @@ -92,10 +154,14 @@ def from_llama( new_model = cls(config) # Copy all weights except attention + logger.debug("Copying embeddings and norm") new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) new_model.norm.load_state_dict(model.norm.state_dict()) - for new_layer, old_layer in zip(new_model.layers, model.layers): + logger.debug("Copying layer weights") + for layer_idx, (new_layer, old_layer) in enumerate( + zip(new_model.layers, model.layers) + ): # Copy everything except attention weights new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) new_layer.input_layernorm.load_state_dict( @@ -109,7 +175,6 @@ def from_llama( new_layer.self_attn.v_proj.load_state_dict( old_layer.self_attn.v_proj.state_dict() ) - print(old_layer.self_attn.o_proj.weight.shape) new_layer.self_attn.o_proj.load_state_dict( old_layer.self_attn.o_proj.state_dict() ) @@ -119,6 +184,9 @@ def from_llama( old_k_size = old_layer.self_attn.k_proj.weight.size(0) if not config.split_heads: + logger.debug( + f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}" + ) new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_( old_layer.self_attn.q_proj.weight.data ) @@ -127,6 +195,7 @@ def from_llama( ) if config.zero_init: + logger.debug(f"Layer {layer_idx}: Zero initializing") # Zero out components as needed with torch.no_grad(): new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() @@ -137,6 +206,7 @@ def from_llama( new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_init.zero_() + logger.info("Conversion complete") return new_model