Skip to content

Commit

Permalink
Improve hf migration (#398)
Browse files Browse the repository at this point in the history
* Update the migration script to be more robust against failure

* Update HF script to instead use bulk fetches and updates

* Add a timeout so pre-commit passes

* Remove hard-coded token
  • Loading branch information
PGijsbers authored Nov 22, 2024
1 parent adf4e0c commit 1bc917a
Showing 1 changed file with 71 additions and 27 deletions.
98 changes: 71 additions & 27 deletions scripts/migrate_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
To be run once (around sometime Nov 2024), likely not needed after that. See also #385, 392.
"""
import logging
import os
import string
from http import HTTPStatus
import time
from pathlib import Path

from sqlalchemy import select
from database.session import DbSession, EngineSingleton
Expand All @@ -22,44 +25,85 @@
import database.setup

import requests
import json

import re
from http import HTTPStatus


def fetch_huggingface_metadata() -> list[dict]:
next_url = "https://huggingface.co/api/datasets"
datasets = []
while next_url:
logging.info(f"Counted {len(datasets)} so far.")
if token := os.environ.get("HUGGINGFACE_TOKEN"):
headers = {"Authorization": f"Bearer {token}"}
else:
headers = {}
response = requests.get(
next_url,
params={"limit": 1000, "full": "False"},
headers=headers,
timeout=20,
)
if response.status_code != HTTPStatus.OK:
logging.info("Stopping iteration", response.status_code, response.json())
break

datasets.extend(response.json())

next_info = response.headers.get("Link", "")
if next_url_match := re.search(r"<([^>]+)>", next_info):
next_url = next_url_match.group()[1:-1]
else:
next_url = None
return datasets


def load_id_map():
HF_DATA_FILE = Path(__file__).parent / "hf_metadata.json"
if HF_DATA_FILE.exists():
logging.info(f"Loading HF data from {HF_DATA_FILE}.")
with open(HF_DATA_FILE, "r") as fh:
hf_data = json.load(fh)
else:
logging.info("Fetching HF data from Hugging Face.")
hf_data = fetch_huggingface_metadata()
with open(HF_DATA_FILE, "w") as fh:
json.dump(hf_data, fh)
id_map = {data["id"]: data["_id"] for data in hf_data}
return id_map


def main():
logging.basicConfig(level=logging.INFO)
AIoDConcept.metadata.create_all(EngineSingleton().engine, checkfirst=True)
id_map = load_id_map()

with DbSession() as session:
datasets_query = select(Dataset).where(Dataset.platform == PlatformName.huggingface)
datasets = session.scalars(datasets_query).all()

logging.info(f"Found {len(datasets)} huggingface datasets.")
is_old_style_identifier = lambda identifier: any(
char not in string.hexdigits for char in identifier
)
datasets = [
dataset
for dataset in datasets
if is_old_style_identifier(dataset.platform_resource_identifier)
]
logging.info(f"Found {len(datasets)} huggingface datasets that need an update.")

with DbSession() as session:
for dataset in datasets:
if all(c in string.hexdigits for c in dataset.platform_resource_identifier):
continue # entry already updated to use new-style id

response = requests.get(
f"https://huggingface.co/api/datasets/{dataset.name}",
params={"full": "False"},
headers={},
timeout=10,
)
if response.status_code != HTTPStatus.OK:
logging.warning(f"Dataset {dataset.name} could not be retrieved.")
continue

dataset_json = response.json()
if dataset.platform_resource_identifier != dataset_json["id"]:
logging.info(
f"Dataset {dataset.platform_resource_identifier} moved to {dataset_json['id']}"
"Deleting the old entry. The new entry either already exists or"
"will be added on a later synchronization invocation."
)
if new_id := id_map.get(dataset.platform_resource_identifier):
dataset.platform_resource_identifier = new_id
session.add(dataset)
else:
session.delete(dataset)
continue

persistent_id = dataset_json["_id"]
logging.info(
f"Setting platform id of {dataset.platform_resource_identifier} to {persistent_id}"
)
dataset.platform_resource_identifier = persistent_id
session.commit()
logging.info("Done updating entries.")


if __name__ == "__main__":
Expand Down

0 comments on commit 1bc917a

Please sign in to comment.