-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
779 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import argparse | ||
import os.path as osp | ||
import time | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.logging import init_wandb, log | ||
from torch_geometric.nn import HGCNConv | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--dataset', type=str, default='Cora') | ||
parser.add_argument('--hidden_channels', type=int, default=128) | ||
parser.add_argument('--heads', type=int, default=8) | ||
parser.add_argument('--lr', type=float, default=0.1e-3) | ||
parser.add_argument('--epochs', type=int, default=100) | ||
parser.add_argument('--wandb', action='store_true', help='Track experiment') | ||
args = parser.parse_args() | ||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
init_wandb(name=f'HGCN-{args.dataset}', heads=args.heads, epochs=args.epochs, | ||
hidden_channels=args.hidden_channels, lr=args.lr, device=device) | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) | ||
data = dataset[0].to(device) | ||
|
||
|
||
class LinearDecoder(torch.nn.Module): | ||
"""MLP Decoder for Hyperbolic/Euclidean node classification models. | ||
""" | ||
def __init__(self, manifold, in_dim, out_dim, c): | ||
super(LinearDecoder, self).__init__() | ||
self.manifold = manifold | ||
self.input_dim = in_dim | ||
self.output_dim = out_dim | ||
self.cls = torch.nn.Linear(self.input_dim, self.output_dim) | ||
self.c = torch.nn.Parameter(torch.FloatTensor([c])) | ||
|
||
def forward(self, x): | ||
x = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c), | ||
c=self.c) | ||
x = self.cls(x) | ||
return x | ||
|
||
def extra_repr(self): | ||
return 'in_features={}, out_features={}, c={}'.format( | ||
self.input_dim, self.output_dim, self.c) | ||
|
||
|
||
class HGCN(torch.nn.Module): | ||
def __init__(self, in_channels, hidden_channels, out_channels): | ||
super().__init__() | ||
self.conv1 = HGCNConv(in_channels, hidden_channels, dropout=0.5) | ||
self.decoder = LinearDecoder(self.conv1.manifold, hidden_channels, | ||
out_channels, 1.0) | ||
|
||
def forward(self, x, edge_index): | ||
# print(x.shape, edge_index.shape) | ||
x = self.conv1(x, edge_index) | ||
x = F.dropout(x, p=0.6, training=self.training) | ||
x = self.decoder(x) | ||
x = F.log_softmax(x, dim=-1) | ||
# print(x.shape) | ||
# exit(0) | ||
return x | ||
|
||
|
||
model = HGCN(dataset.num_features, args.hidden_channels, | ||
dataset.num_classes).to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3) | ||
print(model) | ||
|
||
|
||
def train(): | ||
model.train() | ||
optimizer.zero_grad() | ||
out = model(data.x, data.edge_index) | ||
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
return float(loss) | ||
|
||
|
||
@torch.no_grad() | ||
def test(): | ||
model.eval() | ||
pred = model(data.x, data.edge_index).argmax(dim=-1) | ||
|
||
accs = [] | ||
for mask in [data.train_mask, data.val_mask, data.test_mask]: | ||
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) | ||
return accs | ||
|
||
|
||
times = [] | ||
best_val_acc = final_test_acc = 0 | ||
for epoch in range(1, args.epochs + 1): | ||
start = time.time() | ||
loss = train() | ||
train_acc, val_acc, tmp_test_acc = test() | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
test_acc = tmp_test_acc | ||
log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc) | ||
times.append(time.time() - start) | ||
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn.conv.hgcn_conv import ( | ||
HGCNConv, | ||
HypAct, | ||
Hyperboloid, | ||
HypLinear, | ||
PoincareBall, | ||
) | ||
from torch_geometric.utils import to_torch_csc_tensor | ||
|
||
|
||
def test_hgcn_conv_hyperboloid_forward(): | ||
x = torch.randn(4, 16) | ||
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) | ||
value = torch.rand(edge_index.size(1)) | ||
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) | ||
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) | ||
|
||
conv = HGCNConv(16, 32, manifold='hyperboloid') | ||
assert str(conv) == 'HGCNConv(16, 32)' | ||
|
||
out1 = conv(x, edge_index) | ||
assert out1.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) | ||
|
||
out2 = conv(x, edge_index, value) | ||
assert out2.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) | ||
return | ||
|
||
|
||
def test_hgcn_conv_poincareBall_forward(): | ||
x = torch.randn(4, 16) | ||
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) | ||
value = torch.rand(edge_index.size(1)) | ||
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) | ||
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) | ||
|
||
conv = HGCNConv(16, 32, manifold='poincare') | ||
assert str(conv) == 'HGCNConv(16, 32)' | ||
|
||
out1 = conv(x, edge_index) | ||
assert out1.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) | ||
|
||
out2 = conv(x, edge_index, value) | ||
assert out2.size() == (4, 32) | ||
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) | ||
return | ||
|
||
|
||
def test_hgcn_linear_hyperboloid_forward(): | ||
x = torch.randn(4, 16) | ||
manifold = Hyperboloid() | ||
|
||
linaer = HypLinear(manifold, 16, 32, 1.0, 0.0, True) | ||
|
||
out1 = linaer(x) | ||
assert out1.size() == (4, 32) | ||
return | ||
|
||
|
||
def test_hgcn_linear_poincareBall_forward(): | ||
x = torch.randn(4, 16) | ||
manifold = PoincareBall() | ||
|
||
linaer = HypLinear(manifold, 16, 32, 1.0, 0.0, True) | ||
|
||
out1 = linaer(x) | ||
assert out1.size() == (4, 32) | ||
return | ||
|
||
|
||
def test_hypact_hyperboloid_forward(): | ||
x = torch.randn(4, 16) | ||
manifold = Hyperboloid() | ||
|
||
hypact = HypAct(manifold, 1.0, 1.0, None) | ||
|
||
out1 = hypact(x) | ||
assert out1.size() == (4, 16) | ||
return | ||
|
||
|
||
def test_hypact_poincareBall_forward(): | ||
x = torch.randn(4, 16) | ||
manifold = PoincareBall() | ||
|
||
hypact = HypAct(manifold, 1.0, 1.0, None) | ||
|
||
out1 = hypact(x) | ||
assert out1.size() == (4, 16) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.