diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index cc633ef..01119df 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -35,25 +35,25 @@ TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" -VAL_LOSS_DELTA = 0.001 +VAL_LOSS_DELTA = 0.0005 parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="rel-amazon") -parser.add_argument("--task", type=str, default="user-item-rate") +parser.add_argument("--task", type=str, default="user-item-purchase") parser.add_argument( "--model", type=str, - default="hybridgnn", + default="idgnn", choices=["hybridgnn", "idgnn", "shallowrhsgnn"], ) -parser.add_argument("--epochs", type=int, default=20) -parser.add_argument("--num_trials", type=int, default=50, +parser.add_argument("--epochs", type=int, default=10) +parser.add_argument("--num_trials", type=int, default=10, help="Number of Optuna-based hyper-parameter tuning.") parser.add_argument( "--num_repeats", type=int, default=5, help="Number of repeated training and eval on the best config.") parser.add_argument("--eval_epochs_interval", type=int, default=1) -parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--num_layers", type=int, default=6) parser.add_argument("--num_neighbors", type=int, default=128) parser.add_argument("--temporal_strategy", type=str, default="last", choices=["last", "uniform"]) @@ -107,9 +107,9 @@ if args.model == "idgnn": model_search_space = { - "encoder_channels": [64, 128, 256], + "encoder_channels": [64, 128], "encoder_layers": [2, 4, 8], - "channels": [64, 128, 256], + "channels": [64, 128], "norm": ["layer_norm", "batch_norm"] } train_search_space = { @@ -120,10 +120,10 @@ model_cls = IDGNN elif args.model in ["hybridgnn", "shallowrhsgnn"]: model_search_space = { - "encoder_channels": [32, 64, 128, 256, 512], - "encoder_layers": [2, 4, 8], - "channels": [32, 64, 128, 256, 512], - "embedding_dim": [32, 64, 128, 256, 512], + "encoder_channels": [32, 64], + "encoder_layers": [2, 4], + "channels": [32, 64, 128], + "embedding_dim": [32, 64], "norm": ["layer_norm", "batch_norm"], "rhs_emb_mode": [ RHSEmbeddingMode.FUSION, RHSEmbeddingMode.FEATURE, @@ -131,7 +131,7 @@ ] } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [32, 64, 128], "base_lr": [0.001, 0.01], "gamma_rate": [0.8, 1.], } diff --git a/examples/wmf_example.py b/examples/wmf_example.py new file mode 100644 index 0000000..3229986 --- /dev/null +++ b/examples/wmf_example.py @@ -0,0 +1,200 @@ +"""Example script to run the models in this repository. + +python relbench_example.py --dataset rel-trial --task site-sponsor-run + --model hybridgnn --epochs 10 +""" + +import argparse +import json +import os +import warnings +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from relbench.base import Dataset, RecommendationTask, TaskType +from relbench.datasets import get_dataset +from relbench.modeling.graph import ( + get_link_train_table_input, + make_pkey_fkey_graph, +) +from relbench.modeling.loader import SparseTensor +from relbench.modeling.utils import get_stype_proposal +from relbench.tasks import get_task +from torch import Tensor +from torch_frame import stype +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_geometric.loader import NeighborLoader +from torch_geometric.seed import seed_everything +from torch_geometric.typing import NodeType +from torch_geometric.utils.cross_entropy import sparse_cross_entropy +from tqdm import tqdm + +from hybridgnn.nn.models import WeightedMatrixFactorization +from hybridgnn.utils import GloveTextEmbedding, RHSEmbeddingMode + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str, default="rel-trial") +parser.add_argument("--task", type=str, default="site-sponsor-run") +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=3) +parser.add_argument("--eval_epochs_interval", type=int, default=1) +parser.add_argument("--batch_size", type=int, default=512) +parser.add_argument("--channels", type=int, default=128) +parser.add_argument("--aggr", type=str, default="sum") +parser.add_argument("--num_layers", type=int, default=4) +parser.add_argument("--num_neighbors", type=int, default=16) +parser.add_argument("--embedding_dim", type=int, default=128) +parser.add_argument("--temporal_strategy", type=str, default="last") +parser.add_argument("--max_steps_per_epoch", type=int, default=2000) +parser.add_argument("--num_workers", type=int, default=0) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument("--cache_dir", type=str, + default=os.path.expanduser("~/.cache/relbench_examples")) +args = parser.parse_args() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + torch.set_num_threads(1) +seed_everything(args.seed) + + +dataset: Dataset = get_dataset(args.dataset, download=True) +task: RecommendationTask = get_task(args.dataset, args.task, download=True) +tune_metric = "link_prediction_map" +assert task.task_type == TaskType.LINK_PREDICTION + +stypes_cache_path = Path(f"{args.cache_dir}/{args.dataset}/stypes.json") +try: + with open(stypes_cache_path, "r") as f: + col_to_stype_dict = json.load(f) + for table, col_to_stype in col_to_stype_dict.items(): + for col, stype_str in col_to_stype.items(): + col_to_stype[col] = stype(stype_str) +except FileNotFoundError: + col_to_stype_dict = get_stype_proposal(dataset.get_db()) + Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True) + with open(stypes_cache_path, "w") as f: + json.dump(col_to_stype_dict, f, indent=2, default=str) + +data, col_stats_dict = make_pkey_fkey_graph( + dataset.get_db(), + col_to_stype_dict=col_to_stype_dict, + text_embedder_cfg=TextEmbedderConfig( + text_embedder=GloveTextEmbedding(device=device), batch_size=256), + cache_dir=f"{args.cache_dir}/{args.dataset}/materialized", +) + +num_neighbors = [ + int(args.num_neighbors // 2**i) for i in range(args.num_layers) +] + +loader_dict: Dict[str, NeighborLoader] = {} +dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} +src_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} +num_dst_nodes_dict: Dict[str, int] = {} +num_src_nodes_dict: Dict[str, int] = {} +for split in ["train", "val", "test"]: + table = task.get_table(split) + table_input = get_link_train_table_input(table, task) + dst_nodes_dict[split] = table_input.dst_nodes + src_nodes_dict[split] = table_input.src_nodes + num_dst_nodes_dict[split] = table_input.num_dst_nodes + num_src_nodes_dict[split] = len(table_input.src_nodes[1]) + loader_dict[split] = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=table_input.src_nodes, + input_time=table_input.src_time, + subgraph_type="bidirectional", + batch_size=args.batch_size, + temporal_strategy=args.temporal_strategy, + shuffle=split == "train", + num_workers=args.num_workers, + persistent_workers=args.num_workers > 0, + ) + + +num_src_nodes = num_src_nodes_dict["train"] +num_dst_nodes = num_dst_nodes_dict["train"] + +model = WeightedMatrixFactorization(num_src_nodes, num_dst_nodes, args.embedding_dim).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + +def train() -> float: + model.train() + + loss_accum = count_accum = 0 + steps = 0 + total_steps = min(len(loader_dict["train"]), args.max_steps_per_epoch) + sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device) + for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"): + batch = batch.to(device) + + # Get ground-truth + input_id = batch[task.src_entity_table].input_id + src_batch, dst_index = sparse_tensor[input_id] + + # Optimization + optimizer.zero_grad() + + loss = model(input_id[src_batch], dst_index) + loss /= len(src_batch) + + loss.backward() + + optimizer.step() + + numel = len(src_batch) + loss_accum += float(loss) * numel + count_accum += numel + + steps += 1 + if steps > args.max_steps_per_epoch: + break + return loss_accum / count_accum if count_accum > 0 else float("nan") + +@torch.no_grad() +def test(loader: NeighborLoader, desc: str, sparse_tensor) -> np.ndarray: + model.eval() + + pred_list: List[Tensor] = [] + for batch in tqdm(loader, desc=desc): + batch = batch.to(device) + input_id = batch[task.src_entity_table].input_id + scores = model.lhs(input_id) @ model.rhs(model.full_rhs).t() + + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + pred_list.append(pred_mini) + pred = torch.cat(pred_list, dim=0).cpu().numpy() + return pred + + +state_dict = None +best_val_metric = 0 +val_sparse_tensor = SparseTensor(dst_nodes_dict["val"][1], device=device) +test_sparse_tensor = SparseTensor(dst_nodes_dict["test"][1], device=device) +for epoch in range(1, args.epochs + 1): + train_loss = train() + if epoch % args.eval_epochs_interval == 0: + val_pred = test(loader_dict["val"], desc="Val", sparse_tensor=val_sparse_tensor) + val_metrics = task.evaluate(val_pred, task.get_table("val")) + print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, " + f"Val metrics: {val_metrics}") + + if val_metrics[tune_metric] > best_val_metric: + best_val_metric = val_metrics[tune_metric] + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + +assert state_dict is not None +model.load_state_dict(state_dict) +val_pred = test(loader_dict["val"], desc="Best val", sparse_tensor=val_sparse_tensor) +val_metrics = task.evaluate(val_pred, task.get_table("val")) +print(f"Best val metrics: {val_metrics}") + +test_pred = test(loader_dict["test"], desc="Test", sparse_tensor=test_sparse_tensor) +test_metrics = task.evaluate(test_pred) +print(f"Best test metrics: {test_metrics}") diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index d0b758b..5a14e58 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -3,5 +3,11 @@ from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN from .rhsembeddinggnn import RHSEmbeddingGNN +from .wmf import WeightedMatrixFactorization -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', 'RHSEmbeddingGNN'] +__all__ = classes = ['HeteroGraphSAGE', + 'IDGNN', + 'HybridGNN', + 'ShallowRHSGNN', + 'RHSEmbeddingGNN', + 'WeightedMatrixFactorization'] diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py new file mode 100644 index 0000000..042912b --- /dev/null +++ b/hybridgnn/nn/models/wmf.py @@ -0,0 +1,50 @@ +import torch +from typing import Any, Dict, Optional, Type + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +class WeightedMatrixFactorization(torch.nn.Module): + def __init__( + self, + num_src_nodes: int, + num_dst_nodes: int, + embedding_dim: int, + w0:float = 0.5, + ) -> None: + super().__init__() + self.rhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) + self.lhs = torch.nn.Embedding(num_src_nodes, embedding_dim) + self.w0 = w0 + self.num_src_nodes = num_src_nodes + self.num_dst_nodes = num_dst_nodes + self.register_buffer("full_lhs", torch.arange(0, self.num_src_nodes)) + self.register_buffer("full_rhs", torch.arange(0, self.num_dst_nodes)) + + def reset_parameters(self) -> None: + super().reset_parameters() + self.rhs.reset_parameters() + self.lhs.reset_parameters() + + def forward( + self, + src_tensor: Tensor, + dst_tensor: Tensor, + ) -> Tensor: + lhs_embedding = self.lhs(src_tensor) + rhs_embedding = self.rhs(dst_tensor) + mat_pos = lhs_embedding @ rhs_embedding.t() + + #mask = ~torch.isin(self.full_lhs, src_tensor) + + # Filter out the values present in the first tensor + #neg_lhs = self.full_lhs[mask] + mask = ~torch.isin(self.full_rhs, dst_tensor) + neg_rhs = self.full_rhs[mask] + mat_neg = torch.mm(lhs_embedding, self.rhs(neg_rhs).t()) + return ((1.0 - mat_pos) **2).sum() + self.w0*((mat_neg**2).sum()) \ No newline at end of file