From 1eebd233fc30d03e753bbdb08609c3cf80c798a8 Mon Sep 17 00:00:00 2001 From: Harrison Date: Mon, 9 Dec 2024 23:28:46 +0000 Subject: [PATCH] Improve huggingface loading - Allow just `repo_id` --- src/anemoi/inference/checkpoint.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/anemoi/inference/checkpoint.py b/src/anemoi/inference/checkpoint.py index 9410aca..cdf34bf 100644 --- a/src/anemoi/inference/checkpoint.py +++ b/src/anemoi/inference/checkpoint.py @@ -11,6 +11,7 @@ import datetime import logging from collections import defaultdict +from pathlib import Path from functools import cached_property from anemoi.utils.checkpoints import load_metadata @@ -24,11 +25,19 @@ def _download_huggingfacehub(huggingface_config): """Download model from huggingface""" try: - from huggingface_hub import hf_hub_download + from huggingface_hub import hf_hub_download, snapshot_download except ImportError as e: raise ImportError("Could not import `huggingface_hub`, please run `pip install huggingface_hub`.") from e - config_path = hf_hub_download(**huggingface_config) + if 'filename' in huggingface_config: + config_path = hf_hub_download(**huggingface_config) + else: + repo_path = Path(snapshot_download(**huggingface_config)) + ckpt_files = list(repo_path.glob('*.ckpt')) + if len(ckpt_files) == 1: + return str(ckpt_files[0]) + else: + ValueError(f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`.") return config_path