From c6ba7a630c41e04e4057108830cae5262b28d713 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Mon, 30 Sep 2024 07:26:40 +0000 Subject: [PATCH 1/7] add weighted matrix factorization model --- examples/wmf_example.py | 119 +++++++++++++++++++++++++++++++++++++ hybridgnn/nn/models/wmf.py | 40 +++++++++++++ 2 files changed, 159 insertions(+) create mode 100644 examples/wmf_example.py create mode 100644 hybridgnn/nn/models/wmf.py diff --git a/examples/wmf_example.py b/examples/wmf_example.py new file mode 100644 index 0000000..967dc6a --- /dev/null +++ b/examples/wmf_example.py @@ -0,0 +1,119 @@ +"""Example script to run the models in this repository. + +python relbench_example.py --dataset rel-trial --task site-sponsor-run + --model hybridgnn --epochs 10 +""" + +import argparse +import json +import os +import warnings +from pathlib import Path +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from relbench.base import Dataset, RecommendationTask, TaskType +from relbench.datasets import get_dataset +from relbench.modeling.graph import ( + get_link_train_table_input, + make_pkey_fkey_graph, +) +from relbench.modeling.loader import SparseTensor +from relbench.modeling.utils import get_stype_proposal +from relbench.tasks import get_task +from torch import Tensor +from torch_frame import stype +from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_geometric.loader import NeighborLoader +from torch_geometric.seed import seed_everything +from torch_geometric.typing import NodeType +from torch_geometric.utils.cross_entropy import sparse_cross_entropy +from tqdm import tqdm + +from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.utils import GloveTextEmbedding, RHSEmbeddingMode + +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", type=str, default="rel-trial") +parser.add_argument("--task", type=str, default="site-sponsor-run") +parser.add_argument( + "--model", + type=str, + default="hybridgnn", + choices=["hybridgnn", "idgnn", "shallowrhsgnn"], +) +parser.add_argument("--lr", type=float, default=0.001) +parser.add_argument("--epochs", type=int, default=20) +parser.add_argument("--eval_epochs_interval", type=int, default=1) +parser.add_argument("--batch_size", type=int, default=512) +parser.add_argument("--channels", type=int, default=128) +parser.add_argument("--aggr", type=str, default="sum") +parser.add_argument("--num_layers", type=int, default=4) +parser.add_argument("--num_neighbors", type=int, default=128) +parser.add_argument("--temporal_strategy", type=str, default="last") +parser.add_argument("--max_steps_per_epoch", type=int, default=2000) +parser.add_argument("--num_workers", type=int, default=0) +parser.add_argument("--seed", type=int, default=42) +parser.add_argument("--cache_dir", type=str, + default=os.path.expanduser("~/.cache/relbench_examples")) +args = parser.parse_args() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +if torch.cuda.is_available(): + torch.set_num_threads(1) +seed_everything(args.seed) + + +dataset: Dataset = get_dataset(args.dataset, download=True) +task: RecommendationTask = get_task(args.dataset, args.task, download=True) +tune_metric = "link_prediction_map" +assert task.task_type == TaskType.LINK_PREDICTION + +stypes_cache_path = Path(f"{args.cache_dir}/{args.dataset}/stypes.json") +try: + with open(stypes_cache_path, "r") as f: + col_to_stype_dict = json.load(f) + for table, col_to_stype in col_to_stype_dict.items(): + for col, stype_str in col_to_stype.items(): + col_to_stype[col] = stype(stype_str) +except FileNotFoundError: + col_to_stype_dict = get_stype_proposal(dataset.get_db()) + Path(stypes_cache_path).parent.mkdir(parents=True, exist_ok=True) + with open(stypes_cache_path, "w") as f: + json.dump(col_to_stype_dict, f, indent=2, default=str) + +data, col_stats_dict = make_pkey_fkey_graph( + dataset.get_db(), + col_to_stype_dict=col_to_stype_dict, + text_embedder_cfg=TextEmbedderConfig( + text_embedder=GloveTextEmbedding(device=device), batch_size=256), + cache_dir=f"{args.cache_dir}/{args.dataset}/materialized", +) + +num_neighbors = [ + int(args.num_neighbors // 2**i) for i in range(args.num_layers) +] + +loader_dict: Dict[str, NeighborLoader] = {} +dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} +num_dst_nodes_dict: Dict[str, int] = {} +for split in ["train", "val", "test"]: + table = task.get_table(split) + table_input = get_link_train_table_input(table, task) + dst_nodes_dict[split] = table_input.dst_nodes + num_dst_nodes_dict[split] = table_input.num_dst_nodes + loader_dict[split] = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=table_input.src_nodes, + input_time=table_input.src_time, + subgraph_type="bidirectional", + batch_size=args.batch_size, + temporal_strategy=args.temporal_strategy, + shuffle=split == "train", + num_workers=args.num_workers, + persistent_workers=args.num_workers > 0, + ) diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py new file mode 100644 index 0000000..55bb137 --- /dev/null +++ b/hybridgnn/nn/models/wmf.py @@ -0,0 +1,40 @@ +import torch +from typing import Any, Dict, Optional, Type + +import torch +from torch import Tensor +from torch_frame.data.stats import StatType +from torch_frame.nn.models import ResNet +from torch_geometric.data import HeteroData +from torch_geometric.nn import MLP +from torch_geometric.typing import NodeType + +class WeightedMatrixFactorization(torch.nn.Module): + def __init__( + self, + num_src_nodes: int, + num_dst_nodes: int, + embedding_dim: int, + ) -> None: + super().__init__() + self.rhs = torch.nn.Embedding(num_src_nodes, embedding_dim) + self.lhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) + self.w0 = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self) -> None: + super().reset_parameters() + self.rhs.reset_parameters() + self.lhs.reset_parameters() + self.w0.reset_parameters() + + def forward( + self, + batch: HeteroData, + entity_table: NodeType, + ground_truth: Tensor, + ) -> Dict[NodeType, Tensor]: + batch_size = batch[entity_table].seed_time.size(0) + lhs_idx = batch[entity_table].n_id[:batch_size] + lhs_embedding = self.lhs(lhs_idx) + mat = lhs_embedding @ self.rhs.t() + return torch.sum((1 - mat[ground_truth]) **2) + self.w0(mat[~ground_truth]**2) \ No newline at end of file From f991ba5886df44c5eba4eeb4d311564b4ddd2bd0 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 03:54:19 +0000 Subject: [PATCH 2/7] wip --- examples/.wmf_example.py.swp | Bin 0 -> 16384 bytes examples/wmf_example.py | 22 +++++++++------------- hybridgnn/nn/models/__init__.py | 8 +++++++- hybridgnn/nn/models/wmf.py | 20 +++++++++++--------- 4 files changed, 27 insertions(+), 23 deletions(-) create mode 100644 examples/.wmf_example.py.swp diff --git a/examples/.wmf_example.py.swp b/examples/.wmf_example.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..568a2a6e340186e2c6c57293a63a77c16f3a197d GIT binary patch literal 16384 zcmeHNTZm*w87@~%G;xz)9(*lNKu*W*(>=3Uvpdj1&Fp4}Y{q4GR!D~Nr23rd?vpvU zRMnZ8PTc65FA9SABqBzi#0Q^*h=Om*iVq?vUXtMD(IEIDUIKnqbxu!T*8~w1s^RPI zy8M4teSg(ob^bbo-Y0jSq8EY-Fa2xmm@KxXdI0u{s z-Z)Fh&w-x-UjxR#Ht=EKEb!C!5%Lw_Ip8tiZ}$@N2jKU>%K!tO1^U37_Ym?5@J-;$ zz~_J*xClH75a9NEkq=x29t3{-9zwnWaDW28xtox$0~zp9;4gO(@=M?sz!w1l+yu6O zi@*iopYJB*JHY3GXMm@H2Y`El*WX3RcY#j>SAY)yZz5pwL*NBq7q}m|5BMDdEk6NX z22Owg_%pbF1Nal5?eHt$KmX)a7Unz|jR%F0v`_a5X%UJ~#G)||<3#j_-iu6(WH8OS z)wLNhQ%rJotD7;!MZ(hL1s;2CBd%2_exgF24R{=L(SIT>j`$w3E}5JpANesqN+RC3 zKdm+$FcpnM2{rxA^B-whO>>6|vz)um!e8fmbRu|pn~hkOza_18$Sa_nPz)Su^Mkq2>=mHu)=bl9dz z9`juFH`_GkLsq4#FO{(L9w%x{i<0NQ>8=bdyCqZOHua9Z7G;tSw@g#lkhxr~6_iXs zeUKf(O1@3V{+_6KoAQ%Hs<1dT$=aZc>|lJF083XXin5R_^%uEm5q=W!Qqjvv%6AKO zxv28^k`RTkool$5WndqP1E@EuZO_q6J#sBDEQBVzZENT=z`8XVP!kMs$6} z1}R21+s>kpG26VPZf#biv8*v&TCmt3z7>OrsX>&OY>_3KVX($Z|9-7l(LlHFwcxd+ z?BRcKb!Yd<@Y?lD+dGf%?R;XFs4C}S#+67SwCbfvei)X5 z>-nA(d6==%BU;lIz0RW|%XkixqC_jKO|FZ5v!-inx2Z|#e9YD2IM>Y7bEmGdMYrv8 zuXWVf3u+c*36Y$~AwS|`qQ-j3n|5@jWnC2_8V7J9p>{-SEI5m$ztJMeklM1Sidm4z zkR7ok)l)#v3F|S?6I;h7eNWf*=$%9K zG>f>9sH?^SKPg!rSCI9+m%7$Jb_GudJck==-dOWmEmF6@;#ljiU=!tgozB#wdTpxh zuH8d>U1d#LQ!mNYsZV)ib8Q2f)Z;=Nav@iW_SQ;f>;%h_mohB52zgmVC0Oik{zCZ)JW2ij9e6i1xTZdZiMqt2HEB>6Jp8>+PFP z85+W2%uLB%(CeYGdgB=v*yUw^-t6g7Hb74bSEugnAjx&WNaJN3+UNyP15<|f<9z^C0 z8O$9t8`G`oY@>!-%i$1pu57wQ-bf~5KBuzY4)CpOQ* z4P%@TgYvs7tNSL7tvXR4({3)7mBbB0ZInNN}oG1%;QR zYRqZlFmx)BP=^PC1guU1qM!ed@C<$%(9i#PTsNQZXr3^)cH1C9a5fMdWh;23ZWI0hU8j)DIc13H=-<7y~Jz(vQtQ`|t= zC+8|pp0Bt+EkxLQ`hOZ}kC-~GVeV4)Jrg28+>TDG*9AHZd4S-Se!iV+*OtpNO_aVS znu#t1bM=hmg4UH%(>g|CMjy1V4`ulEm6Co~uCF6>oWt}8*D*FKqi-Bcn=PVmq0C=> z@6?PN0Zbci^%O@K`(TUA4Z#c*cPEG^?3+j<0E^=^CoOXym*`OmmXG*pRU0{Pk$(Zn10+cR literal 0 HcmV?d00001 diff --git a/examples/wmf_example.py b/examples/wmf_example.py index 967dc6a..6c975eb 100644 --- a/examples/wmf_example.py +++ b/examples/wmf_example.py @@ -98,22 +98,18 @@ loader_dict: Dict[str, NeighborLoader] = {} dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} +src_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} num_dst_nodes_dict: Dict[str, int] = {} for split in ["train", "val", "test"]: table = task.get_table(split) table_input = get_link_train_table_input(table, task) dst_nodes_dict[split] = table_input.dst_nodes + src_nodes_dict[split] = table_input.src_nodes num_dst_nodes_dict[split] = table_input.num_dst_nodes - loader_dict[split] = NeighborLoader( - data, - num_neighbors=num_neighbors, - time_attr="time", - input_nodes=table_input.src_nodes, - input_time=table_input.src_time, - subgraph_type="bidirectional", - batch_size=args.batch_size, - temporal_strategy=args.temporal_strategy, - shuffle=split == "train", - num_workers=args.num_workers, - persistent_workers=args.num_workers > 0, - ) + +dst_nodes = torch.cat(dst_nodes_dict["train"][1], dst_nodes_dict["val"][1]) +src_nodes = torch.cat(src_nodes_dict["train"][1], src_nodes_dict["val"][1]) +total_src_nodes = len(torch.unique(src_nodes)) +total_dst_nodes = len(torch.unique(dst_nodes)) + +train_table = task.get_table("train").df diff --git a/hybridgnn/nn/models/__init__.py b/hybridgnn/nn/models/__init__.py index d0b758b..5a14e58 100644 --- a/hybridgnn/nn/models/__init__.py +++ b/hybridgnn/nn/models/__init__.py @@ -3,5 +3,11 @@ from .hybridgnn import HybridGNN from .shallowrhsgnn import ShallowRHSGNN from .rhsembeddinggnn import RHSEmbeddingGNN +from .wmf import WeightedMatrixFactorization -__all__ = classes = ['HeteroGraphSAGE', 'IDGNN', 'HybridGNN', 'ShallowRHSGNN', 'RHSEmbeddingGNN'] +__all__ = classes = ['HeteroGraphSAGE', + 'IDGNN', + 'HybridGNN', + 'ShallowRHSGNN', + 'RHSEmbeddingGNN', + 'WeightedMatrixFactorization'] diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py index 55bb137..468f320 100644 --- a/hybridgnn/nn/models/wmf.py +++ b/hybridgnn/nn/models/wmf.py @@ -20,6 +20,8 @@ def __init__( self.rhs = torch.nn.Embedding(num_src_nodes, embedding_dim) self.lhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) self.w0 = torch.nn.Parameter(torch.tensor(1.0)) + self.num_src_nodes = num_src_nodes + self.num_dst_nodes = num_dst_nodes def reset_parameters(self) -> None: super().reset_parameters() @@ -29,12 +31,12 @@ def reset_parameters(self) -> None: def forward( self, - batch: HeteroData, - entity_table: NodeType, - ground_truth: Tensor, - ) -> Dict[NodeType, Tensor]: - batch_size = batch[entity_table].seed_time.size(0) - lhs_idx = batch[entity_table].n_id[:batch_size] - lhs_embedding = self.lhs(lhs_idx) - mat = lhs_embedding @ self.rhs.t() - return torch.sum((1 - mat[ground_truth]) **2) + self.w0(mat[~ground_truth]**2) \ No newline at end of file + src_tensor: Tensor, + dst_tensor: Tensor, + ) -> Tensor: + lhs_embedding = self.lhs() + rhs_embedding = self.rhs() + mask = torch.zeros(self.num_src_nodes, self.num_dst_nodes).to(src_tensor.device) + mask[src_tensor][dst_tensor] = 1 + mat = lhs_embedding @ rhs_embedding.t() + return ((1 - mat[mask]) **2).sum() + self.w0*(mat[~mask]**2).sum() \ No newline at end of file From 8107b2c3db9bb6264c21dfd582e9d79f38366be0 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 21:37:09 +0000 Subject: [PATCH 3/7] fix --- examples/wmf_example.py | 94 +++++++++++++++++++++++++++++++++++--- hybridgnn/nn/models/wmf.py | 27 +++++++---- 2 files changed, 106 insertions(+), 15 deletions(-) diff --git a/examples/wmf_example.py b/examples/wmf_example.py index 6c975eb..bcfd587 100644 --- a/examples/wmf_example.py +++ b/examples/wmf_example.py @@ -32,7 +32,7 @@ from torch_geometric.utils.cross_entropy import sparse_cross_entropy from tqdm import tqdm -from hybridgnn.nn.models import IDGNN, HybridGNN, ShallowRHSGNN +from hybridgnn.nn.models import WeightedMatrixFactorization from hybridgnn.utils import GloveTextEmbedding, RHSEmbeddingMode parser = argparse.ArgumentParser() @@ -51,7 +51,8 @@ parser.add_argument("--channels", type=int, default=128) parser.add_argument("--aggr", type=str, default="sum") parser.add_argument("--num_layers", type=int, default=4) -parser.add_argument("--num_neighbors", type=int, default=128) +parser.add_argument("--num_neighbors", type=int, default=16) +parser.add_argument("--embedding_dim", type=int, default=128) parser.add_argument("--temporal_strategy", type=str, default="last") parser.add_argument("--max_steps_per_epoch", type=int, default=2000) parser.add_argument("--num_workers", type=int, default=0) @@ -100,16 +101,95 @@ dst_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} src_nodes_dict: Dict[str, Tuple[NodeType, Tensor]] = {} num_dst_nodes_dict: Dict[str, int] = {} +num_src_nodes_dict: Dict[str, int] = {} for split in ["train", "val", "test"]: table = task.get_table(split) table_input = get_link_train_table_input(table, task) dst_nodes_dict[split] = table_input.dst_nodes src_nodes_dict[split] = table_input.src_nodes num_dst_nodes_dict[split] = table_input.num_dst_nodes + num_src_nodes_dict[split] = len(table_input.src_nodes[1]) + loader_dict[split] = NeighborLoader( + data, + num_neighbors=num_neighbors, + time_attr="time", + input_nodes=table_input.src_nodes, + input_time=table_input.src_time, + subgraph_type="bidirectional", + batch_size=args.batch_size, + temporal_strategy=args.temporal_strategy, + shuffle=split == "train", + num_workers=args.num_workers, + persistent_workers=args.num_workers > 0, + ) -dst_nodes = torch.cat(dst_nodes_dict["train"][1], dst_nodes_dict["val"][1]) -src_nodes = torch.cat(src_nodes_dict["train"][1], src_nodes_dict["val"][1]) -total_src_nodes = len(torch.unique(src_nodes)) -total_dst_nodes = len(torch.unique(dst_nodes)) -train_table = task.get_table("train").df +num_src_nodes = num_src_nodes_dict["train"] +num_dst_nodes = num_dst_nodes_dict["train"] +print(num_src_nodes, num_dst_nodes) + +model = WeightedMatrixFactorization(num_src_nodes, num_dst_nodes, args.embedding_dim) +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + +def train() -> float: + model.train() + + loss_accum = count_accum = 0 + steps = 0 + total_steps = min(len(loader_dict["train"]), args.max_steps_per_epoch) + sparse_tensor = SparseTensor(dst_nodes_dict["train"][1], device=device) + for batch in tqdm(loader_dict["train"], total=total_steps, desc="Train"): + batch = batch.to(device) + + # Get ground-truth + input_id = batch[task.src_entity_table].input_id + src_batch, dst_index = sparse_tensor[input_id] + + src_tensor = batch[task.src_entity_table].input_id + + # Optimization + optimizer.zero_grad() + + loss = model(input_id[src_batch], dst_index) + loss /= len(src_batch) + + loss.backward() + + optimizer.step() + + numel = len(src_tensor) + loss_accum += float(loss) * numel + count_accum += numel + + steps += 1 + if steps > args.max_steps_per_epoch: + break + +@torch.no_grad() +def test(loader: NeighborLoader, desc: str) -> np.ndarray: + model.eval() + pass + +state_dict = None +best_val_metric = 0 +for epoch in range(1, args.epochs + 1): + train_loss = train() + if epoch % args.eval_epochs_interval == 0: + val_pred = test(loader_dict["val"], desc="Val") + val_metrics = task.evaluate(val_pred, task.get_table("val")) + print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, " + f"Val metrics: {val_metrics}") + + if val_metrics[tune_metric] > best_val_metric: + best_val_metric = val_metrics[tune_metric] + state_dict = {k: v.cpu() for k, v in model.state_dict().items()} + +assert state_dict is not None +model.load_state_dict(state_dict) +val_pred = test(loader_dict["val"], desc="Best val") +val_metrics = task.evaluate(val_pred, task.get_table("val")) +print(f"Best val metrics: {val_metrics}") + +test_pred = test(loader_dict["test"], desc="Test") +test_metrics = task.evaluate(test_pred) +print(f"Best test metrics: {test_metrics}") diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py index 468f320..dd95411 100644 --- a/hybridgnn/nn/models/wmf.py +++ b/hybridgnn/nn/models/wmf.py @@ -17,11 +17,13 @@ def __init__( embedding_dim: int, ) -> None: super().__init__() - self.rhs = torch.nn.Embedding(num_src_nodes, embedding_dim) - self.lhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) + self.rhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) + self.lhs = torch.nn.Embedding(num_src_nodes, embedding_dim) self.w0 = torch.nn.Parameter(torch.tensor(1.0)) self.num_src_nodes = num_src_nodes self.num_dst_nodes = num_dst_nodes + self.register_buffer("full_lhs", torch.arange(0, self.num_src_nodes)) + self.register_buffer("full_rhs", torch.arange(0, self.num_dst_nodes)) def reset_parameters(self) -> None: super().reset_parameters() @@ -34,9 +36,18 @@ def forward( src_tensor: Tensor, dst_tensor: Tensor, ) -> Tensor: - lhs_embedding = self.lhs() - rhs_embedding = self.rhs() - mask = torch.zeros(self.num_src_nodes, self.num_dst_nodes).to(src_tensor.device) - mask[src_tensor][dst_tensor] = 1 - mat = lhs_embedding @ rhs_embedding.t() - return ((1 - mat[mask]) **2).sum() + self.w0*(mat[~mask]**2).sum() \ No newline at end of file + lhs_embedding = self.lhs(src_tensor) + rhs_embedding = self.rhs(dst_tensor) + mat_pos = lhs_embedding @ rhs_embedding.t() + + mask = ~torch.isin(self.full_lhs, src_tensor) + + # Filter out the values present in the first tensor + neg_lhs = self.full_lhs[mask] + mask = ~torch.isin(self.full_rhs, dst_tensor) + neg_rhs = self.full_rhs[mask] + from torch.cuda.amp import autocast + + with autocast(): + mat_neg = torch.mm(self.lhs(neg_lhs).half(), self.rhs(neg_rhs).half().t()) + return ((1.0 - mat_pos) **2).sum() + self.w0*((mat_neg**2).sum()) \ No newline at end of file From 03160d82219c057e15160abb84be93f278bdf772 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 21:40:04 +0000 Subject: [PATCH 4/7] fix --- hybridgnn/nn/models/wmf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py index dd95411..a8f3adb 100644 --- a/hybridgnn/nn/models/wmf.py +++ b/hybridgnn/nn/models/wmf.py @@ -46,8 +46,5 @@ def forward( neg_lhs = self.full_lhs[mask] mask = ~torch.isin(self.full_rhs, dst_tensor) neg_rhs = self.full_rhs[mask] - from torch.cuda.amp import autocast - - with autocast(): - mat_neg = torch.mm(self.lhs(neg_lhs).half(), self.rhs(neg_rhs).half().t()) + mat_neg = torch.mm(self.lhs(neg_lhs), self.rhs(neg_rhs).t()) return ((1.0 - mat_pos) **2).sum() + self.w0*((mat_neg**2).sum()) \ No newline at end of file From d01c3710ad4df434d918b01ff7cb56b691eac7bf Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 23:44:28 +0000 Subject: [PATCH 5/7] fix --- .../relbench_link_prediction_benchmark.py | 26 +++++++++--------- examples/.wmf_example.py.swp | Bin 16384 -> 0 bytes examples/wmf_example.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) delete mode 100644 examples/.wmf_example.py.swp diff --git a/benchmark/relbench_link_prediction_benchmark.py b/benchmark/relbench_link_prediction_benchmark.py index cc633ef..01119df 100644 --- a/benchmark/relbench_link_prediction_benchmark.py +++ b/benchmark/relbench_link_prediction_benchmark.py @@ -35,25 +35,25 @@ TRAIN_CONFIG_KEYS = ["batch_size", "gamma_rate", "base_lr"] LINK_PREDICTION_METRIC = "link_prediction_map" -VAL_LOSS_DELTA = 0.001 +VAL_LOSS_DELTA = 0.0005 parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="rel-amazon") -parser.add_argument("--task", type=str, default="user-item-rate") +parser.add_argument("--task", type=str, default="user-item-purchase") parser.add_argument( "--model", type=str, - default="hybridgnn", + default="idgnn", choices=["hybridgnn", "idgnn", "shallowrhsgnn"], ) -parser.add_argument("--epochs", type=int, default=20) -parser.add_argument("--num_trials", type=int, default=50, +parser.add_argument("--epochs", type=int, default=10) +parser.add_argument("--num_trials", type=int, default=10, help="Number of Optuna-based hyper-parameter tuning.") parser.add_argument( "--num_repeats", type=int, default=5, help="Number of repeated training and eval on the best config.") parser.add_argument("--eval_epochs_interval", type=int, default=1) -parser.add_argument("--num_layers", type=int, default=2) +parser.add_argument("--num_layers", type=int, default=6) parser.add_argument("--num_neighbors", type=int, default=128) parser.add_argument("--temporal_strategy", type=str, default="last", choices=["last", "uniform"]) @@ -107,9 +107,9 @@ if args.model == "idgnn": model_search_space = { - "encoder_channels": [64, 128, 256], + "encoder_channels": [64, 128], "encoder_layers": [2, 4, 8], - "channels": [64, 128, 256], + "channels": [64, 128], "norm": ["layer_norm", "batch_norm"] } train_search_space = { @@ -120,10 +120,10 @@ model_cls = IDGNN elif args.model in ["hybridgnn", "shallowrhsgnn"]: model_search_space = { - "encoder_channels": [32, 64, 128, 256, 512], - "encoder_layers": [2, 4, 8], - "channels": [32, 64, 128, 256, 512], - "embedding_dim": [32, 64, 128, 256, 512], + "encoder_channels": [32, 64], + "encoder_layers": [2, 4], + "channels": [32, 64, 128], + "embedding_dim": [32, 64], "norm": ["layer_norm", "batch_norm"], "rhs_emb_mode": [ RHSEmbeddingMode.FUSION, RHSEmbeddingMode.FEATURE, @@ -131,7 +131,7 @@ ] } train_search_space = { - "batch_size": [256, 512, 1024], + "batch_size": [32, 64, 128], "base_lr": [0.001, 0.01], "gamma_rate": [0.8, 1.], } diff --git a/examples/.wmf_example.py.swp b/examples/.wmf_example.py.swp deleted file mode 100644 index 568a2a6e340186e2c6c57293a63a77c16f3a197d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 16384 zcmeHNTZm*w87@~%G;xz)9(*lNKu*W*(>=3Uvpdj1&Fp4}Y{q4GR!D~Nr23rd?vpvU zRMnZ8PTc65FA9SABqBzi#0Q^*h=Om*iVq?vUXtMD(IEIDUIKnqbxu!T*8~w1s^RPI zy8M4teSg(ob^bbo-Y0jSq8EY-Fa2xmm@KxXdI0u{s z-Z)Fh&w-x-UjxR#Ht=EKEb!C!5%Lw_Ip8tiZ}$@N2jKU>%K!tO1^U37_Ym?5@J-;$ zz~_J*xClH75a9NEkq=x29t3{-9zwnWaDW28xtox$0~zp9;4gO(@=M?sz!w1l+yu6O zi@*iopYJB*JHY3GXMm@H2Y`El*WX3RcY#j>SAY)yZz5pwL*NBq7q}m|5BMDdEk6NX z22Owg_%pbF1Nal5?eHt$KmX)a7Unz|jR%F0v`_a5X%UJ~#G)||<3#j_-iu6(WH8OS z)wLNhQ%rJotD7;!MZ(hL1s;2CBd%2_exgF24R{=L(SIT>j`$w3E}5JpANesqN+RC3 zKdm+$FcpnM2{rxA^B-whO>>6|vz)um!e8fmbRu|pn~hkOza_18$Sa_nPz)Su^Mkq2>=mHu)=bl9dz z9`juFH`_GkLsq4#FO{(L9w%x{i<0NQ>8=bdyCqZOHua9Z7G;tSw@g#lkhxr~6_iXs zeUKf(O1@3V{+_6KoAQ%Hs<1dT$=aZc>|lJF083XXin5R_^%uEm5q=W!Qqjvv%6AKO zxv28^k`RTkool$5WndqP1E@EuZO_q6J#sBDEQBVzZENT=z`8XVP!kMs$6} z1}R21+s>kpG26VPZf#biv8*v&TCmt3z7>OrsX>&OY>_3KVX($Z|9-7l(LlHFwcxd+ z?BRcKb!Yd<@Y?lD+dGf%?R;XFs4C}S#+67SwCbfvei)X5 z>-nA(d6==%BU;lIz0RW|%XkixqC_jKO|FZ5v!-inx2Z|#e9YD2IM>Y7bEmGdMYrv8 zuXWVf3u+c*36Y$~AwS|`qQ-j3n|5@jWnC2_8V7J9p>{-SEI5m$ztJMeklM1Sidm4z zkR7ok)l)#v3F|S?6I;h7eNWf*=$%9K zG>f>9sH?^SKPg!rSCI9+m%7$Jb_GudJck==-dOWmEmF6@;#ljiU=!tgozB#wdTpxh zuH8d>U1d#LQ!mNYsZV)ib8Q2f)Z;=Nav@iW_SQ;f>;%h_mohB52zgmVC0Oik{zCZ)JW2ij9e6i1xTZdZiMqt2HEB>6Jp8>+PFP z85+W2%uLB%(CeYGdgB=v*yUw^-t6g7Hb74bSEugnAjx&WNaJN3+UNyP15<|f<9z^C0 z8O$9t8`G`oY@>!-%i$1pu57wQ-bf~5KBuzY4)CpOQ* z4P%@TgYvs7tNSL7tvXR4({3)7mBbB0ZInNN}oG1%;QR zYRqZlFmx)BP=^PC1guU1qM!ed@C<$%(9i#PTsNQZXr3^)cH1C9a5fMdWh;23ZWI0hU8j)DIc13H=-<7y~Jz(vQtQ`|t= zC+8|pp0Bt+EkxLQ`hOZ}kC-~GVeV4)Jrg28+>TDG*9AHZd4S-Se!iV+*OtpNO_aVS znu#t1bM=hmg4UH%(>g|CMjy1V4`ulEm6Co~uCF6>oWt}8*D*FKqi-Bcn=PVmq0C=> z@6?PN0Zbci^%O@K`(TUA4Z#c*cPEG^?3+j<0E^=^CoOXym*`OmmXG*pRU0{Pk$(Zn10+cR diff --git a/examples/wmf_example.py b/examples/wmf_example.py index bcfd587..70fd626 100644 --- a/examples/wmf_example.py +++ b/examples/wmf_example.py @@ -128,7 +128,7 @@ num_dst_nodes = num_dst_nodes_dict["train"] print(num_src_nodes, num_dst_nodes) -model = WeightedMatrixFactorization(num_src_nodes, num_dst_nodes, args.embedding_dim) +model = WeightedMatrixFactorization(num_src_nodes, num_dst_nodes, args.embedding_dim).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) def train() -> float: From baa6f9225fb35d3a0ef33cf661ad7eb1cff9a895 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 23:56:40 +0000 Subject: [PATCH 6/7] fix --- examples/wmf_example.py | 35 ++++++++++++++++++++--------------- hybridgnn/nn/models/wmf.py | 10 +++++----- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/examples/wmf_example.py b/examples/wmf_example.py index bcfd587..a977cf4 100644 --- a/examples/wmf_example.py +++ b/examples/wmf_example.py @@ -38,12 +38,6 @@ parser = argparse.ArgumentParser() parser.add_argument("--dataset", type=str, default="rel-trial") parser.add_argument("--task", type=str, default="site-sponsor-run") -parser.add_argument( - "--model", - type=str, - default="hybridgnn", - choices=["hybridgnn", "idgnn", "shallowrhsgnn"], -) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--eval_epochs_interval", type=int, default=1) @@ -126,7 +120,6 @@ num_src_nodes = num_src_nodes_dict["train"] num_dst_nodes = num_dst_nodes_dict["train"] -print(num_src_nodes, num_dst_nodes) model = WeightedMatrixFactorization(num_src_nodes, num_dst_nodes, args.embedding_dim) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) @@ -145,8 +138,6 @@ def train() -> float: input_id = batch[task.src_entity_table].input_id src_batch, dst_index = sparse_tensor[input_id] - src_tensor = batch[task.src_entity_table].input_id - # Optimization optimizer.zero_grad() @@ -157,25 +148,39 @@ def train() -> float: optimizer.step() - numel = len(src_tensor) + numel = len(src_batch) loss_accum += float(loss) * numel count_accum += numel steps += 1 if steps > args.max_steps_per_epoch: break + return loss_accum / count_accum if count_accum > 0 else float("nan") @torch.no_grad() -def test(loader: NeighborLoader, desc: str) -> np.ndarray: +def test(loader: NeighborLoader, desc: str, sparse_tensor) -> np.ndarray: model.eval() - pass + + pred_list: List[Tensor] = [] + for batch in tqdm(loader, desc=desc): + batch = batch.to(device) + input_id = batch[task.src_entity_table].input_id + scores = model.lhs(input_id) @ model.rhs(model.full_rhs).t() + + _, pred_mini = torch.topk(scores, k=task.eval_k, dim=1) + pred_list.append(pred_mini) + pred = torch.cat(pred_list, dim=0).cpu().numpy() + return pred + state_dict = None best_val_metric = 0 +val_sparse_tensor = SparseTensor(dst_nodes_dict["val"][1], device=device) +test_sparse_tensor = SparseTensor(dst_nodes_dict["test"][1], device=device) for epoch in range(1, args.epochs + 1): train_loss = train() if epoch % args.eval_epochs_interval == 0: - val_pred = test(loader_dict["val"], desc="Val") + val_pred = test(loader_dict["val"], desc="Val", sparse_tensor=val_sparse_tensor) val_metrics = task.evaluate(val_pred, task.get_table("val")) print(f"Epoch: {epoch:02d}, Train loss: {train_loss}, " f"Val metrics: {val_metrics}") @@ -186,10 +191,10 @@ def test(loader: NeighborLoader, desc: str) -> np.ndarray: assert state_dict is not None model.load_state_dict(state_dict) -val_pred = test(loader_dict["val"], desc="Best val") +val_pred = test(loader_dict["val"], desc="Best val", sparse_tensor=val_sparse_tensor) val_metrics = task.evaluate(val_pred, task.get_table("val")) print(f"Best val metrics: {val_metrics}") -test_pred = test(loader_dict["test"], desc="Test") +test_pred = test(loader_dict["test"], desc="Test", sparse_tensor=test_sparse_tensor) test_metrics = task.evaluate(test_pred) print(f"Best test metrics: {test_metrics}") diff --git a/hybridgnn/nn/models/wmf.py b/hybridgnn/nn/models/wmf.py index a8f3adb..042912b 100644 --- a/hybridgnn/nn/models/wmf.py +++ b/hybridgnn/nn/models/wmf.py @@ -15,11 +15,12 @@ def __init__( num_src_nodes: int, num_dst_nodes: int, embedding_dim: int, + w0:float = 0.5, ) -> None: super().__init__() self.rhs = torch.nn.Embedding(num_dst_nodes, embedding_dim) self.lhs = torch.nn.Embedding(num_src_nodes, embedding_dim) - self.w0 = torch.nn.Parameter(torch.tensor(1.0)) + self.w0 = w0 self.num_src_nodes = num_src_nodes self.num_dst_nodes = num_dst_nodes self.register_buffer("full_lhs", torch.arange(0, self.num_src_nodes)) @@ -29,7 +30,6 @@ def reset_parameters(self) -> None: super().reset_parameters() self.rhs.reset_parameters() self.lhs.reset_parameters() - self.w0.reset_parameters() def forward( self, @@ -40,11 +40,11 @@ def forward( rhs_embedding = self.rhs(dst_tensor) mat_pos = lhs_embedding @ rhs_embedding.t() - mask = ~torch.isin(self.full_lhs, src_tensor) + #mask = ~torch.isin(self.full_lhs, src_tensor) # Filter out the values present in the first tensor - neg_lhs = self.full_lhs[mask] + #neg_lhs = self.full_lhs[mask] mask = ~torch.isin(self.full_rhs, dst_tensor) neg_rhs = self.full_rhs[mask] - mat_neg = torch.mm(self.lhs(neg_lhs), self.rhs(neg_rhs).t()) + mat_neg = torch.mm(lhs_embedding, self.rhs(neg_rhs).t()) return ((1.0 - mat_pos) **2).sum() + self.w0*((mat_neg**2).sum()) \ No newline at end of file From e5209ae2e3489369e58b5e264be13417c27d0d21 Mon Sep 17 00:00:00 2001 From: Yiwen Yuan Date: Tue, 1 Oct 2024 23:56:46 +0000 Subject: [PATCH 7/7] fix --- examples/wmf_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/wmf_example.py b/examples/wmf_example.py index a977cf4..f715b73 100644 --- a/examples/wmf_example.py +++ b/examples/wmf_example.py @@ -39,7 +39,7 @@ parser.add_argument("--dataset", type=str, default="rel-trial") parser.add_argument("--task", type=str, default="site-sponsor-run") parser.add_argument("--lr", type=float, default=0.001) -parser.add_argument("--epochs", type=int, default=20) +parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--eval_epochs_interval", type=int, default=1) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--channels", type=int, default=128)