Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-diff-transformer CLI command / codepath #2197

Draft
wants to merge 38 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7a4b296
Basic evaluate CLI command / codepath (#2188)
djsaunde Dec 16, 2024
13cdffa
initial diff attn layer / model conversion implementation (support fo…
djsaunde Dec 11, 2024
7be0d74
Adding script for doing conversion; fixes and updates
djsaunde Dec 12, 2024
df1504a
adding CLI command for convert-diff-transformer
djsaunde Dec 12, 2024
e484ec7
training fixes, patching, minor cleanup
djsaunde Dec 13, 2024
849bc94
various improvemnents
djsaunde Dec 13, 2024
2f9fa4c
various improvemnents
djsaunde Dec 13, 2024
6665acf
fix model save / load logic
djsaunde Dec 17, 2024
4c050ce
pre-commit fix
djsaunde Dec 17, 2024
41ebd93
moving monkeypatch
djsaunde Dec 17, 2024
bda1eed
differential flash attention 2; cleanup
djsaunde Dec 17, 2024
63b8e42
duplicate code ignore
djsaunde Dec 17, 2024
d22e113
convert-differential-transformer test coverage
djsaunde Dec 17, 2024
ea07a70
plugin implementation
djsaunde Dec 18, 2024
0b382c8
fixes post-rebase
djsaunde Dec 18, 2024
505321a
isolating problematic test
djsaunde Dec 18, 2024
66176b3
adding split_heads argument for retaining original (Q, K) dimensionan…
djsaunde Dec 18, 2024
1d935f6
moving tests around for flash_attn install
djsaunde Dec 18, 2024
390cb57
removing extra pytest xdist args
djsaunde Dec 19, 2024
0d56582
adding yaml dumper preserving input config format
djsaunde Dec 20, 2024
fcbfa86
refactor and fixing test isolation issues
djsaunde Dec 21, 2024
5b90da0
added modeling code; cleanup + refactor
Dec 23, 2024
a3fd507
fix duplicate-code warnings
Dec 23, 2024
4ff3328
updated custom modeling code
Dec 24, 2024
eb6611d
progress on modeling code
djsaunde Dec 24, 2024
3bc568e
adding registration function
Dec 27, 2024
78e0ec0
changes
djsaunde Dec 27, 2024
e5fa842
update
djsaunde Dec 27, 2024
332ce0a
fixes and cleanup
djsaunde Dec 28, 2024
2a7f139
pre-commit fix
djsaunde Dec 28, 2024
70c4e6f
updates and cleanup
djsaunde Jan 6, 2025
443327c
CLI build_command bugfix
djsaunde Jan 8, 2025
4f804f6
adding diff attn callback, adding documentation
djsaunde Jan 10, 2025
7aca08f
adding guard statements
djsaunde Jan 10, 2025
6dd47ed
fire CLI fixes
djsaunde Jan 10, 2025
661d71a
adding diff attn negative component warmup (in progress)
djsaunde Jan 10, 2025
fd8ad6f
fixing negative component mixing
djsaunde Jan 13, 2025
2869421
inline comment change
djsaunde Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,6 @@ out/

# vim
*.swp

# symlinked to axolotl-artifacts in docker containers
outputs
1 change: 0 additions & 1 deletion cicd/cicd.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"

pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/
4 changes: 2 additions & 2 deletions cicd/multigpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code

import os
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
Expand Down
6 changes: 3 additions & 3 deletions src/axolotl/cli/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union

import fire
from dotenv import load_dotenv
Expand All @@ -23,7 +23,7 @@
LOG = logging.getLogger("axolotl.cli.evaluate")


def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
Expand All @@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
Expand Down
Empty file.
208 changes: 208 additions & 0 deletions src/axolotl/cli/integrations/convert_diff_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""CLI to convert a transformers model's attention layers to differential attention layers."""

import logging
import warnings
from pathlib import Path
from time import time
from typing import Union

import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order

LOG = logging.getLogger(__name__)


def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}

start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)

return elapsed, generated_text


def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"

debug_info = {}

# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)

# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)

# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)

try:
# Convert attention
LOG.info("Converting to differential attention...")

config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)

# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)

# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)

with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}

modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]

# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")

if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)

if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False

if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False

return model, debug_info


def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()

cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

convert_diff_transformer(cfg, cli_args, config)


if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)
23 changes: 22 additions & 1 deletion src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
build_command,
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.common.cli import (
ConvertDiffTransformerCliArgs,
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig

Expand Down Expand Up @@ -77,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
Expand Down Expand Up @@ -240,6 +248,19 @@ def merge_lora(
do_cli(config=config, **kwargs)


@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.integrations.convert_diff_transformer import do_cli

do_cli(config=config, **kwargs)


@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
Expand Down
14 changes: 12 additions & 2 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def decorator(function):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type

if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
Expand All @@ -44,6 +43,7 @@ def decorator(function):
default=field.default,
help=field.metadata.get("description"),
)(function)

return function

return decorator
Expand All @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)

# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
Expand All @@ -66,6 +73,7 @@ def decorator(function):
function = click.option(
option_name, default=None, help=field.description
)(function)

return function

return decorator
Expand All @@ -84,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

Expand Down
Loading