From 7047a0ea32f0768e5bdf982f8242697dd366bbad Mon Sep 17 00:00:00 2001 From: yiweny Date: Thu, 19 Oct 2023 21:33:12 +0000 Subject: [PATCH] add new dataset to pytorch frame --- examples/test.py | 7 ++++++ torch_frame/datasets/__init__.py | 2 ++ torch_frame/datasets/mercari.py | 39 ++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 examples/test.py create mode 100644 torch_frame/datasets/mercari.py diff --git a/examples/test.py b/examples/test.py new file mode 100644 index 00000000..382297ff --- /dev/null +++ b/examples/test.py @@ -0,0 +1,7 @@ +import os.path as osp + +from torch_frame.datasets import Mercari + +path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data') +dataset = Mercari(root=path) +print(dataset.df) diff --git a/torch_frame/datasets/__init__.py b/torch_frame/datasets/__init__.py index 6ac0fe36..672c8441 100644 --- a/torch_frame/datasets/__init__.py +++ b/torch_frame/datasets/__init__.py @@ -13,6 +13,7 @@ from .kdd_census_income import KDDCensusIncome from .multimodal_text_benchmark import MultimodalTextBenchmark from .data_frame_benchmark import DataFrameBenchmark +from .mercari import Mercari real_world_datasets = [ 'Titanic', @@ -27,6 +28,7 @@ 'KDDCensusIncome', 'MultimodalTextBenchmark', 'DataFrameBenchmark', + 'Mercari', ] synthetic_datasets = [ diff --git a/torch_frame/datasets/mercari.py b/torch_frame/datasets/mercari.py new file mode 100644 index 00000000..aeedd4bd --- /dev/null +++ b/torch_frame/datasets/mercari.py @@ -0,0 +1,39 @@ +import os.path as osp + +import pandas as pd + +import torch_frame +from torch_frame.utils.split import SPLIT_TO_NUM + + +class Mercari(torch_frame.data.Dataset): + base_url = 'https://data.pyg.org/datasets/tables/mercari_price_suggestion/' + files = ['train', 'test', 'test_stg2'] + + def __init__(self, root: str): + self.dfs = dict() + col_to_stype = { + 'name': torch_frame.text_embedded, + 'item_condition_id': torch_frame.categorical, + 'category_name': torch_frame.categorical, + 'brand_name': torch_frame.categorical, + 'price': torch_frame.numerical, + 'shipping': torch_frame.categorical, + 'item_description': torch_frame.text_embedded + } + for file in self.files: + if file == 'test': + split = 'val' + elif file == 'test_stg2': + split = 'test' + else: + split = 'train' + self.dfs[split] = pd.read_csv( + self.download_url(osp.join(self.base_url, file + '.csv'), + root)) + df = pd.concat(self.dfs.values(), keys=self.dfs.keys(), + names=['split']).reset_index(level=0) + df['split'] = df['split'].map(SPLIT_TO_NUM) + df.drop(['train_id', 'test_id'], axis=1, inplace=True) + super().__init__(df, col_to_stype, target_col='price', + split_col='split')