From ee4ea1da2a428fb96800aa44683a408e3da78b71 Mon Sep 17 00:00:00 2001 From: jesstoh Date: Sat, 14 Dec 2024 17:15:14 -0800 Subject: [PATCH] Add option for relation to have different hidden channels --- test/nn/kge/test_transd.py | 6 ++++-- torch_geometric/nn/kge/transd.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/test/nn/kge/test_transd.py b/test/nn/kge/test_transd.py index 56d2716211d9..607aeb96f695 100644 --- a/test/nn/kge/test_transd.py +++ b/test/nn/kge/test_transd.py @@ -4,8 +4,10 @@ def test_transe(): - model = TransD(num_nodes=10, num_relations=5, hidden_channels=32) - assert str(model) == 'TransD(10, num_relations=5, hidden_channels=32)' + model = TransD(num_nodes=10, num_relations=5, hidden_channels=32, + rel_hidden_channels=64) + assert str(model) == ('TransD(10, num_relations=5, hidden_channels=32, ' + 'rel_hidden_channels=64)') head_index = torch.tensor([0, 2, 4, 6, 8]) rel_type = torch.tensor([0, 1, 2, 3, 4]) diff --git a/torch_geometric/nn/kge/transd.py b/torch_geometric/nn/kge/transd.py index 1867cd4bd554..e49b543b922b 100644 --- a/torch_geometric/nn/kge/transd.py +++ b/torch_geometric/nn/kge/transd.py @@ -50,6 +50,9 @@ class TransD(KGEModel): (default: :obj:`1.0`) sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the embedding matrices will be sparse. (default: :obj:`False`) + rel_hidden_channels (int, optional): The hidden embedding size of + relation if provided. Otherwise, relation embedding size has the + same size as :obj:`hidden_channels`. (default: :obj:`None`) """ def __init__( self, @@ -59,15 +62,22 @@ def __init__( margin: float = 1.0, p_norm: float = 1.0, sparse: bool = False, + rel_hidden_channels: int = None, ): super().__init__(num_nodes, num_relations, hidden_channels, sparse) self.p_norm = p_norm self.margin = margin + self.rel_hidden_channels = hidden_channels + + if rel_hidden_channels is not None: + self.rel_hidden_channels = rel_hidden_channels + self.rel_emb = Embedding(num_relations, rel_hidden_channels, + sparse=sparse) self.node_proj_emb = Embedding(num_nodes, hidden_channels, sparse=sparse) - self.rel_proj_emb = Embedding(num_relations, hidden_channels, + self.rel_proj_emb = Embedding(num_relations, self.rel_hidden_channels, sparse=sparse) self.reset_parameters() @@ -120,6 +130,12 @@ def loss( margin=self.margin, ) + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.num_nodes}, ' + f'num_relations={self.num_relations}, ' + f'hidden_channels={self.hidden_channels}, ' + f'rel_hidden_channels={self.rel_hidden_channels})') + def project_vector_ops( node_emb: Tensor, @@ -141,5 +157,13 @@ def project_vector_ops( Tensor: The projected vector of node to relation embedding space. """ out = (node_emb * node_proj_emb).sum(dim=-1).reshape(-1, 1) - out = out * rel_proj_emb + node_emb + out = out * rel_proj_emb + + # Efficient identity matrix and node_emb multiplication + node_dim, rel_dim = node_emb.size(1), rel_proj_emb.size(1) + dim = min(node_dim, rel_dim) + node_identity = torch.zeros_like(rel_proj_emb) + node_identity[:, :dim] = node_emb[:, :dim] + + out += node_identity return out