Skip to content

Commit

Permalink
Download huggingface models to huggingface cache instead of ~/.torchchat
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri committed Oct 9, 2024
1 parent 6a2a2e8 commit f045cfe
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 84 deletions.
10 changes: 5 additions & 5 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE

from torchchat.cli.download import get_model_dir
from torchchat.model_config.model_config import resolve_model_config
from torchchat.utils.build_utils import (
device_sync,
Expand Down Expand Up @@ -73,7 +74,7 @@ def __post_init__(self):
or (self.pte_path and Path(self.pte_path).is_file())
):
raise RuntimeError(
"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path"
f"need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path {self.checkpoint_path}"
)

if self.dso_path and self.pte_path:
Expand Down Expand Up @@ -109,10 +110,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
model_config = resolve_model_config(args.model)

checkpoint_path = (
Path(args.model_directory)
/ model_config.name
get_model_dir(model_config, args.model_directory)
/ model_config.checkpoint_file
)
print(f"Using checkpoint path: {checkpoint_path}")
# The transformers config is keyed on the last section
# of the name/path.
params_table = (
Expand Down Expand Up @@ -264,8 +265,7 @@ def from_args(cls, args: argparse.Namespace) -> "TokenizerArgs":
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = (
Path(args.model_directory)
/ model_config.name
get_model_dir(model_config, args.model_directory)
/ model_config.tokenizer_file
)
elif args.checkpoint_path:
Expand Down
2 changes: 1 addition & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _add_jit_downloading_args(parser) -> None:
"--model-directory",
type=Path,
default=default_model_dir,
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}",
help=f"The directory to store downloaded model artifacts. Default: {default_model_dir}. This is overriden by the huggingface cache directory if the model is downloaded from HuggingFace.",
)


Expand Down
74 changes: 18 additions & 56 deletions torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
import json
import os
import re
Expand Down Expand Up @@ -42,12 +41,7 @@ def convert_hf_checkpoint(
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
model_map_json_matches = [Path(m) for m in glob.glob(str(model_dir / "*.index.json"))]
assert len(model_map_json_matches) <= 1, "Found multiple weight mapping files"
if len(model_map_json_matches):
model_map_json = model_map_json_matches[0]
else:
model_map_json = model_dir / "pytorch_model.bin.index.json"
model_map_json = model_dir / "pytorch_model.bin.index.json"

# If there is no weight mapping, check for a consolidated model and
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
Expand All @@ -62,9 +56,10 @@ def convert_hf_checkpoint(
str(consolidated_pth), map_location="cpu", mmap=True, weights_only=True
)
del loaded_result # No longer needed
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
print(f"Symlinking checkpoint to {model_dir / 'model.pth'}.")
consolidated_pth = os.path.realpath(consolidated_pth)
os.symlink(consolidated_pth, model_dir / "model.pth")
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
print("Done.")
return
else:
Expand All @@ -81,17 +76,10 @@ def convert_hf_checkpoint(
"model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight",
"model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight",
"model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight",
"model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias",
"model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias",
"model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias",
"model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias",
"model.layers.{}.self_attn.rotary_emb.inv_freq": None,
"model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight",
"model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight",
"model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight",
"model.layers.{}.mlp.gate_proj.bias": "layers.{}.feed_forward.w1.bias",
"model.layers.{}.mlp.up_proj.bias": "layers.{}.feed_forward.w3.bias",
"model.layers.{}.mlp.down_proj.bias": "layers.{}.feed_forward.w2.bias",
"model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight",
"model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight",
"model.norm.weight": "norm.weight",
Expand All @@ -100,43 +88,19 @@ def convert_hf_checkpoint(
bin_files = {model_dir / bin for bin in bin_index["weight_map"].values()}

def permute(w, n_heads):
dim = config.dim
return (
w.view(n_heads, 2, config.head_dim // 2, *w.shape[1:])
w.view(n_heads, 2, config.head_dim // 2, dim)
.transpose(1, 2)
.reshape(w.shape)
.reshape(config.head_dim * n_heads, dim)
)

merged_result = {}
for file in sorted(bin_files):

# The state_dict can be loaded from either a torch zip file or
# safetensors. We take our best guess from the name and try all
# possibilities
load_pt_mmap = lambda: torch.load(
state_dict = torch.load(
str(file), map_location="cpu", mmap=True, weights_only=True
)
load_pt_no_mmap = lambda: torch.load(
str(file), map_location="cpu", mmap=False, weights_only=True
)
def load_safetensors():
import safetensors.torch
with open(file, "rb") as handle:
return safetensors.torch.load(handle.read())
if "safetensors" in str(file):
loaders = [load_safetensors, load_pt_mmap, load_pt_no_mmap]
else:
loaders = [load_pt_mmap, load_pt_no_mmap, load_safetensors]

state_dict = None
for loader in loaders:
try:
state_dict = loader()
break
except Exception:
continue
assert state_dict is not None, f"Unable to load tensors from {file}"
merged_result.update(state_dict)

final_result = {}
for key, value in merged_result.items():
if "layers" in key:
Expand All @@ -152,18 +116,16 @@ def load_safetensors():
final_result[new_key] = value

for key in tuple(final_result.keys()):
if "wq.weight" in key or "wq.bias" in key:
wk_key = key.replace("wq", "wk")
wv_key = key.replace("wq", "wv")
if "wq" in key:
q = final_result[key]
k = final_result[wk_key]
v = final_result[wv_key]
k = final_result[key.replace("wq", "wk")]
v = final_result[key.replace("wq", "wv")]
q = permute(q, config.n_heads)
k = permute(k, config.n_local_heads)
final_result[key.replace("wq", "wqkv")] = torch.cat([q, k, v])
del final_result[key]
del final_result[wk_key]
del final_result[wv_key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {model_dir / 'model.pth'}. This may take a while.")
torch.save(final_result, model_dir / "model.pth")
print("Done.")
Expand All @@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune(
consolidated_pth = model_dir / "original" / "consolidated.pth"
tokenizer_pth = model_dir / "original" / "tokenizer.model"
if consolidated_pth.is_file() and tokenizer_pth.is_file():
print(f"Moving checkpoint to {model_dir / 'model.pth'}.")
os.rename(consolidated_pth, model_dir / "model.pth")
print(f"Moving tokenizer to {model_dir / 'tokenizer.model'}.")
os.rename(tokenizer_pth, model_dir / "tokenizer.model")
print(f"Creating symlink from {consolidated_pth} to {model_dir / 'model.pth'}.")
os.symlink(consolidated_pth, model_dir / "model.pth")
print(f"Creating symlink from {tokenizer_pth} to {model_dir / 'tokenizer.model'}.")
os.symlink(tokenizer_pth, model_dir / "tokenizer.model")
print("Done.")
else:
raise RuntimeError(f"Could not find {consolidated_pth}")
Expand Down
81 changes: 62 additions & 19 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,19 @@
resolve_model_config,
)

# By default, download models from HuggingFace to the Hugginface hub directory.
# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub"))))

def _download_hf_snapshot(
model_config: ModelConfig, artifact_dir: Path, hf_token: Optional[str]
model_config: ModelConfig, hf_token: Optional[str]
):
from huggingface_hub import model_info, snapshot_download
from requests.exceptions import HTTPError

# Download and store the HF model artifacts.
print(f"Downloading {model_config.name} from HuggingFace...", file=sys.stderr)
model_dir = get_model_dir(model_config, None)
print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True)
try:
# Fetch the info about the model's repo
model_info = model_info(model_config.distribution_path, token=hf_token)
Expand Down Expand Up @@ -56,8 +60,6 @@ def _download_hf_snapshot(

snapshot_download(
model_config.distribution_path,
local_dir=artifact_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns=ignore_patterns,
)
Expand All @@ -76,16 +78,20 @@ def _download_hf_snapshot(
else:
raise e

# Update the model dir to include the snapshot we just downloaded.
model_dir = get_model_dir(model_config, None)
print("Model downloaded to", model_dir)

# Convert the Multimodal Llama model to the torchtune format.
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
convert_hf_checkpoint_to_tune( model_dir=artifact_dir, model_name=model_config.name)
convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name)

else:
# Convert the model to the torchchat format.
print(f"Converting {model_config.name} to torchchat format...", file=sys.stderr)
convert_hf_checkpoint(
model_dir=artifact_dir, model_name=model_config.name, remove_bin_files=True
model_dir=model_dir, model_name=model_config.name, remove_bin_files=True
)


Expand All @@ -99,12 +105,51 @@ def _download_direct(
print(f"Downloading {url}...", file=sys.stderr)
urllib.request.urlretrieve(url, str(local_path.absolute()))

def _get_hf_artifact_dir(model_config: ModelConfig) -> Path:
"""
Returns the directory where the model artifacts are stored.
This is the root folder with blobs, refs and snapshots
"""
assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot)
return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace("/", "--")}"


def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path:
"""
Returns the directory where the model artifacts are stored.
For HuggingFace snapshots, this is the HuggingFace cache directory.
For all other distribution channels, we use the models_dir.
For CLI usage, pass in args.model_directory.
"""
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
artifact_dir = _get_hf_artifact_dir(model_config)

# If these paths doesn't exist, it means the model hasn't been downloaded yet.
if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"):
return artifact_dir
snapshot = open(artifact_dir / "refs" / "main", "r").read().strip()
return artifact_dir / "snapshots" / snapshot
else:
return models_dir / model_config.name


def download_and_convert(
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
model_config = resolve_model_config(model)
model_dir = models_dir / model_config.name
model_dir = get_model_dir(model_config, models_dir)

# HuggingFace download
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, hf_token)
return

# Direct download

# Download into a temporary directory. We'll move to the final
# location once the download and conversion is complete. This
Expand All @@ -117,11 +162,6 @@ def download_and_convert(

try:
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
_download_hf_snapshot(model_config, temp_dir, hf_token)
elif (
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
):
_download_direct(model_config, temp_dir)
Expand All @@ -144,9 +184,9 @@ def download_and_convert(

def is_model_downloaded(model: str, models_dir: Path) -> bool:
model_config = resolve_model_config(model)

# Check if the model directory exists and is not empty.
model_dir = models_dir / model_config.name
model_dir = get_model_dir(model_config, models_dir)
return os.path.isdir(model_dir) and os.listdir(model_dir)


Expand Down Expand Up @@ -194,13 +234,16 @@ def remove_main(args) -> None:
return

model_config = resolve_model_config(args.model)
model_dir = args.model_directory / model_config.name
model_dir = get_model_dir(model_config, args.model_directory)

if not os.path.isdir(model_dir):
print(f"Model {args.model} has no downloaded artifacts.")
print(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
return
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
# For HuggingFace models, we need to remove the entire root directory.
model_dir = _get_hf_artifact_dir(model_config)

print(f"Removing downloaded model artifacts for {args.model}...")
print(f"Removing downloaded model artifacts for {args.model} at {model_dir}...")
shutil.rmtree(model_dir)
print("Done.")

Expand All @@ -216,10 +259,10 @@ def where_main(args) -> None:
return

model_config = resolve_model_config(args.model)
model_dir = args.model_directory / model_config.name
model_dir = get_model_dir(model_config, args.model_directory)

if not os.path.isdir(model_dir):
raise RuntimeError(f"Model {args.model} has no downloaded artifacts.")
raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.")

print(str(os.path.abspath(model_dir)))
exit(0)
Expand Down
6 changes: 3 additions & 3 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from torchtune.models.llama3_2_vision._model_builders import llama3_2_vision_transform

from torchchat.cli.download import is_model_downloaded, load_model_configs
from torchchat.cli.download import is_model_downloaded, load_model_configs, get_model_dir
from torchchat.generate import Generator, GeneratorArgs
from torchchat.model import FlamingoModel

Expand Down Expand Up @@ -522,7 +522,7 @@ def retrieve_model_info(args, model_id: str) -> Union[ModelInfo, None]:
"""
if model_config := load_model_configs().get(model_id):
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
path = get_model_dir(model_config, args.model_directory)
created = int(os.path.getctime(path))
owned_by = getpwuid(os.stat(path).st_uid).pw_name

Expand All @@ -545,7 +545,7 @@ def get_model_info_list(args) -> ModelInfo:
data = []
for model_id, model_config in load_model_configs().items():
if is_model_downloaded(model_id, args.model_directory):
path = args.model_directory / model_config.name
path = get_model_dir(model_config, args.model_directory)
created = int(os.path.getctime(path))
owned_by = getpwuid(os.stat(path).st_uid).pw_name

Expand Down

0 comments on commit f045cfe

Please sign in to comment.