-
-
Notifications
You must be signed in to change notification settings - Fork 898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
convert-diff-transformer CLI command / codepath #2197
base: main
Are you sure you want to change the base?
Conversation
f2c37e7
to
2717b97
Compare
def dump_yaml_preserved_order( | ||
data: Dict, reference_yaml_path: str, output_path: str | ||
) -> None: | ||
"""Dump YAML file while preserving nested order and normalized spacing.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔥
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could similarly have a function to normalize any config yaml file to have some expected ordering / formatting.
I thought differential transformer requires model architecture change and modeling code change? Does this somehow automatically implement a modeling.py for the model? |
Good question. I've implemented a monkeypatch in As for the architecture change, we have |
adc024b
to
938b627
Compare
monkey patch only works in the context of Axolotl - we will need a modeling.py to make inference work properly in the wild (transformers, TGI, vllm, etc) right? (If I understand correctly) |
* basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders <[email protected]>
938b627
to
7d9ec2c
Compare
Description
This PR implements the differential attention layer from the Differential Transformer paper.
Motivation and Context
We wanted to add this attention implementation to
axolotl
so users can swap out the existing attention layers in their models for this more performant version. We matched the official implementation details as closely as possible, while adopting it to play nicely with thetransformers
attention implementations.Since we were focused on being able to convert existing LLMs to having these differential attention layers, we wanted a way to not degrade the performance of the (possibly pre-trained) LLM while doing so.
To this end, the conversion process doubles the dimensionality of the query and key projections (since the differential attention requires both a positive and negative component of the attention) and (optionally; pass
--zero-init
) initializes the weights of the negative component to zero, while copying over the weights from the original attention modules to the positive components.When doing this, the converted network computes the same function as the original (pass
--debug
to confirm this), but may suffer from a vanishing gradient problem. The default behavior is thus to initialize the weights of the negative components of the differential attention layers to 0-centered normally distributed values with a small variance.Relevant links:
How has this been tested?
SmolLM2-135m on A40 Runpod instance on this feature branch. Workflow was:
--zero-init
and--debug
flags for sanity checking exact model conversion (completions, logits, losses)axolotl evaluate
command on the smallmhenrichsen/alpaca_2k_test
dataset with both the original and converted model and check that their evaluation metrics matchFor example:
Types of changes
axolotl.integrations.diff_transformer
module, which implements the differential attention layers for the Llama LLM architecture and for various attention implementations (eager, SDPA, Flash Attention 2), andaxolotl.cli.integrations.convert_diff_transformer
module (and updates toaxolotl.cli.main
), which implements theconvert-diff-transformer
CLI command, andaxolotl.cli.integrations.convert_diff_transformer.patches
(to be moved) for updatingLLAMA_ATTENTION_CLASSES
constant intransformers.models.llama.modeling_llama
.TODO