From f610d00009088d349db800f06e0d4b6de2460213 Mon Sep 17 00:00:00 2001 From: yukew1998 Date: Thu, 19 Dec 2024 17:05:29 -0500 Subject: [PATCH] add TransF implementation --- CHANGELOG.md | 1 + README.md | 2 +- examples/kge_fb15k_237.py | 4 +- test/nn/kge/test_transf.py | 25 +++++++ torch_geometric/nn/kge/__init__.py | 2 + torch_geometric/nn/kge/transf.py | 105 +++++++++++++++++++++++++++++ 6 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 test/nn/kge/test_transf.py create mode 100644 torch_geometric/nn/kge/transf.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..4ac770c9710f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722)) - Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737)) - Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467)) +- Added the `TransF` KGE model ([#9858](https://github.com/pyg-team/pytorch_geometric/pull/9858)) ### Changed diff --git a/README.md b/README.md index f96d720837d1..7fa81c5916c1 100644 --- a/README.md +++ b/README.md @@ -279,7 +279,7 @@ Unlike simple stacking of GNN layers, these models could involve pre-processing, - **[ComplEx](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.ComplEx.html)** from Trouillon *et al.*: [Complex Embeddings for Simple Link Prediction](https://arxiv.org/abs/1606.06357) (ICML 2016) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - **[DistMult](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.DistMult.html)** from Yang *et al.*: [Embedding Entities and Relations for Learning and Inference in Knowledge Bases](https://arxiv.org/abs/1412.6575) (ICLR 2015) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - **[RotatE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.RotatE.html)** from Sun *et al.*: [RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space](https://arxiv.org/abs/1902.10197) (ICLR 2019) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] - +- **[TransF](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.TransF.html)** from Ji *et al.*: [Knowledge Graph Embedding by Flexible Translation](https://cdn.aaai.org/ocs/12887/12887-57589-1-PB.pdf) (ACL 2015) \[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\] **GNN operators and utilities:** diff --git a/examples/kge_fb15k_237.py b/examples/kge_fb15k_237.py index 144074c58df1..e6e9115aaa17 100644 --- a/examples/kge_fb15k_237.py +++ b/examples/kge_fb15k_237.py @@ -5,13 +5,14 @@ import torch.optim as optim from torch_geometric.datasets import FB15k_237 -from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE +from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE, TransF model_map = { 'transe': TransE, 'complex': ComplEx, 'distmult': DistMult, 'rotate': RotatE, + 'transf': TransF } parser = argparse.ArgumentParser() @@ -47,6 +48,7 @@ 'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6), 'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6), 'rotate': optim.Adam(model.parameters(), lr=1e-3), + 'transf': optim.Adam(model.parameters(), lr=0.01), } optimizer = optimizer_map[args.model] diff --git a/test/nn/kge/test_transf.py b/test/nn/kge/test_transf.py new file mode 100644 index 000000000000..ce977682d8dd --- /dev/null +++ b/test/nn/kge/test_transf.py @@ -0,0 +1,25 @@ +import torch + +from torch_geometric.nn import TransF + + +def test_transf(): + model = TransF(num_nodes=10, num_relations=5, hidden_channels=32) + assert str(model) == 'TransF(10, num_relations=5, hidden_channels=32)' + + head_index = torch.tensor([0, 2, 4, 6, 8]) + rel_type = torch.tensor([0, 1, 2, 3, 4]) + tail_index = torch.tensor([1, 3, 5, 7, 9]) + + loader = model.loader(head_index, rel_type, tail_index, batch_size=5) + for h, r, t in loader: + out = model(h, r, t) + assert out.size() == (5, ) + + loss = model.loss(h, r, t) + assert loss >= 0. + + mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) + assert 0 <= mean_rank <= 10 + assert 0 < mrr <= 1 + assert hits == 1.0 \ No newline at end of file diff --git a/torch_geometric/nn/kge/__init__.py b/torch_geometric/nn/kge/__init__.py index 1f7fe6bc0e95..9d365ff80f86 100644 --- a/torch_geometric/nn/kge/__init__.py +++ b/torch_geometric/nn/kge/__init__.py @@ -5,6 +5,7 @@ from .complex import ComplEx from .distmult import DistMult from .rotate import RotatE +from .transf import TransF __all__ = classes = [ 'KGEModel', @@ -12,4 +13,5 @@ 'ComplEx', 'DistMult', 'RotatE', + 'TransF' ] diff --git a/torch_geometric/nn/kge/transf.py b/torch_geometric/nn/kge/transf.py new file mode 100644 index 000000000000..7daf67814c67 --- /dev/null +++ b/torch_geometric/nn/kge/transf.py @@ -0,0 +1,105 @@ +import math + +import torch +import torch.nn.functional as F +from torch import Tensor + +from torch_geometric.nn.kge import KGEModel + +class TransF(KGEModel): + r"""The TransF model from the "Knowledge Graph Embedding by Flexible Translation" + paper. + + :class:`TransF` introduces a flexible translation mechanism by dynamically + scaling the relation vector based on head and tail entity embeddings, resulting in: + + .. math:: + \mathbf{e}_h + f(\mathbf{e}_h, \mathbf{e}_t, \mathbf{e}_r) \cdot \mathbf{e}_r \approx \mathbf{e}_t + + where :math:`f` is a dynamic scaling function: + + .. math:: + f(\mathbf{e}_h, \mathbf{e}_t, \mathbf{e}_r) = \sigma((\mathbf{e}_h \odot \mathbf{e}_t) \cdot \mathbf{e}_r) + + This results in the scoring function: + + .. math:: + d(h, r, t) = - \| \mathbf{e}_h + f(\mathbf{e}_h, \mathbf{e}_t, \mathbf{e}_r) \cdot \mathbf{e}_r - \mathbf{e}_t \|_p + + .. note:: + For an example of using the :class:`TransF` model, see + `examples/kge_fb15k_237.py + `_. + + Args: + num_nodes (int): The number of nodes/entities in the graph. + num_relations (int): The number of relations in the graph. + hidden_channels (int): The hidden embedding size. + margin (int, optional): The margin of the ranking loss. + (default: :obj:`1.0`) + p_norm (int, optional): The order embedding and distance normalization. + (default: :obj:`1.0`) + sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the + embedding matrices will be sparse. (default: :obj:`False`) + """ + def __init__( + self, + num_nodes: int, + num_relations: int, + hidden_channels: int, + margin: float = 1.0, + p_norm: float = 1.0, + sparse: bool = False, + ): + super().__init__(num_nodes, num_relations, hidden_channels, sparse) + + self.p_norm = p_norm + self.margin = margin + + self.reset_parameters() + + def reset_parameters(self): + bound = 6. / math.sqrt(self.hidden_channels) + torch.nn.init.uniform_(self.node_emb.weight, -bound, bound) + torch.nn.init.uniform_(self.rel_emb.weight, -bound, bound) + F.normalize(self.rel_emb.weight.data, p=self.p_norm, dim=-1, + out=self.rel_emb.weight.data) + + def forward( + self, + head_index: Tensor, + rel_type: Tensor, + tail_index: Tensor, + ) -> Tensor: + + head = self.node_emb(head_index) + rel = self.rel_emb(rel_type) + tail = self.node_emb(tail_index) + + head = F.normalize(head, p=self.p_norm, dim=-1) + tail = F.normalize(tail, p=self.p_norm, dim=-1) + + # Flexible scaling function based on head, relation, and tail + scaling_factor = torch.sigmoid((head * tail).sum(dim=-1, keepdim=True)) + adjusted_rel = scaling_factor * rel + + # Calculate *negative* TransF norm: + return -((head + adjusted_rel) - tail).norm(p=self.p_norm, dim=-1) + + def loss( + self, + head_index: Tensor, + rel_type: Tensor, + tail_index: Tensor, + ) -> Tensor: + + pos_score = self(head_index, rel_type, tail_index) + neg_score = self(*self.random_sample(head_index, rel_type, tail_index)) + + return F.margin_ranking_loss( + pos_score, + neg_score, + target=torch.ones_like(pos_score), + margin=self.margin, + )