Skip to content

Commit

Permalink
add new dataset to pytorch frame
Browse files Browse the repository at this point in the history
  • Loading branch information
yiweny committed Oct 19, 2023
1 parent 4500e3c commit 7047a0e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import os.path as osp

from torch_frame.datasets import Mercari

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')
dataset = Mercari(root=path)
print(dataset.df)
2 changes: 2 additions & 0 deletions torch_frame/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .kdd_census_income import KDDCensusIncome
from .multimodal_text_benchmark import MultimodalTextBenchmark
from .data_frame_benchmark import DataFrameBenchmark
from .mercari import Mercari

real_world_datasets = [
'Titanic',
Expand All @@ -27,6 +28,7 @@
'KDDCensusIncome',
'MultimodalTextBenchmark',
'DataFrameBenchmark',
'Mercari',
]

synthetic_datasets = [
Expand Down
39 changes: 39 additions & 0 deletions torch_frame/datasets/mercari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os.path as osp

import pandas as pd

import torch_frame
from torch_frame.utils.split import SPLIT_TO_NUM


class Mercari(torch_frame.data.Dataset):
base_url = 'https://data.pyg.org/datasets/tables/mercari_price_suggestion/'
files = ['train', 'test', 'test_stg2']

def __init__(self, root: str):
self.dfs = dict()
col_to_stype = {
'name': torch_frame.text_embedded,
'item_condition_id': torch_frame.categorical,
'category_name': torch_frame.categorical,
'brand_name': torch_frame.categorical,
'price': torch_frame.numerical,
'shipping': torch_frame.categorical,
'item_description': torch_frame.text_embedded
}
for file in self.files:
if file == 'test':
split = 'val'
elif file == 'test_stg2':
split = 'test'
else:
split = 'train'
self.dfs[split] = pd.read_csv(
self.download_url(osp.join(self.base_url, file + '.csv'),
root))
df = pd.concat(self.dfs.values(), keys=self.dfs.keys(),
names=['split']).reset_index(level=0)
df['split'] = df['split'].map(SPLIT_TO_NUM)
df.drop(['train_id', 'test_id'], axis=1, inplace=True)
super().__init__(df, col_to_stype, target_col='price',
split_col='split')

0 comments on commit 7047a0e

Please sign in to comment.