Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CS224W - Bag of Tricks for Node Classification with GNN - GAT Normalization #9840

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `normalize` parameter to `GATConv` and `GATv2Conv` ([#9840](https://github.com/pyg-team/pytorch_geometric/pull/9840))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
6 changes: 4 additions & 2 deletions benchmark/citation/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
parser.add_argument('--profile', action='store_true')
parser.add_argument('--bf16', action='store_true')
parser.add_argument('--compile', action='store_true')
parser.add_argument('--normalize', action='store_true')
args = parser.parse_args()


class Net(torch.nn.Module):
def __init__(self, dataset):
super().__init__()
self.conv1 = GATConv(dataset.num_features, args.hidden,
heads=args.heads, dropout=args.dropout)
heads=args.heads, dropout=args.dropout,
normalize=args.normalize)
self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes,
heads=args.output_heads, concat=False,
dropout=args.dropout)
dropout=args.dropout, normalize=args.normalize)

def reset_parameters(self):
self.conv1.reset_parameters()
Expand Down
15 changes: 10 additions & 5 deletions examples/gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
parser.add_argument('--heads', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--norm_adj', type=bool, default=False,
help="GAT with symmetric normalized adjacency")
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

Expand All @@ -28,20 +30,23 @@
device = torch.device('cpu')

init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs,
hidden_channels=args.hidden_channels, lr=args.lr, device=device)
hidden_channels=args.hidden_channels, lr=args.lr, device=device,
norm_adj=args.norm_adj)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())
data = dataset[0].to(device)


class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads):
def __init__(self, in_channels, hidden_channels, out_channels, heads,
normalize):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)
self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6,
normalize=normalize)
# On the Pubmed dataset, use `heads` output heads in `conv2`.
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,
concat=False, dropout=0.6)
concat=False, dropout=0.6, normalize=normalize)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
Expand All @@ -52,7 +57,7 @@ def forward(self, x, edge_index):


model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes,
args.heads).to(device)
args.heads, args.norm_adj).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)


Expand Down
89 changes: 87 additions & 2 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@

import torch_geometric.typing
from torch_geometric.nn import GATConv
from torch_geometric.nn.conv.gat_conv import gat_norm
from torch_geometric.testing import is_full_test, withDevice
from torch_geometric.typing import Adj, Size, SparseTensor
from torch_geometric.utils import to_torch_csc_tensor
from torch_geometric.typing import Adj, Size, SparseTensor, torch_sparse
from torch_geometric.utils import to_torch_csc_tensor, to_torch_csr_tensor


@pytest.mark.parametrize('residual', [False, True])
Expand Down Expand Up @@ -155,6 +156,45 @@ def forward(
assert torch.allclose(jit((x1, x2), adj2.t()), out1, atol=1e-6)
assert torch.allclose(jit((x1, None), adj2.t()), out2, atol=1e-6)

# Test GAT normalization:
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATConv(8, 32, heads=2, residual=residual, normalize=True)
assert str(conv) == 'GATConv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)
assert torch.allclose(conv(x1, adj1.t()), out)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = conv

def forward(
self,
x: Tensor,
edge_index: Adj,
size: Size = None,
) -> Tensor:
return self.conv(x, edge_index, size=size)

jit = torch.jit.script(MyModule())
assert torch.allclose(jit(x1, edge_index), out)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)

if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)


def test_gat_conv_with_edge_attr():
x = torch.randn(4, 8)
Expand Down Expand Up @@ -201,3 +241,48 @@ def test_gat_conv_empty_edge_index(device):
conv = GATConv(8, 32, heads=2).to(device)
out = conv(x, edge_index)
assert out.size() == (0, 64)


def test_gat_conv_csc_error():
x1 = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csr_tensor(edge_index, size=(4, 4))

with pytest.raises(ValueError, match="Unexpected sparse tensor layout"):
conv = GATConv(8, 32, heads=2, normalize=True)
assert str(conv) == 'GATConv(8, 32, heads=2)'
_ = conv(x1, adj1.t())


def test_gat_norm_csc_error():
edge_index = torch.tensor([[1, 2, 3], [0, 1, 1]])
edge_weight = torch.tensor([[1.0000, 1.0000], [1.2341, 0.9614],
[0.7659, 1.0386]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

with pytest.raises(NotImplementedError,
match="Sparse CSC matrices are not yet supported"):
gat_norm(adj1, edge_weight)


def test_gat_conv_bipartite_error():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

with pytest.raises(NotImplementedError,
match="not supported for bipartite message passing"):
conv = GATConv((8, 16), 32, heads=2, normalize=True)
conv((x1, x2), edge_index)


def test_remove_diag_sparse_tensor():
# Used in GAT Normalization
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_index2 = torch.tensor([[1, 2, 3], [0, 1, 1]])

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4))

assert torch_sparse.remove_diag(adj1.t()) == adj2.t()
49 changes: 49 additions & 0 deletions test/nn/conv/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,44 @@ def forward(
if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)

# Test GAT normalization:
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

conv = GATv2Conv(8, 32, heads=2, residual=residual, normalize=True)
assert str(conv) == 'GATv2Conv(8, 32, heads=2)'
out = conv(x1, edge_index)
assert out.size() == (4, 64)
assert torch.allclose(conv(x1, edge_index), out)
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)

if is_full_test():

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = conv

def forward(
self,
x: Tensor,
edge_index: Adj,
) -> Tensor:
return self.conv(x, edge_index)

jit = torch.jit.script(MyModule())
assert torch.allclose(jit(x1, edge_index), out)
assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)

if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)


def test_gatv2_conv_with_edge_attr():
x = torch.randn(4, 8)
Expand All @@ -163,3 +201,14 @@ def test_gatv2_conv_with_edge_attr():
conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean')
out = conv(x, edge_index, edge_attr)
assert out.size() == (4, 64)


def test_gat_conv_bipartite_error():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

with pytest.raises(NotImplementedError,
match="not supported for bipartite message passing"):
conv = GATv2Conv((8, 16), 32, heads=2, normalize=True)
conv((x1, x2), edge_index)
Loading
Loading