Skip to content

Commit

Permalink
Tests huggingface download (#83)
Browse files Browse the repository at this point in the history
* Allow just str in huggingface

* Add tests to huggingface loading

* Add huggingface to test requirements

* Remove 3.9 from test suite
  • Loading branch information
HCookie authored Dec 18, 2024
1 parent 7e84e3b commit 1187d78
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-pull-request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
checks:
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]
uses: ecmwf-actions/reusable-workflows/.github/workflows/qa-pytest-pyproject.yml@v2
with:
python-version: ${{ matrix.python-version }}
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"torch",
]

optional-dependencies.all = [ "anemoi-inference[plugin]", "anemoi-utils[all]>=0.4.9" ]
optional-dependencies.all = [ "anemoi-inference[plugin,huggingface]", "anemoi-utils[all]>=0.4.9" ]

optional-dependencies.dev = [ "anemoi-inference[all,docs,plugin,tests]" ]

Expand All @@ -71,7 +71,7 @@ optional-dependencies.docs = [
optional-dependencies.huggingface = [ "huggingface-hub" ]

optional-dependencies.plugin = [ "ai-models>=0.7", "tqdm" ]
optional-dependencies.tests = [ "anemoi-datasets[all]", "hypothesis", "pytest" ]
optional-dependencies.tests = [ "anemoi-datasets[all]", "anemoi-inference[all]", "hypothesis", "pytest" ]

urls.Documentation = "https://anemoi-inference.readthedocs.io/"
urls.Homepage = "https://github.com/ecmwf/anemoi-inference/"
Expand Down
26 changes: 13 additions & 13 deletions src/anemoi/inference/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
LOG = logging.getLogger(__name__)


def _download_huggingfacehub(huggingface_config):
def _download_huggingfacehub(huggingface_config) -> str:
"""Download model from huggingface"""
try:
from huggingface_hub import hf_hub_download
Expand All @@ -34,17 +34,17 @@ def _download_huggingfacehub(huggingface_config):
huggingface_config = {"repo_id": huggingface_config}

if "filename" in huggingface_config:
config_path = hf_hub_download(**huggingface_config)
return str(hf_hub_download(**huggingface_config))

repo_path = Path(snapshot_download(**huggingface_config))
ckpt_files = list(repo_path.glob("*.ckpt"))

if len(ckpt_files) == 1:
return str(ckpt_files[0])
else:
repo_path = Path(snapshot_download(**huggingface_config))
ckpt_files = list(repo_path.glob("*.ckpt"))
if len(ckpt_files) == 1:
return str(ckpt_files[0])
else:
ValueError(
f"Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
)
return config_path
raise ValueError(
f"None or Multiple ckpt files found in repo, {ckpt_files}.\nCannot pick one to load, please specify `filename`."
)


class Checkpoint:
Expand All @@ -58,7 +58,7 @@ def __repr__(self):
return f"Checkpoint({self.path})"

@cached_property
def path(self):
def path(self) -> str:
import json

try:
Expand All @@ -67,7 +67,7 @@ def path(self):
path = self._path

if isinstance(path, (Path, str)):
return path
return str(path)
elif isinstance(path, dict):
if "huggingface" in path:
return _download_huggingfacehub(path["huggingface"])
Expand Down
Empty file added tests/__init__.py
Empty file.
Empty file added tests/checkpoint/__init__.py
Empty file.
78 changes: 78 additions & 0 deletions tests/checkpoint/test_huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


from unittest.mock import patch

import pytest

import anemoi.inference.checkpoint
from anemoi.inference.runner import Runner

from ..metadata.fake_metadata import FakeMetadata


@pytest.fixture(scope="session")
def fake_huggingface_repo(tmp_path_factory):
"""Create a fake huggingface repo download"""
tmp_dir = tmp_path_factory.mktemp("repo")
fn = tmp_dir / "model.ckpt"
fn.write_text("TESTING", encoding="utf-8")
return tmp_dir


@pytest.fixture(scope="session")
def fake_huggingface_ckpt(tmp_path_factory):
"""Create a fake huggingface ckpt download"""
tmp_dir = tmp_path_factory.mktemp("repo")
fn = tmp_dir / "model.ckpt"
fn.write_text("TESTING", encoding="utf-8")
return fn


@patch("huggingface_hub.snapshot_download")
@pytest.mark.parametrize("ckpt", ["organisation/test_repo"])
def test_huggingface_repo_download_str(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_repo

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(repo_id=ckpt)


@patch("huggingface_hub.snapshot_download")
@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo"}])
def test_huggingface_repo_download_dict(huggingface_mock, monkeypatch, ckpt, fake_huggingface_repo):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_repo

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_repo / "model.ckpt")

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(**ckpt)


@patch("huggingface_hub.hf_hub_download")
@pytest.mark.parametrize("ckpt", [{"repo_id": "organisation/test_repo", "filename": "model.ckpt"}])
def test_huggingface_file_download(huggingface_mock, monkeypatch, ckpt, fake_huggingface_ckpt):

monkeypatch.setattr(anemoi.inference.checkpoint.Checkpoint, "_metadata", FakeMetadata())
huggingface_mock.return_value = fake_huggingface_ckpt

runner = Runner({"huggingface": ckpt})
assert runner.checkpoint.path == str(fake_huggingface_ckpt)

assert huggingface_mock.called
huggingface_mock.assert_called_once_with(**ckpt)
3 changes: 3 additions & 0 deletions tests/metadata/fake_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
class FakeMetadata:
def __getattr__(self, name):
return None

0 comments on commit 1187d78

Please sign in to comment.