From fbba9bb7e8a08a16e6a99650b1f8ff239579427b Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Tue, 17 Dec 2024 18:44:47 +0000 Subject: [PATCH] differential flash attention 2; cleanup --- model-out/eval_summary.csv | 6 + outputs | 1 + .../integrations/convert_diff_transformer.py | 66 +++-- src/axolotl/cli/utils.py | 12 +- src/axolotl/common/cli.py | 1 + .../__init__.py | 0 .../convert.py | 46 ++-- .../differential_attention.py} | 236 ++++++++++++++---- .../monkeypatch/attention/differential.py | 7 +- 9 files changed, 269 insertions(+), 106 deletions(-) create mode 100644 model-out/eval_summary.csv create mode 120000 outputs rename src/axolotl/integrations/{diff_transformer => differential_transformer}/__init__.py (100%) rename src/axolotl/integrations/{diff_transformer => differential_transformer}/convert.py (78%) rename src/axolotl/integrations/{diff_transformer/multihead_diffattn.py => differential_transformer/differential_attention.py} (64%) diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv new file mode 100644 index 0000000000..ccbe73358c --- /dev/null +++ b/model-out/eval_summary.csv @@ -0,0 +1,6 @@ +metric,training,validation +loss,1.8773103952407837,1.915901780128479 +model_preparation_time,0.0051,0.0051 +runtime,89.7635,8.9565 +samples_per_second,20.053,22.33 +steps_per_second,20.053,22.33 diff --git a/outputs b/outputs new file mode 120000 index 0000000000..be3c4a823f --- /dev/null +++ b/outputs @@ -0,0 +1 @@ +/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 1cbf619c82..6eb00452b1 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -14,7 +14,9 @@ from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention +from axolotl.integrations.differential_transformer.convert import ( + convert_to_diff_attention, +) LOG = logging.getLogger("axolotl.cli.convert_attention") @@ -74,7 +76,11 @@ def convert_diff_transformer(cfg, cli_args, config_path): # Convert attention LOG.info("Converting to differential attention...") try: - model = convert_to_diff_attention(model, cli_args.zero_init) + model = convert_to_diff_attention( + model=model, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + ) model.to(cfg.device, dtype=cfg.torch_dtype) except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) @@ -130,43 +136,35 @@ def convert_diff_transformer(cfg, cli_args, config_path): + Fore.RESET ) else: - if cli_args.zero_init: - LOG.info( - Fore.RED - + "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - + Fore.RESET - ) + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{orig_text}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{conv_text}\n" + + "*" * 50 + + "\n" + ) + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) else: LOG.info( Fore.YELLOW - + "Generations do not match.\n" - + "Original generation:\n" - + "*" * 50 - + "\n" - + f"{orig_text}\n" - + "*" * 50 - + "\n" - + "Converted generation:\n" - + "*" * 50 - + "\n" - + f"{conv_text}\n" - + "*" * 50 - + "\n" - + "However, this is expected since --zero-init was not passed." + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + Fore.RESET ) + + return model + except Exception as exc: LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f72..a228ee92a3 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -22,7 +22,6 @@ def decorator(function): # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): field_type = field.type - if get_origin(field_type) is Union and type(None) in get_args(field_type): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) @@ -44,6 +43,7 @@ def decorator(function): default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]): def decorator(function): # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = field.annotation + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + # NOTE: defaults are handled by the pydantic model config classes. + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -66,6 +73,7 @@ def decorator(function): function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 2b25b7f395..9c921e5640 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -62,6 +62,7 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) + sublayer_norm: bool = field(default=True) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/differential_transformer/__init__.py similarity index 100% rename from src/axolotl/integrations/diff_transformer/__init__.py rename to src/axolotl/integrations/differential_transformer/__init__.py diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/differential_transformer/convert.py similarity index 78% rename from src/axolotl/integrations/diff_transformer/convert.py rename to src/axolotl/integrations/differential_transformer/convert.py index bd688fadbe..5620ad1995 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/differential_transformer/convert.py @@ -5,22 +5,35 @@ import torch from torch import nn from transformers import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention -from transformers.models.mistral.modeling_mistral import MistralAttention -from transformers.models.mixtral.modeling_mixtral import MixtralAttention +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaFlashAttention2, + LlamaSdpaAttention, +) -from .multihead_diffattn import ( +from .differential_attention import ( LlamaDifferentialAttention, + LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, ) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +ATTENTION_MAPPING = { + LlamaAttention: LlamaDifferentialAttention, + LlamaSdpaAttention: LlamaDifferentialSdpaAttention, + LlamaFlashAttention2: LlamaDifferentialFlashAttention2, +} + def copy_attention_weights( - old_attn: Union[LlamaAttention, LlamaSdpaAttention], - new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention], + old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2], + new_attn: Union[ + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, + LlamaDifferentialFlashAttention2, + ], zero_init: bool = False, ) -> None: """ @@ -69,31 +82,24 @@ def copy_attention_weights( def convert_to_diff_attention( - model: PreTrainedModel, zero_init: bool + model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" - attention_patterns = ( - LlamaAttention, - LlamaSdpaAttention, - MistralAttention, - MixtralAttention, - ) layer_idx = 0 + # Set sublayer norm as config on the model. + model.config.sublayer_norm = sublayer_norm + def convert_module(module): nonlocal layer_idx # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): - if isinstance(child, attention_patterns): - layer_type = type(child).__name__ - + if isinstance(child, tuple(ATTENTION_MAPPING.keys())): # Choose appropriate differential attention class - if isinstance(child, LlamaSdpaAttention): - attention_class = LlamaDifferentialSdpaAttention - else: - attention_class = LlamaDifferentialAttention + attention_class = ATTENTION_MAPPING[type(child)] + layer_type = type(child).__name__ logger.info( f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" ) diff --git a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py b/src/axolotl/integrations/differential_transformer/differential_attention.py similarity index 64% rename from src/axolotl/integrations/diff_transformer/multihead_diffattn.py rename to src/axolotl/integrations/differential_transformer/differential_attention.py index 4735564452..2046f08bcf 100644 --- a/src/axolotl/integrations/diff_transformer/multihead_diffattn.py +++ b/src/axolotl/integrations/differential_transformer/differential_attention.py @@ -7,9 +7,11 @@ import torch import torch.nn.functional as F import transformers +from flash_attn.flash_attn_interface import flash_attn_func from torch import nn from transformers.cache_utils import Cache from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, LlamaRotaryEmbedding, apply_rotary_pos_emb, ) @@ -75,14 +77,11 @@ def __init__( self.rope_theta = config.rope_theta self.is_causal = True - dtype = torch.float32 - # For Q1 and Q2 self.q_proj = nn.Linear( self.hidden_size, self.hidden_size * 2, bias=False, - dtype=dtype, ) # For K1 and K2 @@ -90,7 +89,6 @@ def __init__( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, bias=False, - dtype=dtype, ) # Single V projection @@ -98,7 +96,6 @@ def __init__( self.hidden_size, self.hidden_size // self.base_num_heads * self.base_num_kv_heads, bias=False, - dtype=dtype, ) # Output projection @@ -106,28 +103,33 @@ def __init__( self.hidden_size, self.hidden_size, bias=False, - dtype=dtype, ) # Initialize differential attention parameters self.lambda_init = nn.Parameter( - torch.full((), lambda_init_fn(self.layer_idx), dtype=dtype), + torch.full((), lambda_init_fn(self.layer_idx)), requires_grad=False, ) self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim, dtype=dtype).normal_(mean=0, std=0.1) + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) ) self.rotary_emb = LlamaRotaryEmbedding(config=config) + sublayer_norm = getattr(config, "sublayer_norm", True) + self.subln = ( + LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5) + if sublayer_norm + else nn.Identity() + ) def forward( self, @@ -192,39 +194,21 @@ def forward( # Calculate attention scores for both parts # NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor # instead of sqrt(head_dim). This could be set on the class as `self.scaling`. - attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - - # Add this debug step right after computing attention weights in the forward pass - attn_weights1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) - attn_weights2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt( - self.head_dim - ) + attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) + attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) if attention_mask is not None: causal_mask = attention_mask[:, :, :, : k1.shape[-2]] - attn_weights1 = attn_weights1 + causal_mask - attn_weights2 = attn_weights2 + causal_mask + attn1 = attn1 + causal_mask + attn2 = attn2 + causal_mask - # Apply softmax separately as per paper - attn_weights1 = F.softmax(attn_weights1, dim=-1, dtype=torch.float32).type_as( - attn_weights1 - ) - attn_weights2 = F.softmax(attn_weights2, dim=-1, dtype=torch.float32).type_as( - attn_weights2 - ) - attn_weights1 = F.dropout( - attn_weights1, p=self.attention_dropout, training=self.training - ) - attn_weights2 = F.dropout( - attn_weights2, p=self.attention_dropout, training=self.training - ) + # Apply softmax + attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) + attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) + + # Apply dropout + attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training) + attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training) # Calculate lambda lambda_1 = torch.exp( @@ -236,15 +220,13 @@ def forward( lambda_full = lambda_1 - lambda_2 + self.lambda_init # Compute differential attention (following paper's formula) - attn_weights = attn_weights1 - lambda_full * attn_weights2 + attn_weights = attn1 - lambda_full * attn2 # Apply attention weights to values attn = torch.matmul(attn_weights, v) # Apply sublayer norm and scaling - # NOTE(Dan): The differential transformers paper applies sublayer normalization at this - # point, but this is typically done outside of the attention layer. It would look something - # like: `attn = self.subln(attn).type_as(attn)`, using `LlamaRMSNorm` or similar. + attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # Reshape to output @@ -368,20 +350,21 @@ def forward( # Calculate attention using SDPA is_causal = attention_mask is None and q_len > 1 - attn_output1 = F.scaled_dot_product_attention( + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = F.scaled_dot_product_attention( q1, k1, v, attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=dropout_p, is_causal=is_causal, ) - attn_output2 = F.scaled_dot_product_attention( + attn2 = F.scaled_dot_product_attention( q2, k2, v, attn_mask=causal_mask, - dropout_p=self.attention_dropout if self.training else 0.0, + dropout_p=dropout_p, is_causal=is_causal, ) @@ -395,9 +378,10 @@ def forward( lambda_full = lambda_1 - lambda_2 + self.lambda_init # Combine the attention outputs - attn = attn_output1 - lambda_full * attn_output2 + attn = attn1 - lambda_full * attn2 # Apply sublayer norm and scaling + attn = self.subln(attn) attn = attn * (1 - self.lambda_init) # Reshape to output @@ -411,3 +395,157 @@ def forward( past_key_value, ) # Note: can't return attn_weights with SDPA return attn, None, past_key_value + + +class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention): + """Differential Attention implementation using Flash Attention 2. + This implements the same logic as `LlamaDifferentialAttention`, but uses + Flash Attention 2 for more efficient computation. + + This implements a modified attention mechanism that computes the difference between + two attention patterns, scaled by learned lambda parameters. The mechanism helps + reduce noise in the attention weights for irrelevant / less relevant tokens. + + Key components: + - Split head dimension for differential computation + - Learned lambda parameters that control attention scaling + - Sublayer normalization on the attention output + - Flash Attention 2 for efficient attention computation + + See: + - https://arxiv.org/abs/2410.05258 + - https://github.com/microsoft/unilm/tree/master/Diff-Transformer + + Args: + config: Model configuration object containing hidden size, number of heads etc. + layer_idx: Index of this layer in the transformer stack + dtype: Data type for the layer parameters + """ + + def forward( + self, + hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[tuple[torch.Tensor, torch.Tensor]], + ]: + if output_attentions: + transformers.logger.warning_once( + "LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. " + "Falling back to the manual attention implementation." + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + + # Project to Q1,Q2 and K1,K2 + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Split into Q1,Q2 and K1,K2 + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape Q1,Q2 for attention + q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) + # Reshape K1,K2 for attention + k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + # Reshape V + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + + # Apply rotary embeddings + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q_len, device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads to match Q heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + q1 = q1.transpose(1, 2) + q2 = q2.transpose(1, 2) + k1 = k1.transpose(1, 2) + k2 = k2.transpose(1, 2) + v = v.transpose(1, 2) + + # Calculate attention using Flash Attention + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = flash_attn_func( + q1, + k1, + v, + dropout_p=dropout_p, + causal=True, + ) + attn2 = flash_attn_func( + q2, + k2, + v, + dropout_p=dropout_p, + causal=True, + ) + + attn1 = attn1.transpose(1, 2) + attn2 = attn2.transpose(1, 2) + + # Calculate lambda + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + # Combine the attention outputs + attn = attn1 - lambda_full * attn2 + + # Apply sublayer norm and scaling + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + + # Reshape to output + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn = self.o_proj(attn) + + if output_attentions: + return ( + attn, + None, + past_key_value, + ) # Note: can't return attn_weights with Flash Attention + return attn, None, past_key_value diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index 037a6f0bd2..36e3821af6 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -3,8 +3,9 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES -from axolotl.integrations.diff_transformer.multihead_diffattn import ( +from axolotl.integrations.differential_transformer.differential_attention import ( LlamaDifferentialAttention, + LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, ) @@ -15,6 +16,9 @@ def patch_llama_attention_classes(): # Add our attention class to the registry LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + LLAMA_ATTENTION_CLASSES[ + "differential_flash_attention_2" + ] = LlamaDifferentialFlashAttention2 @classmethod def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument @@ -28,6 +32,7 @@ def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument "flash_attention_2", "differential_eager", "differential_sdpa", + "differential_flash_attention_2", ] if attn_implementation not in valid_impls: message = (