Skip to content

Commit

Permalink
adding yaml dumper preserving input config format
Browse files Browse the repository at this point in the history
  • Loading branch information
djsaunde committed Dec 20, 2024
1 parent e0adf11 commit 2717b97
Show file tree
Hide file tree
Showing 17 changed files with 579 additions and 707 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
from axolotl.utils.yaml import dump_yaml_preserved_order

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
raise


def convert_differential_transformer(cfg, cli_args, config_path):
def convert_diff_transformer(cfg, cli_args, config_path):
debug_info = {}

# Load model and tokenizer
Expand Down Expand Up @@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path):
LOG.info("Saving updated config to %s", output_config_path)

with open(config_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file) or {}
modified_cfg = yaml.safe_load(file) or {}

data["base_model"] = cfg.output_dir
data["differential_attention"] = True
data["plugins"] = [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
]
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]

with open(output_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file)
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
Expand Down Expand Up @@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)

convert_differential_transformer(cfg, cli_args, config)
convert_diff_transformer(cfg, cli_args, config)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,11 @@ def merge_lora(
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_differential_transformer(config: str, **kwargs):
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

from axolotl.cli.integrations.convert_differential_transformer import do_cli
from axolotl.cli.integrations.convert_diff_transformer import do_cli

do_cli(config=config, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class EvaluateCliArgs:
@dataclass
class ConvertDiffTransformerCliArgs:
"""
dataclass with arguments for convert-differential-transformer CLI
dataclass with arguments for convert-diff-transformer CLI
"""

debug: bool = field(default=False)
Expand Down
10 changes: 10 additions & 0 deletions src/axolotl/integrations/diff_transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Differential Transformer

### Usage

```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin

diff_attention: true
```
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin):
"""

def get_input_args(self):
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"

def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.differential_attention:
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""

differential_attention: Optional[bool] = None
diff_attention: Optional[bool] = None
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
LlamaSdpaAttention,
)

from .differential_attention import (
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
Expand Down
Loading

0 comments on commit 2717b97

Please sign in to comment.