Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W: TransH model implementation for KGE #9868

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Added `TransH` implementation to kge models ([#9868](https://github.com/pyg-team/pytorch_geometric/pull/9868))

### Changed

Expand Down
4 changes: 3 additions & 1 deletion examples/kge_fb15k_237.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import torch.optim as optim

from torch_geometric.datasets import FB15k_237
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE
from torch_geometric.nn import ComplEx, DistMult, RotatE, TransE, TransH

model_map = {
'transe': TransE,
'complex': ComplEx,
'distmult': DistMult,
'rotate': RotatE,
'transh': TransH,
}

parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -47,6 +48,7 @@
'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),
'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),
'rotate': optim.Adam(model.parameters(), lr=1e-3),
'transh': optim.Adam(model.parameters(), lr=0.01),
}
optimizer = optimizer_map[args.model]

Expand Down
25 changes: 25 additions & 0 deletions test/nn/kge/test_transh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch

from torch_geometric.nn import TransH


def test_transh():
model = TransH(num_nodes=10, num_relations=5, hidden_channels=32)
assert str(model) == 'TransH(10, num_relations=5, hidden_channels=32)'

head_index = torch.tensor([0, 2, 4, 6, 8])
rel_type = torch.tensor([0, 1, 2, 3, 4])
tail_index = torch.tensor([1, 3, 5, 7, 9])

loader = model.loader(head_index, rel_type, tail_index, batch_size=5)
for h, r, t in loader:
out = model(h, r, t)
assert out.size() == (5, )

loss = model.loss(h, r, t)
assert loss >= 0.

mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)
assert 0 <= mean_rank <= 10
assert 0 < mrr <= 1
assert hits == 1.0
Loading
Loading