Skip to content

Commit

Permalink
differential flash attention 2; cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 17, 2024
1 parent 3e6e9fc commit fbba9bb
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 106 deletions.
6 changes: 6 additions & 0 deletions model-out/eval_summary.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
metric,training,validation
loss,1.8773103952407837,1.915901780128479
model_preparation_time,0.0051,0.0051
runtime,89.7635,8.9565
samples_per_second,20.053,22.33
steps_per_second,20.053,22.33
1 change: 1 addition & 0 deletions outputs
66 changes: 32 additions & 34 deletions src/axolotl/cli/integrations/convert_diff_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attention
from axolotl.integrations.differential_transformer.convert import (
convert_to_diff_attention,
)

LOG = logging.getLogger("axolotl.cli.convert_attention")

Expand Down Expand Up @@ -74,7 +76,11 @@ def convert_diff_transformer(cfg, cli_args, config_path):
# Convert attention
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(model, cli_args.zero_init)
model = convert_to_diff_attention(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
Expand Down Expand Up @@ -130,43 +136,35 @@ def convert_diff_transformer(cfg, cli_args, config_path):
+ Fore.RESET
)
else:
if cli_args.zero_init:
LOG.info(
Fore.RED
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
)

if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
else:
LOG.info(
Fore.YELLOW
+ "Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{orig_text}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{conv_text}\n"
+ "*" * 50
+ "\n"
+ "However, this is expected since --zero-init was not passed."
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)

return model

except Exception as exc:
LOG.error(Fore.RED + "Process failed: %s" + Fore.RESET, str(exc))
raise
Expand Down
12 changes: 10 additions & 2 deletions src/axolotl/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def decorator(function):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type

if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
Expand All @@ -44,6 +43,7 @@ def decorator(function):
default=field.default,
help=field.metadata.get("description"),
)(function)

return function

return decorator
Expand All @@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)

# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
Expand All @@ -66,6 +73,7 @@ def decorator(function):
function = click.option(
option_name, default=None, help=field.description
)(function)

return function

return decorator
Expand Down
1 change: 1 addition & 0 deletions src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class ConvertDiffTransformerCliArgs:

debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)


def load_model_and_tokenizer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,35 @@
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaSdpaAttention
from transformers.models.mistral.modeling_mistral import MistralAttention
from transformers.models.mixtral.modeling_mixtral import MixtralAttention
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)

from .multihead_diffattn import (
from .differential_attention import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}


def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention],
new_attn: Union[LlamaDifferentialAttention, LlamaDifferentialSdpaAttention],
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -69,31 +82,24 @@ def copy_attention_weights(


def convert_to_diff_attention(
model: PreTrainedModel, zero_init: bool
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
attention_patterns = (
LlamaAttention,
LlamaSdpaAttention,
MistralAttention,
MixtralAttention,
)
layer_idx = 0

# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm

def convert_module(module):
nonlocal layer_idx

# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, attention_patterns):
layer_type = type(child).__name__

if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
# Choose appropriate differential attention class
if isinstance(child, LlamaSdpaAttention):
attention_class = LlamaDifferentialSdpaAttention
else:
attention_class = LlamaDifferentialAttention
attention_class = ATTENTION_MAPPING[type(child)]

layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
Expand Down
Loading

0 comments on commit fbba9bb

Please sign in to comment.