diff --git a/.gitignore b/.gitignore index 7b604d88c..4d7ba15a1 100644 --- a/.gitignore +++ b/.gitignore @@ -186,3 +186,6 @@ out/ # vim *.swp + +# symlinked to axolotl-artifacts in docker containers +outputs diff --git a/cicd/cicd.sh b/cicd/cicd.sh index 91926127f..b01846e6e 100755 --- a/cicd/cicd.sh +++ b/cicd/cicd.sh @@ -4,7 +4,6 @@ set -e python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ -# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ diff --git a/cicd/multigpu.py b/cicd/multigpu.py index f9bad386a..f92464630 100644 --- a/cicd/multigpu.py +++ b/cicd/multigpu.py @@ -1,6 +1,6 @@ """ - modal application to run axolotl gpu tests in Modal - """ +modal application to run axolotl gpu tests in Modal +""" # pylint: disable=duplicate-code import os diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index d07b10ce3..278a67474 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -202,7 +202,7 @@ def do_inference( ) elif cfg.chat_template: chat_template_str = get_chat_template(cfg.chat_template) - elif cfg.datasets[0].type == "chat_template": + elif cfg.datasets and cfg.datasets[0].type == "chat_template": chat_template_str = get_chat_template_from_config( cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer ) diff --git a/src/axolotl/cli/evaluate.py b/src/axolotl/cli/evaluate.py index 8e99d6f4b..655f3782f 100644 --- a/src/axolotl/cli/evaluate.py +++ b/src/axolotl/cli/evaluate.py @@ -3,7 +3,7 @@ """ import logging from pathlib import Path -from typing import Union +from typing import Dict, Union import fire from dotenv import load_dotenv @@ -23,7 +23,7 @@ LOG = logging.getLogger("axolotl.cli.evaluate") -def do_evaluate(cfg, cli_args) -> None: +def do_evaluate(cfg, cli_args) -> Dict[str, float]: # pylint: disable=duplicate-code print_axolotl_text_art() check_accelerate_default_config() @@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: diff --git a/src/axolotl/cli/integrations/__init__.py b/src/axolotl/cli/integrations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py new file mode 100644 index 000000000..3b0f16ca9 --- /dev/null +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -0,0 +1,208 @@ +"""CLI to convert a transformers model's attention layers to differential attention layers.""" + +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +import yaml +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer +from axolotl.integrations.diff_transformer.modeling_diff_attn import ( + LlamaDifferentialConfig, + LlamaDifferentialForCausalLM, +) +from axolotl.utils.yaml import dump_yaml_preserved_order + +LOG = logging.getLogger(__name__) + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + inputs = tokenizer(prompt, return_tensors="pt") + inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()} + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + +def convert_diff_transformer(cfg, cli_args, config_path): + assert not ( + cli_args.split_heads and cli_args.zero_init + ), "Both `split_heads` and `zero_init` cannot be `True`" + assert not ( + cli_args.zero_init and cli_args.mirror_weights + ), "Both `zero_init` and `mirror_weights` cannot be `True`" + + debug_info = {} + + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + if cli_args.debug: + LOG.info("Testing original model...") + debug_info["orig_time"], debug_info["orig_text"] = test_inference( + model, tokenizer + ) + + try: + # Convert attention + LOG.info("Converting to differential attention...") + + config = LlamaDifferentialConfig( + **model.config.__dict__, + zero_init=cli_args.zero_init, + sublayer_norm=cli_args.sublayer_norm, + split_heads=cli_args.split_heads, + mirror_weights=cli_args.mirror_weights, + ) + model = LlamaDifferentialForCausalLM.from_llama(model, config) + model.to(cfg.device, dtype=cfg.torch_dtype) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + if cli_args.debug: + LOG.info("Testing converted model...") + debug_info["conv_time"], debug_info["conv_text"] = test_inference( + model, tokenizer + ) + + # Save if requested + if cfg.output_dir: + # Save model and tokenizer + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + modified_cfg = yaml.safe_load(file) or {} + + modified_cfg["base_model"] = cfg.output_dir + modified_cfg["diff_attention"] = True + plugin_class = ( + "axolotl.integrations.diff_transformer.DifferentialTransformerPlugin" + ) + if "plugins" in modified_cfg: + modified_cfg["plugins"].append(plugin_class) + else: + modified_cfg["plugins"] = [plugin_class] + + # Write out the updated axolotl config while preserving original ordering / formatting + dump_yaml_preserved_order( + data=modified_cfg, + reference_yaml_path=config_path, + output_path=output_config_path, + ) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") + + if cli_args.debug: + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {debug_info['orig_time']:.2f}s\n" + + f"Converted generation time: {debug_info['conv_time']:.2f}s" + + Fore.RESET + ) + + if debug_info["orig_text"] == debug_info["conv_text"]: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + debug_info["generations_match"] = True + else: + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['conv_text']}\n" + + "*" * 50 + + "\n" + ) + debug_info["generations_match"] = False + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) + debug_info["match_expected"] = True + else: + LOG.info( + Fore.YELLOW + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + + Fore.RESET + ) + debug_info["match_expected"] = False + + return model, debug_info + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_diff_transformer(cfg, cli_args, config) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 14803e43b..d9d3a2135 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,7 +12,12 @@ build_command, fetch_from_github, ) -from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.common.cli import ( + ConvertDiffTransformerCliArgs, + EvaluateCliArgs, + PreprocessCliArgs, + TrainerCliArgs, +) from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -77,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs): """Evaluate a model.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} + # Enable expandable segments for cuda allocation to improve VRAM usage + set_pytorch_cuda_alloc_conf() + if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] if config: @@ -240,6 +248,19 @@ def merge_lora( do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(ConvertDiffTransformerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def convert_diff_transformer(config: str, **kwargs): + """Convert model attention layers to differential attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.integrations.convert_diff_transformer import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index f0e2573f7..c9b609049 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -22,7 +22,6 @@ def decorator(function): # Process dataclass fields in reverse order for correct option ordering for field in reversed(dataclasses.fields(config_class)): field_type = field.type - if get_origin(field_type) is Union and type(None) in get_args(field_type): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) @@ -44,6 +43,7 @@ def decorator(function): default=field.default, help=field.metadata.get("description"), )(function) + return function return decorator @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]): def decorator(function): # Process model fields in reverse order for correct option ordering for name, field in reversed(config_class.model_fields.items()): - if field.annotation == bool: + field_type = field.annotation + if get_origin(field_type) is Union and type(None) in get_args(field_type): + field_type = next( + t for t in get_args(field_type) if not isinstance(t, NoneType) + ) + + # NOTE: defaults are handled by the pydantic model config classes. + if field_type == bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" function = click.option( @@ -66,6 +73,7 @@ def decorator(function): function = click.option( option_name, default=None, help=field.description )(function) + return function return decorator @@ -84,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]: if isinstance(value, bool): if value: cmd.append(f"--{key}") + else: + cmd.append(f"--no{key}") else: cmd.extend([f"--{key}", str(value)]) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index 02ad9201b..8b31b52b5 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -4,7 +4,7 @@ import logging from dataclasses import dataclass, field -from typing import Optional +from typing import Optional, Union import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 from axolotl.logging_config import configure_logging @@ -12,14 +12,12 @@ from axolotl.utils.models import load_model, load_tokenizer configure_logging() -LOG = logging.getLogger("axolotl.common.cli") +LOG = logging.getLogger(__name__) @dataclass class PreprocessCliArgs: - """ - dataclass representing arguments for preprocessing only - """ + """dataclass with arguments for preprocessing only""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -30,9 +28,7 @@ class PreprocessCliArgs: @dataclass class TrainerCliArgs: - """ - dataclass representing the various non-training arguments - """ + """dataclass with various non-training arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) @@ -45,19 +41,28 @@ class TrainerCliArgs: @dataclass class EvaluateCliArgs: - """ - dataclass representing the various evaluation arguments - """ + """dataclass with various evaluation arguments""" debug: bool = field(default=False) debug_text_only: bool = field(default=False) debug_num_examples: int = field(default=0) +@dataclass +class ConvertDiffTransformerCliArgs: + """dataclass with arguments for convert-diff-transformer CLI""" + + debug: bool = field(default=False) + zero_init: bool = field(default=False) + sublayer_norm: bool = field(default=True) + split_heads: bool = field(default=False) + mirror_weights: bool = field(default=False) + + def load_model_and_tokenizer( *, cfg: DictDefault, - cli_args: TrainerCliArgs, + cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs], ): LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") tokenizer = load_tokenizer(cfg) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 176ce4174..3c6a7026a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + This code is duplicated due to HF TrainingArguments not setting output_dir with a default value so it can't be used as a mixin. """ diff --git a/src/axolotl/evaluate.py b/src/axolotl/evaluate.py index acf15e3fc..1c62fc6ab 100644 --- a/src/axolotl/evaluate.py +++ b/src/axolotl/evaluate.py @@ -9,12 +9,11 @@ import torch from accelerate.logging import get_logger -from axolotl.common.cli import TrainerCliArgs +from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging from axolotl.train import TrainDatasetMeta -from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_processor, load_tokenizer +from axolotl.utils.models import load_processor from axolotl.utils.trainer import setup_trainer project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -62,8 +61,9 @@ def evaluate_dataset( return metrics +# pylint: disable=duplicate-code def evaluate( - *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta + *, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta ) -> Dict[str, float]: """ Evaluate a model on training and validation datasets @@ -79,16 +79,11 @@ def evaluate( - The tokenizer - Dictionary of evaluation metrics """ - # pylint: disable=duplicate-code - # Enable expandable segments for cuda allocation to improve VRAM usage - set_pytorch_cuda_alloc_conf() - - # Load tokenizer - LOG.debug( - f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}", - main_process_only=True, - ) - tokenizer = load_tokenizer(cfg) + # Load model + LOG.debug("loading model for evaluation...") + + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model = model.to(cfg.device, dtype=cfg.torch_dtype) # Load processor for multimodal models if needed processor = None @@ -100,12 +95,6 @@ def evaluate( eval_dataset = dataset_meta.eval_dataset total_num_steps = dataset_meta.total_num_steps - # Load model - LOG.debug("loading model for evaluation...") - model, _ = load_model( - cfg, tokenizer, processor=processor, inference=cli_args.inference - ) - # Set up trainer trainer = setup_trainer( cfg, diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b4ffd6758..f7d35fcf8 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -43,10 +43,12 @@ def merge_input_args(): input_args: List[str] = plugin_manager.get_input_args() plugin_classes = [] dynamic_input = "" + for plugin_args in input_args: plugin_module, plugin_cls = plugin_args.rsplit(".", 1) dynamic_input += f"from {plugin_module} import {plugin_cls}\n" plugin_classes.append(plugin_cls) + if dynamic_input: dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" @@ -62,4 +64,5 @@ def merge_input_args(): "AxolotlConfigWCapabilities" ] return AxolotlConfigWCapabilities, AxolotlInputConfig + return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md new file mode 100644 index 000000000..efba1fc39 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -0,0 +1,12 @@ +# Differential Transformer + +### Usage + +**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command. + +```yaml +plugins: + - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin + +diff_attention: true +``` diff --git a/src/axolotl/integrations/diff_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py new file mode 100644 index 000000000..3b98ae246 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -0,0 +1,67 @@ +"""Definition of differential transformer plugin.""" + +import logging +from typing import List + +from transformers import PreTrainedModel, TrainerCallback + +from axolotl.integrations.base import BasePlugin +from axolotl.utils.callbacks.diff_attn import ( + DifferentialAttentionMixingCallback, + DifferentialAttentionMonitorCallback, +) +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerPlugin(BasePlugin): + """Plugin for differential transformer integration with Axolotl.""" + + def __init__(self) -> None: + """ + Constructor for differential transformers plugin. Calls `register_diff_attn` + to register differential attention custom modeling implementation to `AutoConfig` + and `AutoModel`. + """ + from .modeling_diff_attn import register_diff_attn + + register_diff_attn() + + def get_input_args(self) -> str: + """Returns module path to diff transformer plugin args for `axolotl` config.""" + return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" + + # pylint: disable=unused-argument + def add_callbacks_pre_trainer( + self, cfg: DictDefault, model: PreTrainedModel + ) -> List[TrainerCallback]: + """ + Returns `DifferentialAttentionMonitorCallback` to be added to the list of + callbacks for the `axolotl` trainer if wandb usage is enabled. + + Parameters: + cfg: Dictionary mapping `axolotl` config keys to values. + model: The loaded mfodel. + + Returns: + A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`. + """ + callbacks = [] + if cfg.use_wandb: + callbacks.append( + DifferentialAttentionMonitorCallback( + log_every=cfg.diff_attn_log_every, + num_monitor_layers=cfg.diff_attn_num_monitor_layers, + warmup_steps=cfg.diff_attn_warmup_steps, + ) + ) + + if cfg.diff_attn_warmup_steps: + callbacks.append( + DifferentialAttentionMixingCallback( + warmup_steps=cfg.diff_attn_warmup_steps + ) + ) + + return callbacks diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py new file mode 100644 index 000000000..ebd4d03a1 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -0,0 +1,27 @@ +"""Module for handling differential transfomer input arguments.""" + +import logging +from typing import Optional + +from pydantic import BaseModel + +LOG = logging.getLogger(__name__) + + +class DifferentialTransformerArgs(BaseModel): + """ + Input args for differential transformer. + + Attributes: + diff_attention: Whether to use differential attention layers. + diff_attn_log_every: How often to log differential attention statistics. + diff_attn_num_monitor_layers: Number of layers to monitor for attention stats. + diff_attn_warmup_steps: Number of steps to linearly increase negative attention + mixing weight from 0 to 1. If specified, will reach full mixing at this + step. If `None`, negative attention has full weight from the start. + """ + + diff_attention: Optional[bool] = None + diff_attn_log_every: Optional[int] = 100 + diff_attn_num_monitor_layers: Optional[int] = 3 + diff_attn_warmup_steps: Optional[int] = None diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py new file mode 100644 index 000000000..6ee043d8c --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -0,0 +1,694 @@ +"""Re-implemention of differential attention from the Differential Transformer paper +(https://arxiv.org/abs/2410.05258).""" +# pylint: disable=invalid-name + +import logging +import math +from typing import Any + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) + +logging.basicConfig(level=logging.INFO) +LOG = logging.getLogger(__name__) + +try: + from flash_attn.flash_attn_interface import flash_attn_func + + FLASH_ATTENTION_AVAILABLE = True +except ImportError: + FLASH_ATTENTION_AVAILABLE = False + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + Repeats key/value heads to match the number of query heads in multi-head attention. + + Args: + x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`. + n_rep: Number of times to repeat each head. + + Returns: + Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep, + seq_len, head_dim)`. + If `n_rep` is 1, returns the input tensor unchanged. + """ + batch_size, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) + .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) + ) + + +def lambda_init_fn(depth: int) -> float: + """ + Lambda mixing parameter init function from the "Differential Transformer" paper. + + Args: + depth: Index of layer to init lambda parameter. + + Returns: + Lambda initialization value (decreasing with `depth`). + """ + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class LlamaDifferentialAttentionBase(nn.Module): + """ + Base class for differential attention implementations. + + This class implements the core differential attention mechanism used in Llama models. + It supports both split heads and double projection modes for attention computation. + """ + + def __init__(self, config: Any, layer_idx: int): + """ + Initializes the differential attention module. + + Args: + config: Model configuration object containing hyperparameters, including: + - hidden_size: The size of hidden states. + - num_attention_heads: Number of attention heads. + - num_key_value_heads: Number of key/value heads. + - attention_bias: Whether to use bias in attention projections. + - split_heads: Whether to use split heads mode. + - rms_norm_eps: Epsilon for RMS normalization. + layer_idx: The index of this layer in the model. + + Note: + The initialization process consists of four steps: + 1. Configuration initialization (`_init_config`) + 2. Projection layers initialization (`_init_projections`) + 3. Differential parameters initialization (`_init_differential_params`) + 4. Normalization layers initialization (`_init_normalization`) + """ + super().__init__() + + self.config = config + self._init_config(layer_idx) + self._init_projections() + self._init_differential_params() + self._init_normalization() + + # For logging + self.attn1 = None + self.attn2 = None + self.lambda_full = None + + def _init_config(self, layer_idx: int) -> None: + """ + Initializes configuration parameters for the attention layer. Sets up various + dimension sizes and head counts based on the provided config. Handles both + split heads and double projection modes. + + In split heads mode, the number of heads is divided by 2 (rounding down), which + differs from the original implementation that required an even number. + + Args: + layer_idx: Index of the current layer. + """ + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + self.base_num_heads = self.config.num_attention_heads + self.base_num_kv_heads = self.config.num_key_value_heads + self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads + self.layer_idx = layer_idx + + if self.config.split_heads: + self.heads_per_component = self.base_num_heads // 2 + self.kv_heads_per_component = self.base_num_kv_heads // 2 + self.value_head_dim = 2 * self.head_dim + else: + self.heads_per_component = self.base_num_heads + self.kv_heads_per_component = self.base_num_kv_heads + self.value_head_dim = self.head_dim + + def _init_projections(self) -> None: + """ + Initializes the query, key, value, and output projection layers. + + Creates linear transformations for Q, K, V projections with dimensions + depending on whether split heads or double projection mode is used. + The output projection combines the attention heads back to model dimension. + """ + if self.config.split_heads: + q_out_dim = self.config.hidden_size + k_out_dim = self.head_dim * self.base_num_kv_heads + else: + q_out_dim = self.config.hidden_size * 2 + k_out_dim = self.head_dim * self.base_num_kv_heads * 2 + + self.q_proj = nn.Linear( + self.config.hidden_size, q_out_dim, bias=self.config.attention_bias + ) + self.k_proj = nn.Linear( + self.config.hidden_size, k_out_dim, bias=self.config.attention_bias + ) + self.v_proj = nn.Linear( + self.config.hidden_size, + self.head_dim * self.base_num_kv_heads, + bias=self.config.attention_bias, + ) + self.o_proj = nn.Linear( + self.base_num_heads * self.head_dim, + self.config.hidden_size, + bias=self.config.attention_bias, + ) + + def _init_differential_params(self) -> None: + """ + Initializes parameters specific to differential attention. + + Creates learnable parameters for the differential attention mechanism: + - Mixing parameter for negative attention component warmup phase. + - Lambda parameters for queries and keys. + - Initial lambda value based on layer index. + - Rotary position embedding layer. + """ + self.diff_attn_mix = 1.0 # Default to full mixing + + self.lambda_init = nn.Parameter( + torch.full((), lambda_init_fn(self.layer_idx)), + requires_grad=False, + ) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + + self.rotary_emb = LlamaRotaryEmbedding(config=self.config) + + def _init_normalization(self) -> None: + """ + Initializes normalization layers for the attention mechanism. + + Sets up either RMS normalization or identity transformation based on config. + The normalization is applied to the sublayer output if enabled. + """ + sublayer_norm = getattr(self.config, "sublayer_norm", True) + if sublayer_norm: + self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps) + else: + self.subln = nn.Identity() + + def _prepare_attention_inputs( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Prepares input tensors for attention computation. + + Projects input hidden states to query, key, and value spaces, then reshapes + them for multi-head attention processing. + + Args: + hidden_states: Input tensor of shape `(batch_size, seq_len, + hidden_size)`. + + Returns: + tuple: Tuple containing: + - q1: Positive attention query component + - q2: Negative attention query component + - k1: Positive attention key component + - k2: Negative attention key component + - v: Value tensor + """ + bsz, q_len, _ = hidden_states.size() + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) + + q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( + 1, 2 + ) + k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose( + 1, 2 + ) + v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) + + return q1, q2, k1, k2, v + + def _apply_rotary_embeddings( + self, + q1: torch.Tensor, + q2: torch.Tensor, + k1: torch.Tensor, + k2: torch.Tensor, + position_ids: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + """ + Applies rotary positional embeddings to queries and keys. + + Args: + q1: Positive attention query component. + q2: Negative attention query component. + k1: Positive attention key component. + k2: Negative attention key component. + position_ids: Token position indices. + position_embeddings: Pre-computed rotary embeddings (cos, sin). + + Returns: + tuple: Tuple containing: + - q1: Positive attention query with positional encoding. + - q2: Negative attention query with positional encoding. + - k1: Positive attention key with positional encoding. + - k2: Negative attention key with positional encoding. + - cos: Cosine part of rotary embeddings. + - sin: Sine part of rotary embeddings. + """ + if position_embeddings is None: + LOG.warning( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + return q1, q2, k1, k2, cos, sin + + def _handle_cache( + self, + k1: torch.Tensor, + k2: torch.Tensor, + v: torch.Tensor, + past_key_value: Cache | None, + cache_kwargs: dict, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Handles key-value caching for autoregressive generation and the repetition of + key-value heads to match the number of query heads. + + Args: + k1: Positive attention key component. + k2: Negative attention key component. + v: Value tensor. + past_key_value: Cache object for storing previous key-value pairs. + cache_kwargs: Additional arguments for cache handling. + + Returns: + tuple: Tuple containing: + - k1: Processed positive attention key component. + - k2: Processed negative attention key component. + - v: Processed value tensor. + """ + if past_key_value is not None: + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + k1 = repeat_kv(k1, self.num_key_value_groups) + k2 = repeat_kv(k2, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + if self.config.split_heads: + v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1) + + return k1, k2, v + + def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor: + """ + Computes lambda values for differential attention. + + The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed + from the learned parameters. `diff_attn_mix` is multiplied through the result + for negative attention component warmup phase (if applicable). + + Args: + q1: Positive attention query component, used for type casting. + + Returns: + Computed lambda value for differential attention. + """ + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + lambda_full = lambda_1 - lambda_2 + self.lambda_init + + return self.diff_attn_mix * lambda_full + + def _process_attention_output( + self, attn: torch.Tensor, bsz: int, q_len: int + ) -> torch.Tensor: + """ + Processes and projects the attention output. Applies sublayer normalization, + scales by (1 - λ_init), and projects back to model dimension. + + Args: + attn: Raw attention output. + bsz: Batch size. + q_len: Query sequence length. + + Returns: + Processed attention output of shape (batch_size, seq_len, hidden_size) + """ + attn = self.subln(attn) + # NOTE: this may need to be added back in, but doesn't interact well with + # `diff_attn_mix`, and doesn't allow us to preserve the original model output. + # attn = attn * self.diff_attn_mix * (1 - self.lambda_init) + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size) + + return self.o_proj(attn) + + +class LlamaDifferentialAttention(LlamaDifferentialAttentionBase): + """ + Standard implementation of differential attention. + + This class implements the standard differential attention mechanism using + explicit matrix multiplications for the attention computation. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + Computes differential attention using standard matrix multiplication operations. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - Attention weights if output_attentions is True, else None. + - Updated key-value cache if use_cache is True, else None. + """ + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Standard attention computation + attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) + attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : k1.shape[-2]] + attn1 = attn1 + causal_mask + attn2 = attn2 + causal_mask + + attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) + attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) + + dropout_p = self.config.attention_dropout if self.training else 0.0 + attn1 = F.dropout(attn1, p=dropout_p, training=self.training) + attn2 = F.dropout(attn2, p=dropout_p, training=self.training) + + lambda_full = self._compute_lambda(q1) + attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) + attn = self._process_attention_output(attn, bsz, q_len) + + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + + if output_attentions: + attn_weights = attn1 - lambda_full * attn2 + attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1) + return attn, attn_weights, past_key_value + return attn, None, past_key_value + + +class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase): + """ + SDPA-based implementation of differential attention. + + This class implements differential attention using PyTorch's scaled_dot_product_attention + for improved performance on supported hardware. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + Computes differential attention using PyTorch's scaled dot product attention. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - None for attention weights (SDPA doesn't support output_attentions). + - Updated key-value cache if use_cache is True, else None. + """ + if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but " + + "`torch.nn.functional.scaled_dot_product_attention` does not support " + + "`output_attentions=True`. Falling back to the eager attention implementation." + ) + + # pylint: disable=duplicate-code + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # SDPA-specific attention computation + causal_mask = ( + None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]] + ) + is_causal = attention_mask is None and q_len > 1 + dropout_p = self.config.attention_dropout if self.training else 0.0 + + if q1.device.type == "cuda" and causal_mask is not None: + q1, q2 = q1.contiguous(), q2.contiguous() + k1, k2 = k1.contiguous(), k2.contiguous() + v = v.contiguous() + + attn1 = F.scaled_dot_product_attention( + q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + attn2 = F.scaled_dot_product_attention( + q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + attn = self._process_attention_output(attn, bsz, q_len) + + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + + return attn, None, past_key_value + + +class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase): + """ + Flash Attention 2-based implementation of differential attention. + + This class implements differential attention using Flash Attention 2 for maximum + performance on supported hardware. + """ + + def __init__(self, *args, **kwargs): + """ + Initializes the Flash Attention 2 differential attention module. + + Args: + *args: Positional arguments passed to parent class. + **kwargs: Keyword arguments passed to parent class. + + Raises: + ImportError: If flash-attn library is not installed. + """ + if not FLASH_ATTENTION_AVAILABLE: + raise ImportError( + "LlamaDifferentialFlashAttention2 requires flash-attn library. " + "Please install with `pip install flash-attn --no-build-isolation`" + ) + + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_value: Cache | None = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: torch.LongTensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + Computes differential attention using Flash Attention 2. + + Args: + hidden_states: Input tensor containing sequence to attend to. + attention_mask: Mask to avoid attention on padding tokens. + position_ids: Indices of positions for positional embeddings. + past_key_value: Cached key and value tensors for autoregressive decoding. + output_attentions: Whether to return attention weights. + use_cache: Whether to use cached key/value states. + cache_position: Position indices for cached states. + position_embeddings: Pre-computed positional embeddings. + **kwargs: Additional arguments passed to the forward call. + + Returns: + tuple containing: + - Output tensor after attention computation. + - None for attention weights (Flash Attention doesn't support output_attentions). + - Updated key-value cache if use_cache is True, else None. + """ + if output_attentions: + LOG.warning( + "LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but " + + "flash attenion does not support `output_attentions=True`. Falling back " + + "to the eager attention implementation." + ) + + # pylint: disable=duplicate-code + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Flash Attention specific processing + q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2) + k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2) + v = v.transpose(1, 2) + + dropout_p = self.config.attention_dropout if self.training else 0.0 + + if self.config.split_heads: + v1, v2 = v.chunk(2, dim=-1) + attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True) + attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True) + attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True) + attn2 = torch.cat([attn21, attn22], dim=-1) + else: + attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True) + attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True) + + attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + attn = self._process_attention_output(attn, bsz, q_len) + + # Save for logging + self.attn1 = attn1 + self.attn2 = attn2 + self.lambda_full = lambda_full + + return attn, None, past_key_value diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py new file mode 100644 index 000000000..90e5d838b --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -0,0 +1,401 @@ +""" +Modeling for differential transformers. + +This module implements differential attention variants of the LLaMA model, +providing various attention implementations for improved performance. +""" + +import logging + +import torch +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel + +from .diff_attn import ( + LlamaDifferentialAttention, + LlamaDifferentialFlashAttention2, + LlamaDifferentialSdpaAttention, +) + +logger = logging.getLogger(__name__) + + +class LlamaDifferentialConfig(LlamaConfig): + """ + Configuration class for Differential LLaMA model. + + Extends the base LLaMA configuration with additional parameters for differential + attention mechanisms. + """ + + model_type = "llama-differential" + + def __init__( + self, + split_heads: bool = False, + sublayer_norm: bool = True, + zero_init: bool = False, + mirror_weights: bool = False, + **kwargs, + ): + """ + Initialize differential LLaMA configuration. + + Args: + split_heads: Whether to use split heads mode for attention computation. + sublayer_norm: Whether to apply normalization to sublayers. + zero_init: Whether to initialize new weights to zero. + mirror_weights: Whether to copy the positive attention component weights to + the negative attention component. + **kwargs: Additional arguments passed to LlamaConfig. + """ + super().__init__(**kwargs) + self.split_heads = split_heads + self.sublayer_norm = sublayer_norm + self.zero_init = zero_init + self.mirror_weights = mirror_weights + self.architectures = ["LlamaDifferentialModel"] + self._attn_implementations = { + "eager": "differential_eager", + "sdpa": "differential_sdpa", + "flash_attention_2": "differential_flash_attention_2", + } + + +class LlamaDifferentialModel(LlamaModel): + """ + LlamaModel with differential attention. + + This class extends the base LLaMA model by replacing standard attention with + differential attention mechanisms. + """ + + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + + def __init__(self, config: LlamaDifferentialConfig): + """ + Initialize a differential LLaMA model. + + Args: + config: Configuration object for the model. + + Raises: + ValueError: If specified attention implementation is not supported. + """ + super().__init__(config) + + # Handle attention implementation + attn_impl = config._attn_implementation or "eager" + if attn_impl in config._attn_implementations: + attn_impl = config._attn_implementations[attn_impl] + + # Validate attention implementation + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_impl not in valid_impls: + raise ValueError(f"Invalid attention implementation: {attn_impl}") + + # Replace standard attention with differential attention in each layer + attn_classes = { + "differential_eager": LlamaDifferentialAttention, + "differential_sdpa": LlamaDifferentialSdpaAttention, + "differential_flash_attention_2": LlamaDifferentialFlashAttention2, + } + attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention) + + for idx, layer in enumerate(self.layers): + layer.self_attn = attn_class(config, idx) + + @classmethod + # pylint: disable=protected-access + def _autoset_attn_implementation( + cls, + config: LlamaDifferentialConfig, + **kwargs, # pylint: disable=unused-argument + ) -> LlamaDifferentialConfig: + """ + Automatically set the attention implementation based on config. + + Args: + config: Model configuration object. + **kwargs: Additional arguments (unused). + + Returns: + Updated configuration object. + + Raises: + ValueError: If specified attention implementation is not supported. + """ + config._attn_implementation_autoset = True + attn_implementation = getattr(config, "_attn_implementation", None) + + # Map standard types to differential types if mapping exists + if attn_implementation in config._attn_implementations: + config._attn_implementation = config._attn_implementations[ + attn_implementation + ] + return config + + # If no mapping, validate it's a valid differential type + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message) + + return config + + @classmethod + def from_llama( + cls, + model: LlamaModel | LlamaForCausalLM, + config: LlamaDifferentialConfig | None = None, + ) -> "LlamaDifferentialModel": + """ + Convert a `LlamaModel` to use differential attention. + + Args: + model: Base LLaMA model to convert. + config: Configuration for differential attention. If `None`, created from + base model config. + + Returns: + Converted model with differential attention. + + Raises: + ValueError: If number of heads is not even when using `split_heads` mode. + """ + logger.info(f"Converting {type(model).__name__} to {cls.__name__}") + + # Handle LlamaForCausalLM + if isinstance(model, LlamaForCausalLM): + model = model.model + + if config is None: + config = LlamaDifferentialConfig(**model.config.__dict__) + logger.debug(f"Created config: {config}") + + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + + new_model = cls(config) + + # Copy all weights except attention + logger.debug("Copying embeddings and norm") + new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict()) + new_model.norm.load_state_dict(model.norm.state_dict()) + + logger.debug("Copying layer weights") + for layer_idx, (new_layer, old_layer) in enumerate( + zip(new_model.layers, model.layers) + ): + # Copy everything except attention weights + new_layer.mlp.load_state_dict(old_layer.mlp.state_dict()) + new_layer.input_layernorm.load_state_dict( + old_layer.input_layernorm.state_dict() + ) + new_layer.post_attention_layernorm.load_state_dict( + old_layer.post_attention_layernorm.state_dict() + ) + + # Handle attention weights + new_layer.self_attn.v_proj.load_state_dict( + old_layer.self_attn.v_proj.state_dict() + ) + new_layer.self_attn.o_proj.load_state_dict( + old_layer.self_attn.o_proj.state_dict() + ) + + # Get the original projection sizes + old_q_size = old_layer.self_attn.q_proj.weight.size(0) + old_k_size = old_layer.self_attn.k_proj.weight.size(0) + + if not config.split_heads: + logger.debug( + f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}" + ) + new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_( + old_layer.self_attn.k_proj.weight.data + ) + + if config.zero_init: + logger.debug(f"Layer {layer_idx}: Zero initializing") + with torch.no_grad(): + new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_() + new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_() + new_layer.self_attn.lambda_q1.zero_() + new_layer.self_attn.lambda_k1.zero_() + new_layer.self_attn.lambda_q2.zero_() + new_layer.self_attn.lambda_k2.zero_() + new_layer.self_attn.lambda_init.zero_() + elif config.mirror_weights: + # Mirror weights for second component + new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_( + old_layer.self_attn.q_proj.weight.data + ) + new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_( + old_layer.self_attn.k_proj.weight.data + ) + + logger.info("Conversion complete") + + return new_model + + +class LlamaDifferentialForCausalLM(LlamaForCausalLM): + """ + `LlamaForCausalLM` with differential attention. + + This class extends the base LLaMA causal language model by incorporating + differential attention mechanisms. + """ + + config_class = LlamaDifferentialConfig + base_model_prefix = "llama_differential" + + def __init__(self, config: LlamaDifferentialConfig): + """ + Initialize a differential LLaMA model for causal language modeling. + + Args: + config: Configuration object for the model. + """ + super().__init__(config) + self.model = LlamaDifferentialModel(config) + + @classmethod + # pylint: disable=protected-access + def _autoset_attn_implementation( + cls, + config: LlamaDifferentialConfig, + **kwargs, # pylint: disable=unused-argument + ) -> LlamaDifferentialConfig: + """ + Automatically set the attention implementation based on config. + + Args: + config: Model configuration object. + **kwargs: Additional arguments (unused). + + Returns: + Updated configuration object. + + Raises: + ValueError: If specified attention implementation is not supported. + """ + config._attn_implementation_autoset = True + attn_implementation = getattr(config, "_attn_implementation", None) + + # Map standard types to differential types if mapping exists + if attn_implementation in config._attn_implementations: + config._attn_implementation = config._attn_implementations[ + attn_implementation + ] + + return config + + # If no mapping, validate it's a valid differential type + valid_impls = [ + None, + "differential_eager", + "differential_sdpa", + "differential_flash_attention_2", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message) + + return config + + @classmethod + def from_llama( + cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None + ) -> "LlamaDifferentialForCausalLM": + """ + Convert a `LlamaForCausalLM` to use differential attention. + + Args: + model: Base LLaMA model to convert. + config: Configuration for differential attention. If `None`, created from + base model config. + + Returns: + Converted model with differential attention. + + Raises: + ValueError: If number of heads is not even when using `split_heads` mode. + """ + if config is None: + config = LlamaDifferentialConfig(**model.config.__dict__) + + # Validate head counts if using split heads mode + if config.split_heads: + if config.num_attention_heads % 2 != 0: + raise ValueError( + f"Number of attention heads ({config.num_attention_heads}) must be even " + "when using split_heads=True" + ) + if config.num_key_value_heads % 2 != 0: + raise ValueError( + f"Number of key/value heads ({config.num_key_value_heads}) must be even " + "when using split_heads=True" + ) + + new_model = cls(config) + new_model.model = LlamaDifferentialModel.from_llama(model.model, config) + new_model.lm_head.load_state_dict(model.lm_head.state_dict()) + + return new_model + + +def register_diff_attn() -> None: + """ + Register differential attention components with the transformers library. + + This function registers the differential attention configurations and model classes + with the Auto* classes from `transformers`, making them available through the + standard model loading pipeline. + """ + # Register configs + AutoConfig.register("llama-differential", LlamaDifferentialConfig) + + # Register models + AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel) + AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM) + + from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + + LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention + LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + LLAMA_ATTENTION_CLASSES[ + "differential_flash_attention_2" + ] = LlamaDifferentialFlashAttention2 diff --git a/src/axolotl/utils/callbacks/diff_attn.py b/src/axolotl/utils/callbacks/diff_attn.py new file mode 100644 index 000000000..3e99e7d5f --- /dev/null +++ b/src/axolotl/utils/callbacks/diff_attn.py @@ -0,0 +1,234 @@ +""" +Monitor and log differential attention components during training. + +This module provides a callback for tracking the behavior of differential attention +mechanisms, including lambda parameters and attention statistics. +""" + +from typing import Any + +import torch +import wandb +from torch import nn +from transformers import TrainerCallback + +from axolotl.utils.distributed import is_main_process + + +class DifferentialAttentionMonitorCallback(TrainerCallback): + """ + Callback to monitor differential attention components and lambda parameters. + + This callback tracks attention statistics across all layers and provides detailed + monitoring for a specified number of layers evenly spaced through the model. + """ + + def __init__( + self, + log_every: int = 250, + num_monitor_layers: int = 3, + warmup_steps: int | None = None, + ): + """ + Initialize the differential attention monitor. + + Args: + log_every: Number of steps between logging events. + num_monitor_layers: Number of individual layers to monitor in detail. + warmup_steps: Optional parameter for negative attention component warmup. + """ + self.log_every = log_every + self.num_monitor_layers = num_monitor_layers + self.warmup_steps = warmup_steps + self.monitor_layers: list[int] | None = None # Will be set in on_train_begin + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module, + **kwargs, + ) -> None: + """ + Set up layer monitoring at the start of training. + + Args: + args: Training arguments. + state: Training state. + control: Training control object. + model: The model being trained. + **kwargs: Additional arguments passed by the trainer. + """ + if is_main_process(): + num_layers = len(model.model.layers) + self.num_monitor_layers = min(self.num_monitor_layers, num_layers) + + stride = ( + (num_layers - 1) / (self.num_monitor_layers - 1) + if self.num_monitor_layers > 1 + else 0 + ) + self.monitor_layers = [ + round(i * stride) for i in range(self.num_monitor_layers) + ] + print(f"Monitoring layers {self.monitor_layers} in detail") + + # pylint: disable=unused-argument + def on_step_end( + self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs + ) -> None: + """ + Log attention metrics at the end of each step. + + Collects and logs: + - Lambda parameter norms and values. + - Attention statistics (mean and std). + - Both per-layer and aggregate metrics. + + Args: + args: Training arguments. + state: Training state. + control: Training control object. + model: The model being trained. + **kwargs: Additional arguments passed by the trainer. + """ + if not is_main_process() or state.global_step % self.log_every != 0: + return + + assert self.monitor_layers is not None + + # Aggregate stats across all layers + all_q1_norms = [] + all_q2_norms = [] + all_k1_norms = [] + all_k2_norms = [] + all_lambda1 = [] + all_lambda2 = [] + all_lambda_full = [] + + metrics = {} + for layer_idx, layer in enumerate(model.model.layers): + attn = layer.self_attn + + # Collect stats for aggregation + all_q1_norms.append(attn.lambda_q1.norm().item()) + all_q2_norms.append(attn.lambda_q2.norm().item()) + all_k1_norms.append(attn.lambda_k1.norm().item()) + all_k2_norms.append(attn.lambda_k2.norm().item()) + + lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item() + lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item() + all_lambda1.append(lambda1) + all_lambda2.append(lambda2) + all_lambda_full.append(attn.lambda_full) + + # Log detailed metrics for monitored layers + if layer_idx in self.monitor_layers: + metrics.update( + { + f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(), + f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(), + f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(), + f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(), + f"layer_{layer_idx}/lambda1": lambda1, + f"layer_{layer_idx}/lambda2": lambda2, + f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(), + f"layer_{layer_idx}/lambda_full": lambda1 + - lambda2 + + attn.lambda_init.item(), + f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(), + f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(), + f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(), + f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(), + } + ) + + # Add aggregate metrics + metrics.update( + { + "aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms) + .mean() + .item(), + "aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(), + "aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms) + .mean() + .item(), + "aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(), + "aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms) + .mean() + .item(), + "aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(), + "aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms) + .mean() + .item(), + "aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(), + "aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(), + "aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(), + "aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(), + "aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(), + "aggregate/lambda_full_mean": torch.tensor(all_lambda_full) + .mean() + .item(), + "aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(), + } + ) + + if self.warmup_steps: + metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix + + wandb.log(metrics, step=state.global_step) + + +class DifferentialAttentionMixingCallback(TrainerCallback): + """ + Callback to gradually increase the weight of negative attention components during + training. + """ + + def __init__(self, warmup_steps: int): + """ + Args: + warmup_steps: Number of steps to linearly increase negative attention + weight from 0 to 1. If `None`, negative attention has full weight from + start. + """ + self.warmup_steps = warmup_steps + self.diff_attention_layers: list[nn.Module] | None = None + + # pylint: disable=unused-argument + def on_train_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module, + **kwargs, + ) -> None: + """Cache the differential attention layers at the start of training.""" + if model is not None: + # Get the actual model if it's wrapped + if hasattr(model, "module"): + model = model.module + + # Cache all differential attention layers + self.diff_attention_layers = [ + module for module in model.modules() if hasattr(module, "diff_attn_mix") + ] + + def on_step_begin( + self, + args: Any, + state: Any, + control: Any, + model: torch.nn.Module = None, + **kwargs, + ) -> None: + if self.diff_attention_layers and self.warmup_steps: + # Calculate mixing parameter (0 to 1) + mix = min(1.0, state.global_step / self.warmup_steps) + + # Update cached layers + for layer in self.diff_attention_layers: + layer.diff_attn_mix = mix diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76fe..d16db7613 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -713,19 +713,45 @@ def set_attention_config(self) -> None: if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + + if self.cfg.diff_attention: + self.model_kwargs[ + "attn_implementation" + ] = "differential_flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_flash_attention_2" + ) + else: + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_sdpa" + ) + else: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_eager" + ) + else: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + elif self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" + "differential_eager" ) if self.cfg.low_cpu_mem_usage: @@ -816,6 +842,7 @@ def _configure_zero3_memory_efficient_loading(): if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config + self.model = self.AutoModelLoader.from_pretrained( self.base_model, config=self.model_config, diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py new file mode 100644 index 000000000..c5c9e74ae --- /dev/null +++ b/src/axolotl/utils/yaml.py @@ -0,0 +1,157 @@ +"""Utilities for YAML files.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Set, Tuple, Union + +import yaml + + +class YAMLOrderTracker: + """Tracks the order of keys and section breaks in YAML files.""" + + def __init__(self, yaml_path: str): + self.yaml_path = yaml_path + self.structure, self.needs_break = self._parse_yaml_structure() + + def _get_indentation_level(self, line: str) -> int: + """Get the indentation level of a line.""" + return len(line) - len(line.lstrip()) + + def _parse_yaml_structure( + self, + ) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]: + """Parse the YAML file to extract structure and identify section breaks.""" + with open(self.yaml_path, "r", encoding="utf-8") as file: + contents = file.readlines() + + structure: OrderedDict = OrderedDict() + needs_break = set() # Track which keys should have a break before them + current_path = [] + last_indentation = -1 + had_empty_line = False + + for line in contents: + # Track empty lines and comments + if not line.strip() or line.strip().startswith("#"): + had_empty_line = True + continue + + # Get indentation level and content + indentation = self._get_indentation_level(line) + content = line.strip() + + # Skip lines that don't define keys + if ":" not in content: + continue + + # Extract key + key = content.split(":")[0].strip() + + # If this is a top-level key and we had an empty line, mark it + if indentation == 0: + if had_empty_line: + needs_break.add(key) + had_empty_line = False + + # Handle indentation changes + if indentation > last_indentation: + current_path.append(key) + elif indentation < last_indentation: + levels_up = (last_indentation - indentation) // 2 + current_path = current_path[:-levels_up] + current_path[-1] = key + else: + if current_path: + current_path[-1] = key + + # Update structure + current_dict = structure + for path_key in current_path[:-1]: + if path_key not in current_dict: + current_dict[path_key] = OrderedDict() + current_dict = current_dict[path_key] + + if current_path: + if current_path[-1] not in current_dict: + current_dict[current_path[-1]] = OrderedDict() + + last_indentation = indentation + + return structure, needs_break + + +class OrderedDumper(yaml.SafeDumper): + """Custom YAML dumper that maintains dictionary order.""" + + +def represent_none(self, _): + """Represent None values as empty fields.""" + return self.represent_scalar("tag:yaml.org,2002:null", "") + + +def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: + """Custom representer for dictionaries that maintains order.""" + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict: + """Reorder a dictionary based on a reference structure.""" + ordered = OrderedDict() + + # First add keys that are in the reference order + for key in reference_structure: + if key in data: + if isinstance(reference_structure[key], dict) and isinstance( + data[key], dict + ): + ordered[key] = reorder_dict(data[key], reference_structure[key]) + else: + ordered[key] = data[key] + + # Then add any remaining keys that weren't in the reference + for key in data: + if key not in ordered: + ordered[key] = data[key] + + return ordered + + +def dump_yaml_preserved_order( + data: Dict, reference_yaml_path: str, output_path: str +) -> None: + """Dump YAML file while preserving nested order and normalized spacing.""" + # Get reference structure and spacing + tracker = YAMLOrderTracker(reference_yaml_path) + + # Reorder the data + ordered_data = reorder_dict(data, tracker.structure) + + # Register the custom representers + OrderedDumper.add_representer(type(None), represent_none) + OrderedDumper.add_representer(dict, ordered_dict_representer) + OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) + + # First dump to string + yaml_str = yaml.dump( + ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False + ) + + # Add spacing according to reference + lines = yaml_str.split("\n") + result_lines: List[str] = [] + current_line = 0 + + while current_line < len(lines): + line = lines[current_line] + if line.strip() and ":" in line and not line.startswith(" "): # Top-level key + key = line.split(":")[0].strip() + if key in tracker.needs_break: + # Add single empty line before this key + if result_lines and result_lines[-1] != "": + result_lines.append("") + result_lines.append(line) + current_line += 1 + + # Write the final result + with open(output_path, "w", encoding="utf-8") as file: + file.write("\n".join(result_lines)) diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19..d360e29d6 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/test_cli_base.py b/tests/cli/test_cli_base.py index 6dbae045f..f8f1edfa3 100644 --- a/tests/cli/test_cli_base.py +++ b/tests/cli/test_cli_base.py @@ -43,14 +43,12 @@ def _test_basic_execution( result = cli_runner.invoke(cli, [command, str(config_path)]) assert mock.called - assert mock.call_args.args[0] == [ + assert mock.call_args.args[0][:5] == [ "accelerate", "launch", "-m", f"axolotl.cli.{command}", str(config_path), - "--debug-num-examples", - "0", ] assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b029..f06f06717 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d25..b8effa3d2 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b76..935fb85b8 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli @@ -22,6 +23,7 @@ def test_build_command(): "--batch-size", "8", "--debug", + "--nouse-fp16", ] diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98..aac016760 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b77..420c28b9e 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aa..e2dd3a6c3 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_shard.py b/tests/cli/test_cli_shard.py index 505a2a737..86766e17f 100644 --- a/tests/cli/test_cli_shard.py +++ b/tests/cli/test_cli_shard.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI shard command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -11,14 +12,12 @@ def test_shard_with_accelerate(cli_runner, config_path): result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) assert mock.called - assert mock.call_args.args[0] == [ + assert mock.call_args.args[0][:5] == [ "accelerate", "launch", "-m", "axolotl.cli.shard", str(config_path), - "--debug-num-examples", - "0", ] assert mock.call_args.kwargs == {"check": True} assert result.exit_code == 0 diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e94..533dd5c0e 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac72..ecb0025e4 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch diff --git a/tests/e2e/integrations/convert_diff_transformer/__init__.py b/tests/e2e/integrations/convert_diff_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/e2e/integrations/convert_diff_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py new file mode 100644 index 000000000..3964df052 --- /dev/null +++ b/tests/e2e/integrations/convert_diff_transformer/conftest.py @@ -0,0 +1,31 @@ +"""Shared fixtures for differential transformer conversion tests.""" + +import pytest +from click.testing import CliRunner + + +@pytest.fixture(scope="class") +def base_config(): + """Basic config for testing.""" + return { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "datasets": [ + { + "path": "axolotl-ai-co/alpaca_100_test", + "type": "alpaca", + }, + ], + "gradient_accumulation_steps": 1, + "learning_rate": 1e-4, + "val_set_size": 0.1, + "micro_batch_size": 1, + "sequence_len": 2048, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + } + + +@pytest.fixture(scope="class") +def cli_runner(): + return CliRunner() diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py new file mode 100644 index 000000000..d5915f8a5 --- /dev/null +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py @@ -0,0 +1,51 @@ +"""End-to-end tests for differential transformer conversion and evaluation.""" +# pylint: disable=duplicate-code + +from pathlib import Path + +import yaml +from pytest import approx + +from axolotl.cli import load_cfg +from axolotl.cli.evaluate import do_evaluate +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer +from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs + + +def test_conversion_and_eval_cli(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + eval_cfg = load_cfg(str(output_dir)) + eval_cli_args = EvaluateCliArgs() + all_metrics = do_evaluate(eval_cfg, eval_cli_args) + + assert list(all_metrics.keys()) == [ + "train_loss", + "train_model_preparation_time", + "train_runtime", + "train_samples_per_second", + "train_steps_per_second", + "eval_loss", + "eval_model_preparation_time", + "eval_runtime", + "eval_samples_per_second", + "eval_steps_per_second", + ] + assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4) + assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4) diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py new file mode 100644 index 000000000..e1ad31fdd --- /dev/null +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -0,0 +1,150 @@ +"""End-to-end tests for differential transformer conversion.""" +# pylint: disable=redefined-outer-name +# pylint: disable=duplicate-code + +from pathlib import Path +from typing import Optional +from unittest.mock import patch + +import pytest +import yaml + +from axolotl.cli import load_cfg +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer +from axolotl.cli.main import cli +from axolotl.common.cli import ConvertDiffTransformerCliArgs + + +def test_cli_validation(cli_runner): + # Test missing config file + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) + assert result.exit_code != 0 + assert "Error: Missing argument 'CONFIG'." in result.output + + # Test non-existent config file + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) + assert result.exit_code != 0 + assert "Error: Invalid value for 'CONFIG'" in result.output + + +def test_basic_execution(cli_runner, tmp_path: Path, base_config): + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + with patch( + "axolotl.cli.integrations.convert_diff_transformer.do_cli" + ) as mock_do_cli: + result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) + assert result.exit_code == 0 + + mock_do_cli.assert_called_once() + assert mock_do_cli.call_args.kwargs["config"] == str(config_path) + + +def test_conversion_cli_basic(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs() + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_debug(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert not debug_info["generations_match"] + assert not debug_info["match_expected"] + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +def test_conversion_cli_reproduce(tmp_path: Path, base_config): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_repoduce_attentions( + tmp_path: Path, base_config, attention: Optional[str] +): + output_dir = tmp_path / "converted" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs( + debug=True, zero_init=True, sublayer_norm=False + ) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is True + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists() + + +@pytest.mark.parametrize( + "attention", ["eager_attention", "sdp_attention", "flash_attention"] +) +def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str): + output_dir = tmp_path / "converted" + + # Smallest model with an even number of attention heads + base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B" + base_config["output_dir"] = str(output_dir) + base_config[attention] = True + + config_path = tmp_path / "config.yml" + with open(config_path, "w", encoding="utf-8") as file: + yaml.dump(base_config, file) + + cfg = load_cfg(str(config_path)) + cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) + + assert debug_info["generations_match"] is False + assert (output_dir / "model.safetensors").exists() + assert (output_dir / "config.json").exists() + assert (output_dir / "axolotl_config.yml").exists()