From 566729ec13d2e05032d41128b41c18d61119c3bc Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Sat, 28 Dec 2024 01:10:56 +0000 Subject: [PATCH] fixes and cleanup --- model-out/eval_summary.csv | 6 ++ .../integrations/convert_diff_transformer.py | 2 - .../integrations/diff_transformer/__init__.py | 5 ++ .../diff_transformer/diff_attn.py | 66 ++++++++----------- .../diff_transformer/modeling_diff_attn.py | 57 +++++++++++----- .../test_convert_diff_transformer.py | 3 + 6 files changed, 83 insertions(+), 56 deletions(-) create mode 100644 model-out/eval_summary.csv diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv new file mode 100644 index 000000000..6a8f78af7 --- /dev/null +++ b/model-out/eval_summary.csv @@ -0,0 +1,6 @@ +metric,training,validation +loss,15.633337020874023,15.604033470153809 +model_preparation_time,0.0058,0.0058 +runtime,77.8124,8.4643 +samples_per_second,23.133,23.629 +steps_per_second,23.133,23.629 diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 11a43f6a8..ecde82251 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -18,7 +18,6 @@ from axolotl.integrations.diff_transformer.modeling_diff_attn import ( LlamaDifferentialConfig, LlamaDifferentialForCausalLM, - register_diff_attn, ) from axolotl.utils.yaml import dump_yaml_preserved_order @@ -51,7 +50,6 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): def convert_diff_transformer(cfg, cli_args, config_path): - register_diff_attn() debug_info = {} # Load model and tokenizer diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py index 461ede4fd..1dbae22c4 100644 --- a/src/axolotl/integrations/diff_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -10,5 +10,10 @@ class DifferentialTransformerPlugin(BasePlugin): """Plugin for differential transformer integration with Axolotl.""" + def __init__(self): + from .modeling_diff_attn import register_diff_attn + + register_diff_attn() + def get_input_args(self): return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py index 5ae503464..a03a3fb00 100644 --- a/src/axolotl/integrations/diff_transformer/diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -46,27 +46,22 @@ class LlamaDifferentialAttentionBase(nn.Module): def __init__(self, config: Any, layer_idx: int): super().__init__() + self.config = config - self._init_config(config, layer_idx) + self._init_config(layer_idx) self._init_projections() self._init_differential_params() - self._init_normalization(config) + self._init_normalization() - def _init_config(self, config: Any, layer_idx: int): + def _init_config(self, layer_idx: int): """Initialize configuration parameters.""" - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.head_dim = config.hidden_size // config.num_attention_heads - self.base_num_heads = config.num_attention_heads - self.base_num_kv_heads = config.num_key_value_heads + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + self.base_num_heads = self.config.num_attention_heads + self.base_num_kv_heads = self.config.num_key_value_heads self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads self.layer_idx = layer_idx - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.split_heads = config.split_heads - if config.split_heads: + if self.config.split_heads: # Split heads mode - single projections # NOTE: This rounds down `base_num_heads / 2` as opposed to the original # implementation, which asserts `self.base_num_heads` is even @@ -81,31 +76,29 @@ def _init_config(self, config: Any, layer_idx: int): def _init_projections(self): """Initialize Q, K, V projections.""" - if self.split_heads: + if self.config.split_heads: # Split heads mode - single projections - q_out_dim = self.hidden_size - k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads + q_out_dim = self.config.hidden_size + k_out_dim = self.head_dim * self.base_num_kv_heads else: # Double projection mode - q_out_dim = self.hidden_size * 2 - k_out_dim = ( - self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 - ) + q_out_dim = self.config.hidden_size * 2 + k_out_dim = self.head_dim * self.base_num_kv_heads * 2 self.q_proj = nn.Linear( - self.hidden_size, q_out_dim, bias=self.config.attention_bias + self.config.hidden_size, q_out_dim, bias=self.config.attention_bias ) self.k_proj = nn.Linear( - self.hidden_size, k_out_dim, bias=self.config.attention_bias + self.config.hidden_size, k_out_dim, bias=self.config.attention_bias ) self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + self.config.hidden_size, + self.head_dim * self.base_num_kv_heads, bias=self.config.attention_bias, ) self.o_proj = nn.Linear( self.base_num_heads * self.head_dim, - self.hidden_size, + self.config.hidden_size, bias=self.config.attention_bias, ) @@ -129,11 +122,11 @@ def _init_differential_params(self): ) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) - def _init_normalization(self, config): + def _init_normalization(self): """Initialize normalization layers.""" - sublayer_norm = getattr(config, "sublayer_norm", True) + sublayer_norm = getattr(self.config, "sublayer_norm", True) if sublayer_norm: - self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps) + self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps) else: self.subln = nn.Identity() @@ -148,7 +141,6 @@ def _prepare_attention_inputs(self, hidden_states: torch.Tensor): q1, q2 = q.chunk(2, dim=-1) k1, k2 = k.chunk(2, dim=-1) - # Reshape q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( 1, 2 ) @@ -161,9 +153,7 @@ def _prepare_attention_inputs(self, hidden_states: torch.Tensor): k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( 1, 2 ) - v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose( - 1, 2 - ) + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) return q1, q2, k1, k2, v @@ -198,6 +188,8 @@ def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): k1 = repeat_kv(k1, self.num_key_value_groups) k2 = repeat_kv(k2, self.num_key_value_groups) v = repeat_kv(v, self.num_key_value_groups) + if self.config.split_heads: + v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1) return k1, k2, v @@ -215,7 +207,7 @@ def _process_attention_output(self, attn, bsz, q_len): """Process and project attention output.""" attn = self.subln(attn) attn = attn * (1 - self.lambda_init) - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) return self.o_proj(attn) @@ -255,7 +247,7 @@ def forward( attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) - dropout_p = self.attention_dropout if self.training else 0.0 + dropout_p = self.config.attention_dropout if self.training else 0.0 attn1 = F.dropout(attn1, p=dropout_p, training=self.training) attn2 = F.dropout(attn2, p=dropout_p, training=self.training) @@ -318,7 +310,7 @@ def forward( None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]] ) is_causal = attention_mask is None and q_len > 1 - dropout_p = self.attention_dropout if self.training else 0.0 + dropout_p = self.config.attention_dropout if self.training else 0.0 if q1.device.type == "cuda" and causal_mask is not None: q1, q2 = q1.contiguous(), q2.contiguous() @@ -396,9 +388,9 @@ def forward( k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2) v = v.transpose(1, 2) - dropout_p = self.attention_dropout if self.training else 0.0 + dropout_p = self.config.attention_dropout if self.training else 0.0 - if self.split_heads: + if self.config.split_heads: v1, v2 = v.chunk(2, dim=-1) attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True) attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True) diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index fe702403e..e41fd1fdb 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -6,15 +6,10 @@ import torch from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LlamaForCausalLM, - LlamaModel, - LlamaPreTrainedModel, -) +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from .diff_attn import ( LlamaDifferentialAttention, - LlamaDifferentialAttentionBase, LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, ) @@ -46,17 +41,6 @@ def __init__( } -class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel): - """Base class for differential LLaMA models.""" - - config_class = LlamaDifferentialConfig - base_model_prefix = "llama_differential" - - def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)): - module.gradient_checkpointing = value - - class LlamaDifferentialModel(LlamaModel): """LlamaModel with differential attention.""" @@ -222,6 +206,37 @@ def __init__(self, config): super().__init__(config) self.model = LlamaDifferentialModel(config) + # pylint: disable=protected-access + @classmethod + def _autoset_attn_implementation( + cls, config, **kwargs + ): # pylint: disable=unused-argument + config._attn_implementation_autoset = True + attn_implementation = getattr(config, "_attn_implementation", None) + + # Map standard types to differential types if mapping exists + if attn_implementation in config._attn_implementations: + config._attn_implementation = config._attn_implementations[ + attn_implementation + ] + return config + + # If no mapping, validate it's a valid differential type + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message) + + return config + @classmethod def from_llama( cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None @@ -257,3 +272,11 @@ def register_diff_attn(): # Register models AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel) AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM) + + from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + + LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention + LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + LLAMA_ATTENTION_CLASSES[ + "differential_flash_attention_2" + ] = LlamaDifferentialFlashAttention2 diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index e616a8ef1..e1ad31fdd 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -130,6 +130,9 @@ def test_conversion_cli_repoduce_attentions( ) def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): output_dir = tmp_path / "converted" + + # Smallest model with an even number of attention heads + base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" base_config["output_dir"] = str(output_dir) base_config[attention] = True