Skip to content

Commit

Permalink
Add option for relation to have different hidden channels
Browse files Browse the repository at this point in the history
  • Loading branch information
jesstoh committed Dec 15, 2024
1 parent 36d5d54 commit ee4ea1d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
6 changes: 4 additions & 2 deletions test/nn/kge/test_transd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@


def test_transe():
model = TransD(num_nodes=10, num_relations=5, hidden_channels=32)
assert str(model) == 'TransD(10, num_relations=5, hidden_channels=32)'
model = TransD(num_nodes=10, num_relations=5, hidden_channels=32,
rel_hidden_channels=64)
assert str(model) == ('TransD(10, num_relations=5, hidden_channels=32, '
'rel_hidden_channels=64)')

head_index = torch.tensor([0, 2, 4, 6, 8])
rel_type = torch.tensor([0, 1, 2, 3, 4])
Expand Down
28 changes: 26 additions & 2 deletions torch_geometric/nn/kge/transd.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class TransD(KGEModel):
(default: :obj:`1.0`)
sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
embedding matrices will be sparse. (default: :obj:`False`)
rel_hidden_channels (int, optional): The hidden embedding size of
relation if provided. Otherwise, relation embedding size has the
same size as :obj:`hidden_channels`. (default: :obj:`None`)
"""
def __init__(
self,
Expand All @@ -59,15 +62,22 @@ def __init__(
margin: float = 1.0,
p_norm: float = 1.0,
sparse: bool = False,
rel_hidden_channels: int = None,
):
super().__init__(num_nodes, num_relations, hidden_channels, sparse)

self.p_norm = p_norm
self.margin = margin
self.rel_hidden_channels = hidden_channels

if rel_hidden_channels is not None:
self.rel_hidden_channels = rel_hidden_channels
self.rel_emb = Embedding(num_relations, rel_hidden_channels,
sparse=sparse)

self.node_proj_emb = Embedding(num_nodes, hidden_channels,
sparse=sparse)
self.rel_proj_emb = Embedding(num_relations, hidden_channels,
self.rel_proj_emb = Embedding(num_relations, self.rel_hidden_channels,
sparse=sparse)

self.reset_parameters()
Expand Down Expand Up @@ -120,6 +130,12 @@ def loss(
margin=self.margin,
)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.num_nodes}, '
f'num_relations={self.num_relations}, '
f'hidden_channels={self.hidden_channels}, '
f'rel_hidden_channels={self.rel_hidden_channels})')


def project_vector_ops(
node_emb: Tensor,
Expand All @@ -141,5 +157,13 @@ def project_vector_ops(
Tensor: The projected vector of node to relation embedding space.
"""
out = (node_emb * node_proj_emb).sum(dim=-1).reshape(-1, 1)
out = out * rel_proj_emb + node_emb
out = out * rel_proj_emb

# Efficient identity matrix and node_emb multiplication
node_dim, rel_dim = node_emb.size(1), rel_proj_emb.size(1)
dim = min(node_dim, rel_dim)
node_identity = torch.zeros_like(rel_proj_emb)
node_identity[:, :dim] = node_emb[:, :dim]

out += node_identity
return out

0 comments on commit ee4ea1d

Please sign in to comment.