Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script to write Llama's HF-formatted config.json for vLLM #936

Merged
merged 7 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/fairseq2/models/llama/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

from fairseq2.data import VocabularyInfo
from fairseq2.models.llama.factory import LLaMAConfig, llama_arch
from fairseq2.models.llama.factory import LLaMAConfig, RopeScaling, llama_arch


@llama_arch("7b")
Expand Down Expand Up @@ -121,7 +121,7 @@ def _llama3_1_8b() -> LLaMAConfig:
config = _llama3_8b()

config.max_seq_len = 131_072
config.use_scaled_rope = True
config.rope_scaling = RopeScaling()

return config

Expand All @@ -131,7 +131,7 @@ def _llama3_1_70b() -> LLaMAConfig:
config = _llama3_70b()

config.max_seq_len = 131_072
config.use_scaled_rope = True
config.rope_scaling = RopeScaling()

return config

Expand All @@ -146,6 +146,7 @@ def _llama3_2_3b() -> LLaMAConfig:
config.num_attn_heads = 24
config.num_key_value_heads = 8
config.num_layers = 28
config.rope_scaling = RopeScaling(factor=32.0)

return config

Expand All @@ -160,5 +161,6 @@ def _llama3_2_1b() -> LLaMAConfig:
config.num_attn_heads = 32
config.num_key_value_heads = 8
config.num_layers = 16
config.rope_scaling = RopeScaling(factor=32.0)

return config
51 changes: 38 additions & 13 deletions src/fairseq2/models/llama/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

from __future__ import annotations

import functools
import math
from dataclasses import dataclass, field
from typing import Final
from typing import Final, final

import torch
from torch import Tensor
Expand Down Expand Up @@ -85,13 +86,35 @@ class LLaMAConfig:
rope_theta: float = 10_000.0
"""The coefficient of the long-term decay of the Rotary position encoder."""

use_scaled_rope: bool = False
"""If ``True``, scales Rotary encoding frequencies to LLaMA 3.1 context length."""
rope_scaling: RopeScaling | None = None
"""If specified, provides scaling parameters for RoPE frequencies,
aiming to increase the context length."""

dropout_p: float = 0.1
"""The dropout probability on outputs of Transformer layers."""


@final
@dataclass
class RopeScaling:
"""Holds the configuration for RoPE (Rotary Position Embedding)
scaling in Llama 3 models.
"""

factor: float = 8.0
"""Ratio between the intended max context length and the model’s
original max context length."""

low_freq_factor: float = 1.0
"""Factor used to define low frequencies."""

high_freq_factor: float = 4.0
"""Factor used to define high frequencies."""

original_context_length: int = 8192
"""Original context length. Defaults to LLaMA 3's context length."""


llama_archs = ConfigRegistry[LLaMAConfig]()

llama_arch = llama_archs.decorator
Expand Down Expand Up @@ -217,8 +240,11 @@ def build_attention(
sdpa = create_default_sdpa(attn_dropout_p=self._config.dropout_p)

if self._pos_encoder is None:
if self._config.use_scaled_rope:
freqs_init_fn = self._init_scaled_freqs
if self._config.rope_scaling is not None:
freqs_init_fn = functools.partial(
cbalioglu marked this conversation as resolved.
Show resolved Hide resolved
self._init_scaled_freqs,
rope_scaling=self._config.rope_scaling,
)
else:
freqs_init_fn = None

Expand Down Expand Up @@ -265,8 +291,14 @@ def build_layer_norm(
return RMSNorm(model_dim, bias=False, device=device, dtype=dtype)

@staticmethod
def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor:
def _init_scaled_freqs(
pos_encoder: RotaryEncoder, rope_scaling: RopeScaling
) -> Tensor:
device = pos_encoder.freqs.device
scale_factor = rope_scaling.factor
l_freq_factor = rope_scaling.low_freq_factor
h_freq_factor = rope_scaling.high_freq_factor
old_context_len = rope_scaling.original_context_length

# (E / 2)
indices = torch.arange(
Expand All @@ -278,13 +310,6 @@ def _init_scaled_freqs(pos_encoder: RotaryEncoder) -> Tensor:
if device.type == "meta":
return freqs # type: ignore[no-any-return]

old_context_len = 8192 # The context length of LLaMA 3.

scale_factor = 8.0

l_freq_factor = 1
h_freq_factor = 5

l_freq_wavelen = old_context_len / l_freq_factor
h_freq_wavelen = old_context_len / h_freq_factor

Expand Down
62 changes: 62 additions & 0 deletions src/fairseq2/models/llama/integ.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,24 @@

from typing import Any

from fairseq2.models.llama.factory import LLaMAConfig
from fairseq2.models.utils.checkpoint import convert_model_state_dict


def get_ffn_dim_multipliers(architecture: str) -> float:
ffn_dim_multipliers = {
"llama2_70b": 1.3,
"llama3_8b": 1.3,
"llama3_70b": 1.3,
"llama3_1_8b": 1.3,
"llama3_1_70b": 1.3,
"llama3_1_405b": 1.2,
"llama3_2_1b": 1.5,
}

return ffn_dim_multipliers.get(architecture, 1.0)


def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]:
"""Convert a fairseq2 LLaMA checkpoint to the reference format."""
try:
Expand Down Expand Up @@ -38,3 +53,50 @@ def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any
}

return convert_model_state_dict(state_dict, key_map)


def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]:
"""Convert Llama's config to a dict mirroring Huggingface's format"""

def compute_intermediate_size(
n: int, ffn_dim_multiplier: float = 1, multiple_of: int = 256
) -> int:
"""From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171"""
return multiple_of * (
(int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of
)

if config.rope_scaling is not None:
rope_scaling = {
"factor": config.rope_scaling.factor,
"low_freq_factor": config.rope_scaling.low_freq_factor,
"high_freq_factor": config.rope_scaling.high_freq_factor,
"original_max_position_embeddings": config.rope_scaling.original_context_length,
"rope_type": "llama3",
}
else:
rope_scaling = None

# we only specify the parameters made explicit in the Huggingface converter
# https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384
return {
"architectures": ["Fairseq2LlamaForCausalLM"],
"bos_token_id": config.vocab_info.bos_idx,
"eos_token_id": config.vocab_info.eos_idx,
cbalioglu marked this conversation as resolved.
Show resolved Hide resolved
"hidden_size": config.model_dim,
"intermediate_size": compute_intermediate_size(
config.model_dim,
get_ffn_dim_multipliers(arch),
config.ffn_inner_dim_to_multiple,
),
"max_position_embeddings": config.max_seq_len,
"model_type": "llama",
"num_attention_heads": config.num_attn_heads,
"num_hidden_layers": config.num_layers,
"num_key_value_heads": config.num_key_value_heads,
"rms_norm_eps": 1e-5,
"rope_scaling": rope_scaling,
"rope_theta": config.rope_theta,
"tie_word_embeddings": False,
"vocab_size": config.vocab_info.size,
}
7 changes: 7 additions & 0 deletions src/fairseq2/recipes/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from fairseq2.recipes.cli import Cli
from fairseq2.recipes.llama.convert_checkpoint import ConvertCheckpointCommandHandler
from fairseq2.recipes.llama.write_hf_config import WriteHfConfigCommandHandler


def _setup_llama_cli(cli: Cli) -> None:
Expand All @@ -18,3 +19,9 @@ def _setup_llama_cli(cli: Cli) -> None:
handler=ConvertCheckpointCommandHandler(),
help="convert fairseq2 LLaMA checkpoints to reference checkpoints",
)

group.add_command(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, but would be great if we could also have a way to automatically dump this config.json in LLM finetuning recipes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

name="write_hf_config",
handler=WriteHfConfigCommandHandler(),
help="write fairseq2 LLaMA config in Huggingface config format",
)
20 changes: 7 additions & 13 deletions src/fairseq2/recipes/llama/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
from fairseq2.assets import default_asset_store
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.integ import convert_to_reference_checkpoint
from fairseq2.models.llama.integ import (
convert_to_reference_checkpoint,
get_ffn_dim_multipliers,
)
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.console import get_error_console
from fairseq2.setup import setup_fairseq2
Expand Down Expand Up @@ -160,19 +163,10 @@ def run(self, args: Namespace) -> int:
if model_config.num_attn_heads != model_config.num_key_value_heads:
params["model"]["n_kv_heads"] = model_config.num_key_value_heads

# we only specify archs where multiplier != 1.0
ffn_dim_multipliers = {
"llama2_70b": 1.3,
"llama3_8b": 1.3,
"llama3_70b": 1.3,
"llama3_1_8b": 1.3,
"llama3_1_70b": 1.3,
"llama3_1_405b": 1.2,
"llama3_2_1b": 1.5,
}
ffn_dim_multiplier = get_ffn_dim_multipliers(arch)

if arch in ffn_dim_multipliers:
params["model"]["ffn_dim_multiplier"] = ffn_dim_multipliers[arch]
if ffn_dim_multiplier != 1.0:
params["model"]["ffn_dim_multiplier"] = ffn_dim_multiplier

try:
with args.output_dir.joinpath("params.json").open("w") as fp:
Expand Down
82 changes: 82 additions & 0 deletions src/fairseq2/recipes/llama/write_hf_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import final

from typing_extensions import override

from fairseq2.assets import default_asset_store
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.integ import convert_to_huggingface_config
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.setup import setup_fairseq2

log = get_log_writer(__name__)


@final
class WriteHfConfigCommandHandler(CliCommandHandler):
"""Writes fairseq2 LLaMA config files in Huggingface format."""

@override
def init_parser(self, parser: ArgumentParser) -> None:
parser.add_argument(
"--model",
metavar="ARCH_NAME",
help="model name to fetch architecture to generate config.json",
)

parser.add_argument(
"output_dir",
type=Path,
help="output directory to store reference checkpoint",
)

@override
def run(self, args: Namespace) -> int:
setup_fairseq2()

arch = (
default_asset_store.retrieve_card(args.model).field("model_arch").as_(str)
)

if arch:
model_config = load_llama_config(args.model)
else:
model_config = None

if model_config is None:
log.error("Config could not be retrieved for model {}", args.model)

sys.exit(1)

args.output_dir.mkdir(parents=True, exist_ok=True)

# Convert and write the config
log.info("Writing config...")

config = convert_to_huggingface_config(arch, model_config)

json_file = args.output_dir.joinpath("config.json")

try:
with json_file.open("w") as fp:
json.dump(config, fp, indent=2, sort_keys=True)
except OSError as ex:
raise RuntimeError(
f"The file {json_file} cannot be saved. See the nested exception for details."
) from ex

log.info("Config converted and saved in {}", json_file)

return 0
Loading