From 212c2935b4edbad77a806f4190f22439410cff3f Mon Sep 17 00:00:00 2001 From: Vincent Maladiere Date: Mon, 20 Jan 2025 19:15:51 +0100 Subject: [PATCH] wip --- skrub/datasets/__init__.py | 7 ++- skrub/datasets/_fetching.py | 3 +- skrub/datasets/_utils.py | 33 +++++------ skrub/datasets/_zip_datasets.py | 100 ++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 23 deletions(-) create mode 100644 skrub/datasets/_zip_datasets.py diff --git a/skrub/datasets/__init__.py b/skrub/datasets/__init__.py index 89641aa56..83ab218b9 100644 --- a/skrub/datasets/__init__.py +++ b/skrub/datasets/__init__.py @@ -9,8 +9,9 @@ fetch_toxicity, fetch_traffic_violations, ) -from ._utils import get_data_dir from ._generating import make_deduplication_data +from ._utils import get_data_dir + # from ._ken_embeddings import ( # fetch_ken_embeddings, # fetch_ken_table_aliases, @@ -25,10 +26,10 @@ "fetch_open_payments", "fetch_road_safety", "fetch_traffic_violations", - #"fetch_world_bank_indicator", + # "fetch_world_bank_indicator", "fetch_credit_fraud", "fetch_toxicity", - #"fetch_movielens", + # "fetch_movielens", "get_data_dir", "make_deduplication_data", # "fetch_ken_embeddings", diff --git a/skrub/datasets/_fetching.py b/skrub/datasets/_fetching.py index 22cedd59d..1d6b12290 100644 --- a/skrub/datasets/_fetching.py +++ b/skrub/datasets/_fetching.py @@ -37,7 +37,6 @@ } - def fetch_employee_salaries(data_home=None): """Fetches the employee salaries dataset (regression), available at \ https://openml.org/d/42125 @@ -142,7 +141,7 @@ def fetch_open_payments( data_id=OPEN_PAYMENTS_ID, data_home=data_directory, target_column="status", - return_X_y=return_X_y + return_X_y=return_X_y, ) diff --git a/skrub/datasets/_utils.py b/skrub/datasets/_utils.py index 830dae656..745c42ebe 100644 --- a/skrub/datasets/_utils.py +++ b/skrub/datasets/_utils.py @@ -1,11 +1,12 @@ -import json import hashlib -import pandas as pd +import json import shutil import time import warnings -import requests from pathlib import Path + +import pandas as pd +import requests from sklearn.utils import Bunch DATASET_INFO = { @@ -94,23 +95,22 @@ def load_dataset(dataset_name, data_home=None): data_home = get_data_home(data_home) dataset_dir = data_home / dataset_name datafiles_dir = dataset_dir / dataset_name - + if not datafiles_dir.exists() or not any(datafiles_dir.iterdir()): extract_archive(dataset_dir) - + bunch = Bunch() - for file_path in dataset_dir.iterdir(): + for file_path in dataset_dir.iterdir(): if file_path.suffix == ".csv": bunch[file_path.stem] = pd.read_csv(file_path) elif file_path.suffix == ".json": metadata_key = f"{file_path.stem}_metadata" bunch[metadata_key] = json.loads(file_path.read_text(), "utf-8") - + return bunch - -def extract_archive(dataset_dir): +def extract_archive(dataset_dir): dataset_name = dataset_dir.name archive_path = dataset_dir / f"{dataset_name}.zip" if not archive_path.exists(): @@ -121,11 +121,10 @@ def extract_archive(dataset_dir): def download_archive(dataset_name, archive_path, retry=3, delay=1, timeout=30): - metadata = DATASET_INFO[dataset_name] error_flag = False - while True: + for _ in range(retry): for target_url in metadata["urls"]: r = requests.get(target_url, timeout=timeout) try: @@ -142,14 +141,12 @@ def download_archive(dataset_name, archive_path, retry=3, delay=1, timeout=30): "The file has been updated, please update your skrub version." ) break - - if not retry: - raise OSError( - f"Can't download the file {dataset_name} from urls {metadata['urls']}." - ) - time.sleep(delay) - retry -= 1 timeout *= 2 + else: + raise OSError( + f"Can't download the file {dataset_name} from urls {metadata['urls']}." + ) + archive_path.write_bytes(r.content) diff --git a/skrub/datasets/_zip_datasets.py b/skrub/datasets/_zip_datasets.py new file mode 100644 index 000000000..a5dbfd7ba --- /dev/null +++ b/skrub/datasets/_zip_datasets.py @@ -0,0 +1,100 @@ +import argparse +import datetime +import hashlib +import json +import shutil +from pathlib import Path + +from skrub import datasets + + +def create_archive( + all_datasets_dir, all_archives_dir, dataset_name, dataframes, metadata +): + print(dataset_name) + dataset_dir = all_datasets_dir / dataset_name + dataset_dir.mkdir(parents=True) + (dataset_dir / "metadata.json").write_text(json.dumps(metadata), "utf-8") + for stem, df in dataframes.items(): + csv_path = dataset_dir / f"{stem}.csv" + df.to_csv(csv_path, index=False) + archive_path = all_archives_dir / dataset_name + result = shutil.make_archive( + archive_path, + "zip", + root_dir=all_datasets_dir, + base_dir=dataset_name, + ) + result = Path(result) + checksum = hashlib.sha256(result.read_bytes()).hexdigest() + return checksum + + +def load_simple_dataset(fetcher): + dataset = fetcher() + df = dataset.X + df[dataset.target] = dataset.y + name = fetcher.__name__.removeprefix("fetch_") + return ( + name, + {name: df}, + { + "name": dataset.name, + "description": dataset.description, + "source": dataset.source, + "target": dataset.target, + }, + ) + + +def iter_datasets(): + simple_fetchers = {f for f in datasets.__all__ if f.startswith("fetch_")} - { + "fetch_world_bank_indicator", + "fetch_figshare", + "fetch_credit_fraud", + "fetch_ken_embeddings", + "fetch_ken_table_aliases", + "fetch_ken_types", + } + for fetcher in sorted(simple_fetchers): + yield load_simple_dataset(getattr(datasets, fetcher)) + + +def make_skrub_datasets(): + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output_dir", + default=None, + help="where to store the output. a subdirectory containing all the archives will be created", + ) + args = parser.parse_args() + + if args.output_dir is None: + output_dir = Path.cwd() + else: + output_dir = Path(args.output_dir).resolve() + + root_dir = ( + output_dir / f"skrub_datasets_{datetime.datetime.now():%Y-%m-%dT%H-%M%S}" + ) + root_dir.mkdir(parents=True) + all_datasets_dir = root_dir / "datasets" + all_datasets_dir.mkdir() + all_archives_dir = root_dir / "archives" + all_archives_dir.mkdir() + + print(f"saving output in {root_dir}") + + checksums = {} + for dataset_name, dataframes, metadata in iter_datasets(): + checksums[dataset_name] = create_archive( + all_datasets_dir, all_archives_dir, dataset_name, dataframes, metadata + ) + + (all_archives_dir / "checksums.json").write_text(json.dumps(checksums), "utf-8") + print(f"archive files saved in {all_archives_dir}") + + +if __name__ == "__main__": + make_skrub_datasets() \ No newline at end of file