Skip to content

Commit

Permalink
More misc cleanup (#225)
Browse files Browse the repository at this point in the history
  • Loading branch information
rishabh-ranjan authored Jul 5, 2024
1 parent 9c82a15 commit 0a25d06
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 25 deletions.
3 changes: 2 additions & 1 deletion examples/gnn_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@

eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {}
for split in ["val", "test"]:
seed_time = task.val_seed_time if split == "val" else task.test_seed_time
timestamp = dataset.val_timestamp if split == "val" else dataset.test_timestamp
seed_time = int(timestamp.timestamp())
target_table = task.get_table(split)
src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values)
src_loader = NeighborLoader(
Expand Down
3 changes: 0 additions & 3 deletions relbench/base/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ class Database:
r"""A database is a collection of named tables linked by foreign key -
primary key connections."""

# TODO: maybe add a function to visualize schema in jupyter

def __init__(self, table_dict: Dict[str, Table]) -> None:
r"""Creates a database from a dictionary of tables."""

self.table_dict = table_dict

def __repr__(self) -> str:
# TODO: add more info
return f"{self.__class__.__name__}()"

def save(self, path: Union[str, os.PathLike]) -> None:
Expand Down
2 changes: 0 additions & 2 deletions relbench/base/task_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def make_table(
) -> Table:
r"""To be implemented by subclass."""

# TODO: ensure that tasks follow the right-closed convention

raise NotImplementedError

def _get_table(self, split: str) -> Table:
Expand Down
15 changes: 2 additions & 13 deletions relbench/base/task_link.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import pandas as pd
from numpy.typing import NDArray

# TODO: remove!
from ..modeling.utils import to_unix_time
from .dataset import Dataset
from .table import Table
from .task_base import BaseTask, TaskType
Expand Down Expand Up @@ -89,7 +87,6 @@ def evaluate(

return {fn.__name__: fn(pred_isin, dst_count) for fn in metrics}

# TODO: should these be here? seed_time is confusing terminology?
@property
def num_src_nodes(self) -> int:
return len(self.dataset.get_db().table_dict[self.src_entity_table])
Expand All @@ -98,15 +95,7 @@ def num_src_nodes(self) -> int:
def num_dst_nodes(self) -> int:
return len(self.dataset.get_db().table_dict[self.dst_entity_table])

@property
def val_seed_time(self) -> int:
return to_unix_time(pd.Series([self.dataset.val_timestamp]))[0]

@property
def test_seed_time(self) -> int:
return to_unix_time(pd.Series([self.dataset.test_timestamp]))[0]

def stats(self) -> dict[str, dict[str, int]]:
def stats(self) -> Dict[str, Dict[str, int]]:
r"""Get train / val / test table statistics for each timestamp
and the whole table, including number of unique source entities,
number of unique destination entities, number of destination
Expand Down Expand Up @@ -177,7 +166,7 @@ def stats(self) -> dict[str, dict[str, int]]:
] = ratio_train_test_entity_overlap
return res

def _get_stats(self, df: pd.DataFrame) -> list[int]:
def _get_stats(self, df: pd.DataFrame) -> List[int]:
num_unique_src_entities = df[self.src_entity_col].nunique()
num_unique_dst_entities = len(
set(value for row in df[self.dst_entity_col] for value in row)
Expand Down
10 changes: 5 additions & 5 deletions relbench/datasets/avito.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,15 @@


class AvitoDataset(Dataset):
url = "https://www.kaggle.com/competitions/avito-context-ad-clicks"
err_msg = (
"{data} not found. Please download avito data from "
"'{url}' and move it to '{path}'."
)
"""Original data source:
https://www.kaggle.com/competitions/avito-context-ad-clicks"""

# search stream ranges from 2015-04-25 to 2015-05-20
val_timestamp = pd.Timestamp("2015-05-08")
test_timestamp = pd.Timestamp("2015-05-14")

def make_db(self) -> Database:
# subsampled version of the original dataset
# Customize path as necessary
r"""Process the raw files into a database."""
url = "https://relbench.stanford.edu/data/rel-avito-raw-100k.zip"
Expand Down Expand Up @@ -69,6 +67,8 @@ def make_db(self) -> Database:
)
visit_stream_df = clean_datetime(visit_stream_df, "ViewDate")

category_df.drop(columns=["__index_level_0__"], inplace=True)

tables = {}
tables["AdsInfo"] = Table(
df=ads_info_df,
Expand Down
3 changes: 2 additions & 1 deletion test/modeling/test_link_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def test_link_train_fake_product_dataset(tmp_path, share_same_time):

eval_loaders_dict: Dict[str, Tuple[NeighborLoader, NeighborLoader]] = {}
for split in ["val", "test"]:
seed_time = task.val_seed_time if split == "val" else task.test_seed_time
timestamp = dataset.val_timestamp if split == "val" else dataset.test_timestamp
seed_time = int(timestamp.timestamp())
target_table = task.get_table(split)
src_node_indices = torch.from_numpy(target_table.df[task.src_entity_col].values)
src_loader = NeighborLoader(
Expand Down

0 comments on commit 0a25d06

Please sign in to comment.