Skip to content

Commit

Permalink
Add script to write HF-format config.json for Llama
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinGleize committed Dec 23, 2024
1 parent 9a89641 commit bce9212
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 14 deletions.
36 changes: 36 additions & 0 deletions src/fairseq2/checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from abc import ABC, abstractmethod
from collections.abc import Iterator, Mapping, Set
from contextlib import AbstractContextManager, nullcontext
import json
from pathlib import Path
from shutil import rmtree
from typing import final
Expand Down Expand Up @@ -94,7 +95,21 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None:
:param metadata:
The metadata to save. Must be pickeable.
"""

Check failure on line 98 in src/fairseq2/checkpoint/manager.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
@abstractmethod
def save_json_dict(
self,
output_name: str,
json_dict: Mapping[str, object],
) -> None:
"""Save a collection of key-values in JSON format, associated with the checkpoint.
:param output_name:
The name of the output json artifact.
:param json_dict:
The key-values to save. Must be json.dumps-able.
"""

Check failure on line 112 in src/fairseq2/checkpoint/manager.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
@abstractmethod
def save_score(self, score: float | None) -> None:
"""Save the score of the checkpoint."""
Expand Down Expand Up @@ -444,6 +459,27 @@ def save_metadata(self, metadata: Mapping[str, object]) -> None:
) from ex

self._root_gang.barrier()

Check failure on line 462 in src/fairseq2/checkpoint/manager.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
@override
def save_json_dict(
self,
output_name: str,
json_dict: Mapping[str, object],
) -> None:
to_write = json.dumps(json_dict, indent=2, sort_keys=True) + "\n"

if self._root_gang.rank == 0:
json_file = self._checkpoint_dir.joinpath(f"{output_name}")

try:
with json_file.open("w") as fp:
fp.write(to_write)
except OSError as ex:
raise CheckpointError(
f"The JSON file named {output_name} cannot be saved at training step {step_nr}. See the nested exception for details."

Check failure on line 479 in src/fairseq2/checkpoint/manager.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

undefined name 'step_nr'

Check failure on line 479 in src/fairseq2/checkpoint/manager.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Name "step_nr" is not defined
) from ex

self._root_gang.barrier()

@override
def save_score(self, score: float | None) -> None:
Expand Down
46 changes: 45 additions & 1 deletion src/fairseq2/models/llama/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@

import math
from dataclasses import dataclass, field
from typing import Final
from typing import Any, Final

import torch
from torch import Tensor

from fairseq2.config_registry import ConfigRegistry
from fairseq2.data import VocabularyInfo
from fairseq2.models.factory import model_factories
from fairseq2.models.llama.integ import get_ffn_dim_multipliers
from fairseq2.models.transformer import (
TransformerDecoderModel,
TransformerEmbeddingFrontend,
Expand Down Expand Up @@ -324,3 +325,46 @@ def get_llama_lora_config() -> LoRAConfig:
dropout_p=0.05,
keys=[".*decoder.layers.*.self_attn.*(q_proj|v_proj)$"],
)


def convert_to_huggingface_config(arch: str, config: LLaMAConfig) -> dict[str, Any]:

def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256):

Check failure on line 332 in src/fairseq2/models/llama/factory.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

Function is missing a type annotation
"""From: https://github.com/huggingface/transformers/blob/82fcac0a7e40dc6cc5e3121d714b9b16775293ad/src/transformers/models/llama/convert_llama_weights_to_hf.py#L171"""
return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of)

Check failure on line 335 in src/fairseq2/models/llama/factory.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
if config.use_scaled_rope:
rope_scaling = {
"factor": 32.0 if "3_2" in arch else 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
}
else:
# mgleize: I'm not sure what's the to_json behavior is if rope_scaling ever is None
rope_scaling = None

# we only specify the parameters made explicit in the Huggingface converter
# https://github.com/huggingface/transformers/blob/93aafdc620d39b9ec714ffecf015a085ea221282/src/transformers/models/llama/convert_llama_weights_to_hf.py#L384
return {
"architectures": ["Fairseq2LlamaForCausalLM"],
"bos_token_id": config.vocab_info.bos_idx,
"eos_token_id": config.vocab_info.eos_idx,
"hidden_size": config.model_dim,
"intermediate_size": compute_intermediate_size(
config.model_dim,
get_ffn_dim_multipliers(arch),
config.ffn_inner_dim_to_multiple,
),
"max_position_embeddings": config.max_seq_len,
"model_type": "llama",
"num_attention_heads": config.num_attn_heads,
"num_hidden_layers": config.num_layers,
"num_key_value_heads": config.num_key_value_heads,
"rms_norm_eps": 1e-5,
"rope_scaling": rope_scaling,
"rope_theta": config.rope_theta,
"tie_word_embeddings": "3_2" in arch,
"vocab_size": config.vocab_info.size,
}
15 changes: 15 additions & 0 deletions src/fairseq2/models/llama/integ.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,21 @@
from fairseq2.models.utils.checkpoint import convert_model_state_dict


def get_ffn_dim_multipliers(architecture: str) -> float:
# we only specify archs where multiplier != 1.0
ffn_dim_multipliers = {
"llama2_70b": 1.3,
"llama3_8b": 1.3,
"llama3_70b": 1.3,
"llama3_1_8b": 1.3,
"llama3_1_70b": 1.3,
"llama3_1_405b": 1.2,
"llama3_2_1b": 1.5,
}

Check failure on line 25 in src/fairseq2/models/llama/integ.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
return ffn_dim_multipliers.get(architecture, 1.0)


def convert_to_reference_checkpoint(checkpoint: dict[str, Any]) -> dict[str, Any]:
"""Convert a fairseq2 LLaMA checkpoint to the reference format."""
try:
Expand Down
7 changes: 7 additions & 0 deletions src/fairseq2/recipes/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from fairseq2.recipes.cli import Cli
from fairseq2.recipes.llama.convert_checkpoint import ConvertCheckpointCommandHandler
from fairseq2.recipes.llama.write_hf_config import WriteHfConfigCommandHandler


def _setup_llama_cli(cli: Cli) -> None:
Expand All @@ -18,3 +19,9 @@ def _setup_llama_cli(cli: Cli) -> None:
handler=ConvertCheckpointCommandHandler(),
help="convert fairseq2 LLaMA checkpoints to reference checkpoints",
)

group.add_command(
name="write_hf_config",
handler=WriteHfConfigCommandHandler(),
help="write fairseq2 LLaMA config in Huggingface config format",
)
17 changes: 4 additions & 13 deletions src/fairseq2/recipes/llama/convert_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from fairseq2.assets import default_asset_store
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.integ import convert_to_reference_checkpoint
from fairseq2.models.llama.integ import convert_to_reference_checkpoint, get_ffn_dim_multipliers
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.console import get_error_console
from fairseq2.setup import setup_fairseq2
Expand Down Expand Up @@ -160,19 +160,10 @@ def run(self, args: Namespace) -> int:
if model_config.num_attn_heads != model_config.num_key_value_heads:
params["model"]["n_kv_heads"] = model_config.num_key_value_heads

# we only specify archs where multiplier != 1.0
ffn_dim_multipliers = {
"llama2_70b": 1.3,
"llama3_8b": 1.3,
"llama3_70b": 1.3,
"llama3_1_8b": 1.3,
"llama3_1_70b": 1.3,
"llama3_1_405b": 1.2,
"llama3_2_1b": 1.5,
}
ffn_dim_multiplier = get_ffn_dim_multipliers(arch)

if arch in ffn_dim_multipliers:
params["model"]["ffn_dim_multiplier"] = ffn_dim_multipliers[arch]
if ffn_dim_multiplier != 1.0:
params["model"]["ffn_dim_multiplier"] = ffn_dim_multiplier

try:
with args.output_dir.joinpath("params.json").open("w") as fp:
Expand Down
84 changes: 84 additions & 0 deletions src/fairseq2/recipes/llama/write_hf_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

Check failure on line 1 in src/fairseq2/recipes/llama/write_hf_config.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

would reformat
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json
import sys
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import final

from typing_extensions import override

from fairseq2.assets import default_asset_store
from fairseq2.logging import get_log_writer
from fairseq2.models.llama import load_llama_config
from fairseq2.models.llama.factory import convert_to_huggingface_config
from fairseq2.recipes.cli import CliCommandHandler
from fairseq2.recipes.console import get_error_console
from fairseq2.setup import setup_fairseq2

log = get_log_writer(__name__)


@final
class WriteHfConfigCommandHandler(CliCommandHandler):
"""Writes fairseq2 LLaMA config files in Huggingface format."""

@override
def init_parser(self, parser: ArgumentParser) -> None:
parser.add_argument(
"--model",
metavar="ARCH_NAME",
help="model name to fetch architecture to generate config.json",
)

parser.add_argument(
"output_dir",
type=Path,
help="output directory to store reference checkpoint",
)

@override
def run(self, args: Namespace) -> int:
setup_fairseq2()

arch = (
default_asset_store.retrieve_card(args.model).field("model_arch").as_(str)
)

if arch:
model_config = load_llama_config(args.model)
else:
model_config = None

if model_config is None:
log.error("Model config could not be retrieved for model {}", args.model)

sys.exit(1)

Check failure on line 63 in src/fairseq2/recipes/llama/write_hf_config.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

blank line contains whitespace
args.output_dir.mkdir(parents=True, exist_ok=True)

# Convert and write the config
with get_error_console().status("[bold green]Writing config...") as status:

Check failure on line 67 in src/fairseq2/recipes/llama/write_hf_config.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

local variable 'status' is assigned to but never used

config = convert_to_huggingface_config(arch, model_config)
to_write = json.dumps(config, indent=2, sort_keys=True) + "\n"

json_file = args.output_dir.joinpath("config.json")

try:
with json_file.open("w") as fp:
fp.write(to_write)
except OSError as ex:
raise RuntimeError(
f"The file config.json cannot be saved. See the nested exception for details."

Check failure on line 79 in src/fairseq2/recipes/llama/write_hf_config.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

f-string is missing placeholders
) from ex

log.info("Config converted and saved in {}", json_file)

return 0

0 comments on commit bce9212

Please sign in to comment.