Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent-Maladiere committed Jan 20, 2025
1 parent 7d21d88 commit 212c293
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 23 deletions.
7 changes: 4 additions & 3 deletions skrub/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
fetch_toxicity,
fetch_traffic_violations,
)
from ._utils import get_data_dir
from ._generating import make_deduplication_data
from ._utils import get_data_dir

# from ._ken_embeddings import (
# fetch_ken_embeddings,
# fetch_ken_table_aliases,
Expand All @@ -25,10 +26,10 @@
"fetch_open_payments",
"fetch_road_safety",
"fetch_traffic_violations",
#"fetch_world_bank_indicator",
# "fetch_world_bank_indicator",
"fetch_credit_fraud",
"fetch_toxicity",
#"fetch_movielens",
# "fetch_movielens",
"get_data_dir",
"make_deduplication_data",
# "fetch_ken_embeddings",
Expand Down
3 changes: 1 addition & 2 deletions skrub/datasets/_fetching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
}



def fetch_employee_salaries(data_home=None):
"""Fetches the employee salaries dataset (regression), available at \
https://openml.org/d/42125
Expand Down Expand Up @@ -142,7 +141,7 @@ def fetch_open_payments(
data_id=OPEN_PAYMENTS_ID,
data_home=data_directory,
target_column="status",
return_X_y=return_X_y
return_X_y=return_X_y,
)


Expand Down
33 changes: 15 additions & 18 deletions skrub/datasets/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import hashlib
import pandas as pd
import json
import shutil
import time
import warnings
import requests
from pathlib import Path

import pandas as pd
import requests
from sklearn.utils import Bunch

DATASET_INFO = {
Expand Down Expand Up @@ -94,23 +95,22 @@ def load_dataset(dataset_name, data_home=None):
data_home = get_data_home(data_home)
dataset_dir = data_home / dataset_name
datafiles_dir = dataset_dir / dataset_name

if not datafiles_dir.exists() or not any(datafiles_dir.iterdir()):
extract_archive(dataset_dir)

bunch = Bunch()
for file_path in dataset_dir.iterdir():
for file_path in dataset_dir.iterdir():
if file_path.suffix == ".csv":
bunch[file_path.stem] = pd.read_csv(file_path)
elif file_path.suffix == ".json":
metadata_key = f"{file_path.stem}_metadata"
bunch[metadata_key] = json.loads(file_path.read_text(), "utf-8")

return bunch


def extract_archive(dataset_dir):

def extract_archive(dataset_dir):
dataset_name = dataset_dir.name
archive_path = dataset_dir / f"{dataset_name}.zip"
if not archive_path.exists():
Expand All @@ -121,11 +121,10 @@ def extract_archive(dataset_dir):


def download_archive(dataset_name, archive_path, retry=3, delay=1, timeout=30):

metadata = DATASET_INFO[dataset_name]
error_flag = False

while True:
for _ in range(retry):
for target_url in metadata["urls"]:
r = requests.get(target_url, timeout=timeout)
try:
Expand All @@ -142,14 +141,12 @@ def download_archive(dataset_name, archive_path, retry=3, delay=1, timeout=30):
"The file has been updated, please update your skrub version."
)
break

if not retry:
raise OSError(
f"Can't download the file {dataset_name} from urls {metadata['urls']}."
)

time.sleep(delay)
retry -= 1
timeout *= 2

else:
raise OSError(
f"Can't download the file {dataset_name} from urls {metadata['urls']}."
)

archive_path.write_bytes(r.content)
100 changes: 100 additions & 0 deletions skrub/datasets/_zip_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import argparse
import datetime
import hashlib
import json
import shutil
from pathlib import Path

from skrub import datasets


def create_archive(
all_datasets_dir, all_archives_dir, dataset_name, dataframes, metadata
):
print(dataset_name)
dataset_dir = all_datasets_dir / dataset_name
dataset_dir.mkdir(parents=True)
(dataset_dir / "metadata.json").write_text(json.dumps(metadata), "utf-8")
for stem, df in dataframes.items():
csv_path = dataset_dir / f"{stem}.csv"
df.to_csv(csv_path, index=False)
archive_path = all_archives_dir / dataset_name
result = shutil.make_archive(
archive_path,
"zip",
root_dir=all_datasets_dir,
base_dir=dataset_name,
)
result = Path(result)
checksum = hashlib.sha256(result.read_bytes()).hexdigest()
return checksum


def load_simple_dataset(fetcher):
dataset = fetcher()
df = dataset.X
df[dataset.target] = dataset.y
name = fetcher.__name__.removeprefix("fetch_")
return (
name,
{name: df},
{
"name": dataset.name,
"description": dataset.description,
"source": dataset.source,
"target": dataset.target,
},
)


def iter_datasets():
simple_fetchers = {f for f in datasets.__all__ if f.startswith("fetch_")} - {
"fetch_world_bank_indicator",
"fetch_figshare",
"fetch_credit_fraud",
"fetch_ken_embeddings",
"fetch_ken_table_aliases",
"fetch_ken_types",
}
for fetcher in sorted(simple_fetchers):
yield load_simple_dataset(getattr(datasets, fetcher))


def make_skrub_datasets():
parser = argparse.ArgumentParser()
parser.add_argument(
"-o",
"--output_dir",
default=None,
help="where to store the output. a subdirectory containing all the archives will be created",
)
args = parser.parse_args()

if args.output_dir is None:
output_dir = Path.cwd()
else:
output_dir = Path(args.output_dir).resolve()

root_dir = (
output_dir / f"skrub_datasets_{datetime.datetime.now():%Y-%m-%dT%H-%M%S}"
)
root_dir.mkdir(parents=True)
all_datasets_dir = root_dir / "datasets"
all_datasets_dir.mkdir()
all_archives_dir = root_dir / "archives"
all_archives_dir.mkdir()

print(f"saving output in {root_dir}")

checksums = {}
for dataset_name, dataframes, metadata in iter_datasets():
checksums[dataset_name] = create_archive(
all_datasets_dir, all_archives_dir, dataset_name, dataframes, metadata
)

(all_archives_dir / "checksums.json").write_text(json.dumps(checksums), "utf-8")
print(f"archive files saved in {all_archives_dir}")


if __name__ == "__main__":
make_skrub_datasets()

0 comments on commit 212c293

Please sign in to comment.