Skip to content

Commit

Permalink
Delete models from old location for huggingface download
Browse files Browse the repository at this point in the history
  • Loading branch information
vmpuri committed Oct 11, 2024
1 parent 654dbec commit 84602c8
Showing 1 changed file with 76 additions and 28 deletions.
104 changes: 76 additions & 28 deletions torchchat/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from pathlib import Path
from typing import Optional

from torchchat.cli.convert_hf_checkpoint import convert_hf_checkpoint, convert_hf_checkpoint_to_tune
from torchchat.cli.convert_hf_checkpoint import (
convert_hf_checkpoint,
convert_hf_checkpoint_to_tune,
)
from torchchat.model_config.model_config import (
load_model_configs,
ModelConfig,
Expand All @@ -20,20 +23,46 @@

# By default, download models from HuggingFace to the Hugginface hub directory.
# Both $HF_HOME and $HUGGINGFACE_HUB_CACHE are valid environment variables for the same directory.
HUGGINGFACE_HOME_PATH = Path(os.environ.get("HF_HOME", os.environ.get("HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub"))))
HUGGINGFACE_HOME_PATH = Path(
os.environ.get(
"HF_HOME",
os.environ.get(
"HUGGINGFACE_HUB_CACHE", os.path.expanduser("~/.cache/huggingface/hub")
),
)
)

if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

def _download_hf_snapshot(
model_config: ModelConfig, hf_token: Optional[str]
):

# Previously, all models were stored in the torchchat models directory (by default ~/.torchchat/model-cache)
# For Hugging Face models, we now store them in the HuggingFace cache directory.
# This function will delete all model artifacts in the old directory for each model with the Hugging Face distribution path.
def _cleanup_hf_models_from_torchchat_dir(models_dir: Path):
for model_config in load_model_configs().values():
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
if os.path.exists(models_dir / model_config.name):
print(
f"Cleaning up old model artifacts in {models_dir / model_config.name}. New artifacts will be downloaded to {HUGGINGFACE_HOME_PATH}"
)
shutil.rmtree(models_dir / model_config.name)


def _download_hf_snapshot(model_config: ModelConfig, hf_token: Optional[str]):
from huggingface_hub import model_info, snapshot_download
from requests.exceptions import HTTPError

# Download and store the HF model artifacts.
model_dir = get_model_dir(model_config, None)
print(f"Downloading {model_config.name} from Hugging Face to {model_dir}", file=sys.stderr, flush=True)
print(
f"Downloading {model_config.name} from Hugging Face to {model_dir}",
file=sys.stderr,
flush=True,
)
try:
# Fetch the info about the model's repo
model_info = model_info(model_config.distribution_path, token=hf_token)
Expand Down Expand Up @@ -81,14 +110,17 @@ def _download_hf_snapshot(
else:
raise e

# Update the model dir to include the snapshot we just downloaded.
# Update the model dir to include the snapshot we just downloaded.
model_dir = get_model_dir(model_config, None)
print("Model downloaded to", model_dir)

# Convert the Multimodal Llama model to the torchtune format.
if model_config.name in {"meta-llama/Llama-3.2-11B-Vision-Instruct", "meta-llama/Llama-3.2-11B-Vision"}:
if model_config.name in {
"meta-llama/Llama-3.2-11B-Vision-Instruct",
"meta-llama/Llama-3.2-11B-Vision",
}:
print(f"Converting {model_config.name} to torchtune format...", file=sys.stderr)
convert_hf_checkpoint_to_tune( model_dir=model_dir, model_name=model_config.name)
convert_hf_checkpoint_to_tune(model_dir=model_dir, model_name=model_config.name)

else:
# Convert the model to the torchchat format.
Expand All @@ -108,32 +140,44 @@ def _download_direct(
print(f"Downloading {url}...", file=sys.stderr)
urllib.request.urlretrieve(url, str(local_path.absolute()))


def _get_hf_artifact_dir(model_config: ModelConfig) -> Path:
"""
Returns the directory where the model artifacts are stored.
This is the root folder with blobs, refs and snapshots
"""
assert(model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot)
return HUGGINGFACE_HOME_PATH / f"models--{model_config.distribution_path.replace('/', '--')}"
assert (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
)
return (
HUGGINGFACE_HOME_PATH
/ f"models--{model_config.distribution_path.replace('/', '--')}"
)


def get_model_dir(model_config: ModelConfig, models_dir: Optional[Path]) -> Path:
"""
Returns the directory where the model artifacts are stored.
For HuggingFace snapshots, this is the HuggingFace cache directory.
For all other distribution channels, we use the models_dir.
For CLI usage, pass in args.model_directory.
Returns the directory where the model artifacts are expected to be stored.
For Hugging Face artifacts, this will be the location of the "main" snapshot if it exists, or the expected model directory otherwise.
For all other distribution channels, we use the models_dir.
For CLI usage, pass in args.model_directory.
"""
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
artifact_dir = _get_hf_artifact_dir(model_config)

if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
artifact_dir = _get_hf_artifact_dir(model_config)

# If these paths doesn't exist, it means the model hasn't been downloaded yet.
if not os.path.isdir(artifact_dir) and not os.path.isdir(artifact_dir / "snapshots"):
if not os.path.isdir(artifact_dir) and not os.path.isdir(
artifact_dir / "snapshots"
):
return artifact_dir
snapshot = open(artifact_dir / "refs" / "main", "r").read().strip()
return artifact_dir / "snapshots" / snapshot
return artifact_dir / "snapshots" / snapshot
else:
return models_dir / model_config.name

Expand Down Expand Up @@ -164,9 +208,7 @@ def download_and_convert(
os.makedirs(temp_dir, exist_ok=True)

try:
if (
model_config.distribution_channel == ModelDistributionChannel.DirectDownload
):
if model_config.distribution_channel == ModelDistributionChannel.DirectDownload:
_download_direct(model_config, temp_dir)
else:
raise RuntimeError(
Expand All @@ -187,7 +229,7 @@ def download_and_convert(

def is_model_downloaded(model: str, models_dir: Path) -> bool:
model_config = resolve_model_config(model)

# Check if the model directory exists and is not empty.
model_dir = get_model_dir(model_config, models_dir)
return os.path.isdir(model_dir) and os.listdir(model_dir)
Expand Down Expand Up @@ -242,7 +284,10 @@ def remove_main(args) -> None:
if not os.path.isdir(model_dir):
print(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
return
if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot:
if (
model_config.distribution_channel
== ModelDistributionChannel.HuggingFaceSnapshot
):
# For HuggingFace models, we need to remove the entire root directory.
model_dir = _get_hf_artifact_dir(model_config)

Expand All @@ -265,12 +310,15 @@ def where_main(args) -> None:
model_dir = get_model_dir(model_config, args.model_directory)

if not os.path.isdir(model_dir):
raise RuntimeError(f"Model {args.model} has no downloaded artifacts in {model_dir}.")
raise RuntimeError(
f"Model {args.model} has no downloaded artifacts in {model_dir}."
)

print(str(os.path.abspath(model_dir)))
exit(0)


# Subcommand to download model artifacts.
def download_main(args) -> None:
_cleanup_hf_models_from_torchchat_dir(args.model_directory)
download_and_convert(args.model, args.model_directory, args.hf_token)

0 comments on commit 84602c8

Please sign in to comment.