Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Download Hugging Face models into Hugging Face cache #1285

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion install/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Requires python >=3.10

# Hugging Face download
huggingface_hub
huggingface_hub[hf_transfer]

# GGUF import
gguf
Expand Down
10 changes: 5 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.cli.download import get_model_dir
from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand Down Expand Up @@ -73,7 +74,7 @@ def __post_init__(self):
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
f"{self.checkpoint_path} is not a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
)

if self.dso_path and self.pte_path:
Expand Down Expand Up @@ -109,10 +110,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
model_config = resolve_model_config(args.model)

checkpoint_path = (
Path(args.model_directory)
/ model_config.name
get_model_dir(model_config, args.model_directory)
/ model_config.checkpoint_file
)
print(f"Using checkpoint path: {checkpoint_path}")
# The transformers config is keyed on the last section
# of the name/path.
params_table = (
Expand Down Expand Up @@ -264,8 +265,7 @@ def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = (
Path(args.model_directory)
/ model_config.name
get_model_dir(model_config, args.model_directory)
/ model_config.tokenizer_file
)
elif args.checkpoint_path:
Expand Down
2 changes: 1 addition & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _add_jit_downloading_args(parser) -> None:
"--model-directory",
type=Path,
default=default_model_dir,
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}. This is overriden by the huggingface cache directory if the model is downloaded from HuggingFace.",
)


Expand Down
74 changes: 18 additions & 56 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import json
import os
import re
Expand Down Expand Up @@ -42,12 +41,7 @@ def convert_hf_checkpoint(
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
if len(model_map_json_matches):
model_map_json = model_map_json_matches[0]
else:
model_map_json = model_dir / "pytorch_model.bin.index.json"
model_map_json = model_dir / "pytorch_model.bin.index.json"

# If there is no weight mapping, check for a consolidated model and
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
Expand All @@ -62,9 +56,10 @@ def convert_hf_checkpoint(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
print(f"Symlinking checkpoint to {model_dir / 'model.pth'}.")
consolidated_pth = os.path.realpath(consolidated_pth)
os.symlink(consolidated_pth, model_dir / "model.pth")
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
print("Done.")
return
else:
Expand All @@ -81,17 +76,10 @@ def convert_hf_checkpoint(
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
Expand All @@ -100,43 +88,19 @@ def convert_hf_checkpoint(
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
dim = config.dim
return (
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
w.view(n_heads, 2, config.head_dim // 2, dim)
.transpose(1, 2)
.reshape(w.shape)
.reshape(config.head_dim * n_heads, dim)
)

merged_result = {}
for file in sorted(bin_files):

# The state_dict can be loaded from either a torch zip file or
# safetensors. We take our best guess from the name and try all
# possibilities
load_pt_mmap = lambda: torch.load(
state_dict = torch.load(
str(file), map_location="cpu", mmap=True, weights_only=True
)
load_pt_no_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=False, weights_only=True
)
def load_safetensors():
import safetensors.torch
with open(file, "rb") as handle:
return safetensors.torch.load(handle.read())
if "safetensors" in str(file):
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
else:
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]

state_dict = None
for loader in loaders:
try:
state_dict = loader()
break
except Exception:
continue
assert state_dict is not None, f"Unable to load tensors from {file}"
merged_result.update(state_dict)

final_result = {}
for key, value in merged_result.items():
if "layers" in key:
Expand All @@ -152,18 +116,16 @@ def load_safetensors():
final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq.weight" in key or "wq.bias" in key:
wk_key = key.replace("wq", "wk")
wv_key = key.replace("wq", "wv")
if "wq" in key:
q = final_result[key]
k = final_result[wk_key]
v = final_result[wv_key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_heads)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[wk_key]
del final_result[wv_key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, model_dir / "model.pth")
print("Done.")
Expand All @@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune(
consolidated_pth = model_dir / "original" / "consolidated.pth"
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
print(f"Creating symlink from {consolidated_pth} to {model_dir / 'model.pth'}.")
os.symlink(consolidated_pth, model_dir / "model.pth")
print(f"Creating symlink from {tokenizer_pth} to {model_dir / 'tokenizer.model'}.")
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
print("Done.")
else:
raise RuntimeError(f"Could not find {consolidated_pth}")
Expand Down
Loading
Loading