diff --git a/torchchat/cli/download.py b/torchchat/cli/download.py index 14c19f943..fee85bdcc 100644 --- a/torchchat/cli/download.py +++ b/torchchat/cli/download.py @@ -25,6 +25,14 @@ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) is None: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# For each model with huggingface distribution path, clean up the old location. +def _delete_old_hf_models(models_dir: Path): + for model_config in load_model_configs().values(): + if model_config.distribution_channel == ModelDistributionChannel.HuggingFaceSnapshot: + if os.path.exists(models_dir / model_config.name): + print(f"Cleaning up old model artifacts in {models_dir / model_config.name}. New artifacts will be downloaded to {HUGGINGFACE_HOME_PATH}") + shutil.rmtree(models_dir / model_config.name) + def _download_hf_snapshot( model_config: ModelConfig, hf_token: Optional[str] ): @@ -273,4 +281,5 @@ def where_main(args) -> None: # Subcommand to download model artifacts. def download_main(args) -> None: + _delete_old_hf_models(args.model_directory) download_and_convert(args.model, args.model_directory, args.hf_token)