From 4c9a38461ddccd7422704ee27f1bfc36296a5014 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 6 Nov 2024 23:43:10 +0100 Subject: [PATCH 01/16] Use unnormalized attention matrix for GAT --- test/nn/conv/test_gat_conv.py | 17 ++++ torch_geometric/nn/conv/gat_conv.py | 119 ++++++++++++++++++++++++++++ torch_geometric/typing.py | 6 ++ 3 files changed, 142 insertions(+) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 6549911ac0d4..bf861b305eea 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -155,6 +155,23 @@ def forward( assert torch.allclose(jit((x1, x2), adj2.t()), out1, atol=1e-6) assert torch.allclose(jit((x1, None), adj2.t()), out2, atol=1e-6) + # Test GAT normalization: + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + conv = GATConv(8, 32, heads=2, residual=residual, normalize=True) + assert str(conv) == 'GATConv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) + # assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + def test_gat_conv_with_edge_attr(): x = torch.randn(4, 8) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 720dfb09811c..e7e68e92a93d 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -22,8 +22,10 @@ add_self_loops, is_torch_sparse_tensor, remove_self_loops, + scatter, softmax, ) +from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.utils.sparse import set_sparse_value if typing.TYPE_CHECKING: @@ -32,6 +34,91 @@ from torch.jit import _overload_method as overload +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (Tensor, Tensor, Optional[int], str, Optional[int]) -> Tuple[Tensor, Tensor] # noqa + pass + + +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (SparseTensor, Tensor, Optional[int], str, Optional[int]) -> Tuple[SparseTensor, Tensor] # noqa + pass + + +@torch.jit._overload +def gat_norm( # noqa: F811 + edge_index, edge_weight, num_nodes, flow, dtype): + # type: (Adj, Tensor, Optional[int], str, Optional[int]) -> Tuple[Adj, Tensor] # noqa + pass + + +def gat_norm( # noqa: F811 + edge_index: Adj, + edge_weight: Tensor, + num_nodes: Optional[int] = None, + flow: str = "source_to_target", + dtype: Optional[torch.dtype] = None, +): + fill_value = 1.0 + + if isinstance(edge_index, SparseTensor): + assert edge_index.size(0) == edge_index.size(1) + + adj_t = edge_index + + if not adj_t.has_value(): + adj_t = adj_t.fill_value(1., dtype=dtype) + + deg = torch_sparse.sum(adj_t, dim=1) + att_mat = edge_index.copy() + att_mat = att_mat.set_value(edge_weight) + att_mat = torch_sparse.mul(att_mat, + deg) # unnormalized attention matrix + + # Add self-loop, also called renormalization trick + att_mat = torch_sparse.fill_diag(att_mat, fill_value) + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0.) + + att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(-1, 1)) + att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(1, -1)) + + return adj_t, att_mat.to_dense() + + if is_torch_sparse_tensor(edge_index): + raise NotImplementedError("Sparse CSC and COO matrices are not yet " + "supported in 'gat_norm'") + + assert flow in ['source_to_target', 'target_to_source'] + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + adj_t = torch.ones((edge_index.size(1), ), dtype=dtype, + device=edge_index.device) + + row, col = edge_index[0], edge_index[1] + idx = col if flow == 'source_to_target' else row + deg = scatter(adj_t, idx, dim=0, dim_size=num_nodes, reduce='sum') + deg_expand = deg[col] + att_mat = (deg_expand.unsqueeze(1).mul(edge_weight) + ) # unnormalized attention matrix + + # Add self-loop, also called renormalization trick + adj_t, att_mat = add_self_loops(edge_index, att_mat, fill_value=fill_value, + num_nodes=num_nodes) + row, col = adj_t[0], adj_t[1] + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0) + att_mat = deg_tilde_inv_sqrt[row].unsqueeze( + 1) * att_mat * deg_tilde_inv_sqrt[col].unsqueeze(1) + + return adj_t, att_mat + + class GATConv(MessagePassing): r"""The graph attentional operator from the `"Graph Attention Networks" `_ paper. @@ -76,6 +163,21 @@ class GATConv(MessagePassing): If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. + Normalization will be computed as presented in `"Bag of Tricks for Node + Classification with Graph Neural Networks" + `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}_{att}} + \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}_{att}} = A_{att} + \mathbf{I}` denotes the + unnormalized attention matrix :math:`\mathbf{\hat{A}_{att}} + = \mathbf{D}\mathbf{alpha}` + with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` + its diagonal degree matrix based on the adjacency matrix with inserted + self-loops :math:`\hat{A} = A + I`. + Args: in_channels (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. @@ -109,6 +211,8 @@ class GATConv(MessagePassing): an additive bias. (default: :obj:`True`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + normalize (bool, optional): If set to :obj:`True`, will add + self-loops to the input graph. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. @@ -140,11 +244,15 @@ def __init__( fill_value: Union[float, Tensor, str] = 'mean', bias: bool = True, residual: bool = False, + normalize: bool = False, **kwargs, ): kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) + if normalize: + add_self_loops = False + self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -155,6 +263,7 @@ def __init__( self.edge_dim = edge_dim self.fill_value = fill_value self.residual = residual + self.normalize = normalize # In case we are operating in bipartite graphs, we apply separate # transformations 'lin_src' and 'lin_dst' to source and target nodes: @@ -358,10 +467,20 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if isinstance(edge_index, Tensor): + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + elif isinstance(edge_index, SparseTensor): + edge_index = torch_sparse.fill_diag(edge_index, 0.0) + # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, size=size) + if self.normalize: + edge_index, alpha = gat_norm(edge_index, alpha) # yapf: disable + # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index 468f37abfaed..e0fb657cb294 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -183,6 +183,9 @@ def from_dense(self, mat: Tensor, has_value: bool = True) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") + def copy(self) -> 'SparseTensor': + raise ImportError("'SparseTensor' requires 'torch-sparse'") + def size(self, dim: int) -> int: raise ImportError("'SparseTensor' requires 'torch-sparse'") @@ -218,6 +221,9 @@ def to_torch_sparse_csr_tensor( ) -> Tensor: raise ImportError("'SparseTensor' requires 'torch-sparse'") + def to_dense(self) -> torch.Tensor: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + class torch_sparse: # type: ignore @staticmethod def matmul(src: SparseTensor, other: Tensor, From 2cfaefc1e807f9149b8e9afb9e4071619acf7d47 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Sat, 30 Nov 2024 15:25:22 +0100 Subject: [PATCH 02/16] Update GAT norm with torch_csr_tensor --- examples/gat.py | 15 +++++--- test/nn/conv/test_gat_conv.py | 27 ++++++++++++-- torch_geometric/nn/conv/gat_conv.py | 56 +++++++++++++++++++++++------ 3 files changed, 80 insertions(+), 18 deletions(-) diff --git a/examples/gat.py b/examples/gat.py index 09f90011efdd..ed8d0b91d5e6 100644 --- a/examples/gat.py +++ b/examples/gat.py @@ -17,6 +17,8 @@ parser.add_argument('--heads', type=int, default=8) parser.add_argument('--lr', type=float, default=0.005) parser.add_argument('--epochs', type=int, default=200) +parser.add_argument('--norm_adj', type=bool, default=False, + help="GAT with symmetric normalized adjacency") parser.add_argument('--wandb', action='store_true', help='Track experiment') args = parser.parse_args() @@ -28,7 +30,8 @@ device = torch.device('cpu') init_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs, - hidden_channels=args.hidden_channels, lr=args.lr, device=device) + hidden_channels=args.hidden_channels, lr=args.lr, device=device, + norm_adj=args.norm_adj) path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures()) @@ -36,12 +39,14 @@ class GAT(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, heads): + def __init__(self, in_channels, hidden_channels, out_channels, heads, + normalize): super().__init__() - self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6) + self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6, + normalize=normalize) # On the Pubmed dataset, use `heads` output heads in `conv2`. self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, - concat=False, dropout=0.6) + concat=False, dropout=0.6, normalize=normalize) def forward(self, x, edge_index): x = F.dropout(x, p=0.6, training=self.training) @@ -52,7 +57,7 @@ def forward(self, x, edge_index): model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes, - args.heads).to(device) + args.heads, args.normalize).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index bf861b305eea..25ddfa90447f 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -6,9 +6,10 @@ import torch_geometric.typing from torch_geometric.nn import GATConv +from torch_geometric.nn.conv.gat_conv import gat_norm from torch_geometric.testing import is_full_test, withDevice from torch_geometric.typing import Adj, Size, SparseTensor -from torch_geometric.utils import to_torch_csc_tensor +from torch_geometric.utils import to_torch_csc_tensor, to_torch_csr_tensor @pytest.mark.parametrize('residual', [False, True]) @@ -166,7 +167,7 @@ def forward( out = conv(x1, edge_index) assert out.size() == (4, 64) assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out) - # assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + assert torch.allclose(conv(x1, adj1.t()), out) if torch_geometric.typing.WITH_TORCH_SPARSE: adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) @@ -218,3 +219,25 @@ def test_gat_conv_empty_edge_index(device): conv = GATConv(8, 32, heads=2).to(device) out = conv(x, edge_index) assert out.size() == (0, 64) + + +def test_gat_conv_csc_error(): + x1 = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csr_tensor(edge_index, size=(4, 4)) + + with pytest.raises(ValueError, match="Unexpected sparse tensor layout"): + conv = GATConv(8, 32, heads=2, normalize=True) + assert str(conv) == 'GATConv(8, 32, heads=2)' + _ = conv(x1, adj1.t()) + + +def test_gat_norm_csc_error(): + edge_index = torch.tensor([[1, 2, 3], [0, 1, 1]]) + edge_weight = torch.tensor([[1.0000, 1.0000], [1.2341, 0.9614], + [0.7659, 1.0386]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + with pytest.raises(NotImplementedError, + match="Sparse CSC matrices are not yet supported"): + gat_norm(adj1, edge_weight) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index e7e68e92a93d..3dfc49252bde 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -19,11 +19,14 @@ torch_sparse, ) from torch_geometric.utils import ( + add_remaining_self_loops, add_self_loops, is_torch_sparse_tensor, remove_self_loops, scatter, softmax, + to_edge_index, + to_torch_csr_tensor, ) from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_geometric.utils.sparse import set_sparse_value @@ -90,8 +93,39 @@ def gat_norm( # noqa: F811 return adj_t, att_mat.to_dense() if is_torch_sparse_tensor(edge_index): - raise NotImplementedError("Sparse CSC and COO matrices are not yet " - "supported in 'gat_norm'") + assert edge_index.size(0) == edge_index.size(1) + + if edge_index.layout == torch.sparse_csc: + raise NotImplementedError("Sparse CSC matrices are not yet " + "supported in 'gat_norm'") + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + adj_t = edge_index + edge_index, value = to_edge_index(adj_t) + + col, row = edge_index[0], edge_index[1] + idx = col if flow == 'source_to_target' else row + deg = scatter(value, idx, 0, dim_size=num_nodes, reduce='sum') + att_mat = deg[col].view(-1, 1) * edge_weight + + # Add self-loop, also called renormalization trick + edge_index, att_mat = add_remaining_self_loops(edge_index, att_mat, + fill_value=fill_value, + num_nodes=num_nodes) + col, row = edge_index[0], edge_index[1] + deg_tilde = deg + 1 + deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) + deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0) + att_mat = deg_tilde_inv_sqrt[row].view( + -1, 1) * att_mat * deg_tilde_inv_sqrt[col].view(-1, 1) + + # Sort edge_index lexicographically + sorted_indices = torch.argsort(edge_index[0] * edge_index.size(1) + + edge_index[1]) + edge_index = edge_index[:, sorted_indices] + att_mat = att_mat[sorted_indices] + + return to_torch_csr_tensor(edge_index), att_mat assert flow in ['source_to_target', 'target_to_source'] num_nodes = maybe_num_nodes(edge_index, num_nodes) @@ -102,9 +136,8 @@ def gat_norm( # noqa: F811 row, col = edge_index[0], edge_index[1] idx = col if flow == 'source_to_target' else row deg = scatter(adj_t, idx, dim=0, dim_size=num_nodes, reduce='sum') - deg_expand = deg[col] - att_mat = (deg_expand.unsqueeze(1).mul(edge_weight) - ) # unnormalized attention matrix + att_mat = deg[col].view(-1, + 1) * edge_weight # unnormalized attention matrix # Add self-loop, also called renormalization trick adj_t, att_mat = add_self_loops(edge_index, att_mat, fill_value=fill_value, @@ -163,9 +196,9 @@ class GATConv(MessagePassing): If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. - Normalization will be computed as presented in `"Bag of Tricks for Node - Classification with Graph Neural Networks" - `_ paper. + Normalization will be computed as presented in + `"Bag of Tricks for Node Classification with Graph Neural Networks" + `__ paper. .. math:: \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}_{att}} @@ -173,8 +206,8 @@ class GATConv(MessagePassing): where :math:`\mathbf{\hat{A}_{att}} = A_{att} + \mathbf{I}` denotes the unnormalized attention matrix :math:`\mathbf{\hat{A}_{att}} - = \mathbf{D}\mathbf{alpha}` - with inserted self-loops and :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` + = \mathbf{D}\mathbf{\alpha}` + with inserted self-loops, :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix based on the adjacency matrix with inserted self-loops :math:`\hat{A} = A + I`. @@ -212,7 +245,8 @@ class GATConv(MessagePassing): residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) normalize (bool, optional): If set to :obj:`True`, will add - self-loops to the input graph. (default: :obj:`False`) + self-loops to the input graph and compute symmetric normalization + coefficients on-the-fly. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. From d2a52a0d663d28c50c16c11d41e5f7ed7be9e561 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Sat, 30 Nov 2024 17:04:08 +0100 Subject: [PATCH 03/16] Add full test to GAT Conv with normalization --- test/nn/conv/test_gat_conv.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 25ddfa90447f..47ce54df50c0 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -173,6 +173,28 @@ def forward( adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + size: Size = None, + ) -> Tensor: + return self.conv(x, edge_index, size=size) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + def test_gat_conv_with_edge_attr(): x = torch.randn(4, 8) From 45f63fd538363dbf816049fe331ad5dffea67d19 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Sat, 30 Nov 2024 21:09:55 +0100 Subject: [PATCH 04/16] Add normalization to GATv2Conv --- test/nn/conv/test_gatv2_conv.py | 38 +++++++++++++++++++++++++++ torch_geometric/nn/conv/gatv2_conv.py | 19 ++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index 3bca6530eee9..a4c200b9012c 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -141,6 +141,44 @@ def forward( if torch_geometric.typing.WITH_TORCH_SPARSE: assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6) + # Test GAT normalization: + x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) + + conv = GATv2Conv(8, 32, heads=2, residual=residual, normalize=True) + assert str(conv) == 'GATv2Conv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.size() == (4, 64) + assert torch.allclose(conv(x1, edge_index), out) + assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6) + + if is_full_test(): + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = conv + + def forward( + self, + x: Tensor, + edge_index: Adj, + ) -> Tensor: + return self.conv(x, edge_index) + + jit = torch.jit.script(MyModule()) + assert torch.allclose(jit(x1, edge_index), out) + assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out) + + if torch_geometric.typing.WITH_TORCH_SPARSE: + assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6) + def test_gatv2_conv_with_edge_attr(): x = torch.randn(4, 8) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index f3b2f4937e52..e5ff0981ad16 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -7,6 +7,7 @@ from torch.nn import Parameter from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.conv.gat_conv import gat_norm from torch_geometric.nn.dense.linear import Linear from torch_geometric.nn.inits import glorot, zeros from torch_geometric.typing import ( @@ -112,6 +113,9 @@ class GATv2Conv(MessagePassing): (default: :obj:`False`) residual (bool, optional): If set to :obj:`True`, the layer will add a learnable skip-connection. (default: :obj:`False`) + normalize (bool, optional): If set to :obj:`True`, will add + self-loops to the input graph and compute symmetric normalization + coefficients on-the-fly. (default: :obj:`False`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. @@ -144,10 +148,14 @@ def __init__( bias: bool = True, share_weights: bool = False, residual: bool = False, + normalize: bool = False, **kwargs, ): super().__init__(node_dim=0, **kwargs) + if normalize: + add_self_loops = False + self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -159,6 +167,7 @@ def __init__( self.fill_value = fill_value self.residual = residual self.share_weights = share_weights + self.normalize = normalize if isinstance(in_channels, int): self.lin_l = Linear(in_channels, heads * out_channels, bias=bias, @@ -321,10 +330,20 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if isinstance(edge_index, Tensor): + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + elif isinstance(edge_index, SparseTensor): + edge_index = torch_sparse.fill_diag(edge_index, 0.0) + # edge_updater_type: (x: PairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, x=(x_l, x_r), edge_attr=edge_attr) + if self.normalize: + edge_index, alpha = gat_norm(edge_index, alpha) # yapf: disable + # propagate_type: (x: PairTensor, alpha: Tensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha) From 8c056772034703e7f2cb747f8f9172abfddd2d25 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Sun, 1 Dec 2024 10:02:46 +0100 Subject: [PATCH 05/16] Add GAT normalization parameter in benchmark/citation/gat.py --- benchmark/citation/gat.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/benchmark/citation/gat.py b/benchmark/citation/gat.py index f9ed5d6071af..47594ed50bfc 100644 --- a/benchmark/citation/gat.py +++ b/benchmark/citation/gat.py @@ -24,6 +24,7 @@ parser.add_argument('--profile', action='store_true') parser.add_argument('--bf16', action='store_true') parser.add_argument('--compile', action='store_true') +parser.add_argument('--normalize', action='store_true') args = parser.parse_args() @@ -31,10 +32,11 @@ class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = GATConv(dataset.num_features, args.hidden, - heads=args.heads, dropout=args.dropout) + heads=args.heads, dropout=args.dropout, + normalize=args.normalize) self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes, heads=args.output_heads, concat=False, - dropout=args.dropout) + dropout=args.dropout, normalize=args.normalize) def reset_parameters(self): self.conv1.reset_parameters() From beb2e343415476465dd3210761f7973c5f7571f7 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Mon, 2 Dec 2024 13:24:03 +0100 Subject: [PATCH 06/16] Fix typo --- examples/gat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gat.py b/examples/gat.py index ed8d0b91d5e6..bf091241f695 100644 --- a/examples/gat.py +++ b/examples/gat.py @@ -57,7 +57,7 @@ def forward(self, x, edge_index): model = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes, - args.heads, args.normalize).to(device) + args.heads, args.norm_adj).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4) From f346c8c3520af13ac4db1611106992523c781227 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Mon, 2 Dec 2024 13:28:34 +0100 Subject: [PATCH 07/16] Remove unused variable and passing more parameters to gat_norm --- torch_geometric/nn/conv/gat_conv.py | 8 +++++--- torch_geometric/nn/conv/gatv2_conv.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 3dfc49252bde..ac15f9826810 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -100,8 +100,7 @@ def gat_norm( # noqa: F811 "supported in 'gat_norm'") num_nodes = maybe_num_nodes(edge_index, num_nodes) - adj_t = edge_index - edge_index, value = to_edge_index(adj_t) + edge_index, value = to_edge_index(edge_index) col, row = edge_index[0], edge_index[1] idx = col if flow == 'source_to_target' else row @@ -513,7 +512,10 @@ def forward( # noqa: F811 size=size) if self.normalize: - edge_index, alpha = gat_norm(edge_index, alpha) # yapf: disable + edge_index, alpha = gat_norm(edge_index, + alpha, + flow=self.flow, + dtype=alpha.dtype) # yapf: disable # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index e5ff0981ad16..d49e2a53da80 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -342,7 +342,10 @@ def forward( # noqa: F811 edge_attr=edge_attr) if self.normalize: - edge_index, alpha = gat_norm(edge_index, alpha) # yapf: disable + edge_index, alpha = gat_norm(edge_index, + alpha, + flow=self.flow, + dtype=alpha.dtype) # yapf: disable # propagate_type: (x: PairTensor, alpha: Tensor) out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha) From 8bbdbc4d3e6eeafde33f33b9170d41c4c38cb38b Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Mon, 9 Dec 2024 21:24:20 +0100 Subject: [PATCH 08/16] Outline the fact that GAT Normalization does not handle bipartite message passing yet and add num_nodes as parameters --- test/nn/conv/test_gat_conv.py | 11 +++++++++++ test/nn/conv/test_gatv2_conv.py | 11 +++++++++++ torch_geometric/nn/conv/gat_conv.py | 19 ++++++++++++++++++- torch_geometric/nn/conv/gatv2_conv.py | 14 ++++++++++++++ 4 files changed, 54 insertions(+), 1 deletion(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 47ce54df50c0..52054aa99f0e 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -263,3 +263,14 @@ def test_gat_norm_csc_error(): with pytest.raises(NotImplementedError, match="Sparse CSC matrices are not yet supported"): gat_norm(adj1, edge_weight) + + +def test_gat_conv_bipartite_error(): + x1 = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) + + with pytest.raises(NotImplementedError, + match="not supported for bipartite message passing"): + conv = GATConv(8, 32, heads=2, normalize=True) + _ = conv(x1, adj1.t()) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index a4c200b9012c..d8a94e72a9f7 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -201,3 +201,14 @@ def test_gatv2_conv_with_edge_attr(): conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean') out = conv(x, edge_index, edge_attr) assert out.size() == (4, 64) + + +def test_gat_conv_bipartite_error(): + x1 = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) + + with pytest.raises(NotImplementedError, + match="not supported for bipartite message passing"): + conv = GATv2Conv(8, 32, heads=2, normalize=True) + _ = conv(x1, adj1.t()) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index ac15f9826810..afd8340092fb 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -67,6 +67,8 @@ def gat_norm( # noqa: F811 ): fill_value = 1.0 + assert flow in ['source_to_target', 'target_to_source'] + if isinstance(edge_index, SparseTensor): assert edge_index.size(0) == edge_index.size(1) @@ -126,7 +128,6 @@ def gat_norm( # noqa: F811 return to_torch_csr_tensor(edge_index), att_mat - assert flow in ['source_to_target', 'target_to_source'] num_nodes = maybe_num_nodes(edge_index, num_nodes) adj_t = torch.ones((edge_index.size(1), ), dtype=dtype, @@ -500,6 +501,14 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if isinstance(edge_index, + SparseTensor) or is_torch_sparse_tensor(edge_index): + if edge_index.size(0) != edge_index.size(1): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") + if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( @@ -512,8 +521,16 @@ def forward( # noqa: F811 size=size) if self.normalize: + num_nodes = None + if isinstance(edge_index, Tensor): + num_nodes = x_src.size(0) + if x_dst is not None: + num_nodes = min(num_nodes, x_dst.size(0)) + num_nodes = min(size) if size is not None else num_nodes + edge_index, alpha = gat_norm(edge_index, alpha, + num_nodes=num_nodes, flow=self.flow, dtype=alpha.dtype) # yapf: disable diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index d49e2a53da80..590fe8b2c77c 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -330,6 +330,14 @@ def forward( # noqa: F811 "simultaneously is currently not yet supported for " "'edge_index' in a 'SparseTensor' form") + if self.normalize: + if isinstance(edge_index, + SparseTensor) or is_torch_sparse_tensor(edge_index): + if edge_index.size(0) != edge_index.size(1): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") + if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( @@ -342,8 +350,14 @@ def forward( # noqa: F811 edge_attr=edge_attr) if self.normalize: + num_nodes = None + if isinstance(edge_index, Tensor): + num_nodes = x_l.size(0) + if x_r is not None: + num_nodes = min(num_nodes, x_r.size(0)) edge_index, alpha = gat_norm(edge_index, alpha, + num_nodes=num_nodes, flow=self.flow, dtype=alpha.dtype) # yapf: disable From d22db1d2220ec89b2c43099be6c45d153acd787b Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Tue, 10 Dec 2024 21:51:10 +0100 Subject: [PATCH 09/16] Checking self.normalize once --- torch_geometric/nn/conv/gat_conv.py | 1 - torch_geometric/nn/conv/gatv2_conv.py | 1 - 2 files changed, 2 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index afd8340092fb..f364533ddff8 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -509,7 +509,6 @@ def forward( # noqa: F811 "The usage of 'normalize' is not supported " "for bipartite message passing.") - if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index 590fe8b2c77c..a368453adce3 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -338,7 +338,6 @@ def forward( # noqa: F811 "The usage of 'normalize' is not supported " "for bipartite message passing.") - if self.normalize: if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) From c19ada352bd277d960df77295f538cbbbe3d0895 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Tue, 10 Dec 2024 23:09:47 +0100 Subject: [PATCH 10/16] Update test condition to check bipartite graph in forward pass --- test/nn/conv/test_gat_conv.py | 6 +++--- test/nn/conv/test_gatv2_conv.py | 6 +++--- torch_geometric/nn/conv/gat_conv.py | 12 +++++------- torch_geometric/nn/conv/gatv2_conv.py | 10 ++++------ 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 52054aa99f0e..3c1e08210134 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -267,10 +267,10 @@ def test_gat_norm_csc_error(): def test_gat_conv_bipartite_error(): x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) with pytest.raises(NotImplementedError, match="not supported for bipartite message passing"): - conv = GATConv(8, 32, heads=2, normalize=True) - _ = conv(x1, adj1.t()) + conv = GATConv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index d8a94e72a9f7..3a0477d8df37 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -205,10 +205,10 @@ def test_gatv2_conv_with_edge_attr(): def test_gat_conv_bipartite_error(): x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) with pytest.raises(NotImplementedError, match="not supported for bipartite message passing"): - conv = GATv2Conv(8, 32, heads=2, normalize=True) - _ = conv(x1, adj1.t()) + conv = GATv2Conv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index f364533ddff8..998a4fe1ce24 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -502,12 +502,10 @@ def forward( # noqa: F811 "'edge_index' in a 'SparseTensor' form") if self.normalize: - if isinstance(edge_index, - SparseTensor) or is_torch_sparse_tensor(edge_index): - if edge_index.size(0) != edge_index.size(1): - raise NotImplementedError( - "The usage of 'normalize' is not supported " - "for bipartite message passing.") + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( @@ -520,7 +518,7 @@ def forward( # noqa: F811 size=size) if self.normalize: - num_nodes = None + num_nodes: Optional[int] = None if isinstance(edge_index, Tensor): num_nodes = x_src.size(0) if x_dst is not None: diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index a368453adce3..a95117f28feb 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -331,12 +331,10 @@ def forward( # noqa: F811 "'edge_index' in a 'SparseTensor' form") if self.normalize: - if isinstance(edge_index, - SparseTensor) or is_torch_sparse_tensor(edge_index): - if edge_index.size(0) != edge_index.size(1): - raise NotImplementedError( - "The usage of 'normalize' is not supported " - "for bipartite message passing.") + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( From 76cab5da73ca9d3dc57213e14eb594fde638d4c9 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 11 Dec 2024 20:55:38 +0100 Subject: [PATCH 11/16] Fix gat_norm when edge_index is SparseTensor --- torch_geometric/nn/conv/gat_conv.py | 41 +++++++++++++++------------ torch_geometric/nn/conv/gatv2_conv.py | 5 +--- torch_geometric/typing.py | 4 +++ 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 998a4fe1ce24..70e860081c6e 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -78,21 +78,31 @@ def gat_norm( # noqa: F811 adj_t = adj_t.fill_value(1., dtype=dtype) deg = torch_sparse.sum(adj_t, dim=1) - att_mat = edge_index.copy() - att_mat = att_mat.set_value(edge_weight) - att_mat = torch_sparse.mul(att_mat, - deg) # unnormalized attention matrix + att_mat = adj_t.set_value(edge_weight) + # repeat_interleave for any dtype tensor + num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1 + repeat_deg = deg.view(-1).unsqueeze(1).repeat(1, num_heads).flatten() + # perform mat multiplication with sparse tensor + # for multiple attention heads + att_mat = torch_sparse.mul(att_mat, repeat_deg.view(-1, 1, num_heads)) # Add self-loop, also called renormalization trick att_mat = torch_sparse.fill_diag(att_mat, fill_value) deg_tilde = deg + 1 deg_tilde_inv_sqrt = deg_tilde.pow_(-0.5) deg_tilde_inv_sqrt.masked_fill_(deg_tilde_inv_sqrt == float('inf'), 0.) + # repeat_interleave for any dtype tensor + repeat_deg_tilde_inv = deg_tilde_inv_sqrt.view(-1).unsqueeze(1).repeat( + 1, num_heads).flatten() + repeat_deg_tilde_inv = repeat_deg_tilde_inv - att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(-1, 1)) - att_mat = torch_sparse.mul(att_mat, deg_tilde_inv_sqrt.view(1, -1)) - - return adj_t, att_mat.to_dense() + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(-1, 1, num_heads)) + att_mat = torch_sparse.mul(att_mat, + repeat_deg_tilde_inv.view(1, -1, num_heads)) + # should never be None, only for typing purpose + alpha = att_mat.storage.value() + return adj_t, alpha if alpha is not None else edge_weight if is_torch_sparse_tensor(edge_index): assert edge_index.size(0) == edge_index.size(1) @@ -284,9 +294,6 @@ def __init__( kwargs.setdefault('aggr', 'add') super().__init__(node_dim=0, **kwargs) - if normalize: - add_self_loops = False - self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -479,7 +486,7 @@ def forward( # noqa: F811 alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) - if self.add_self_loops: + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as # source and target nodes: @@ -511,7 +518,7 @@ def forward( # noqa: F811 edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) elif isinstance(edge_index, SparseTensor): - edge_index = torch_sparse.fill_diag(edge_index, 0.0) + edge_index = torch_sparse.remove_diag(edge_index) # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, @@ -525,11 +532,9 @@ def forward( # noqa: F811 num_nodes = min(num_nodes, x_dst.size(0)) num_nodes = min(size) if size is not None else num_nodes - edge_index, alpha = gat_norm(edge_index, - alpha, - num_nodes=num_nodes, - flow=self.flow, - dtype=alpha.dtype) # yapf: disable + edge_index, alpha = gat_norm(edge_index, alpha, + num_nodes=num_nodes, flow=self.flow, + dtype=alpha.dtype) # noqa: F811 # propagate_type: (x: OptPairTensor, alpha: Tensor) out = self.propagate(edge_index, x=x, alpha=alpha, size=size) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index a95117f28feb..982de16cf74e 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -153,9 +153,6 @@ def __init__( ): super().__init__(node_dim=0, **kwargs) - if normalize: - add_self_loops = False - self.in_channels = in_channels self.out_channels = out_channels self.heads = heads @@ -311,7 +308,7 @@ def forward( # noqa: F811 assert x_l is not None assert x_r is not None - if self.add_self_loops: + if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): num_nodes = x_l.size(0) if x_r is not None: diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index e0fb657cb294..c6e91f4f1db0 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -248,6 +248,10 @@ def fill_diag(src: SparseTensor, fill_value: float, k: int = 0) -> SparseTensor: raise ImportError("'fill_diag' requires 'torch-sparse'") + @staticmethod + def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor: + raise ImportError("'remove_diag' requires 'torch-sparse'") + @staticmethod def masked_select_nnz(src: SparseTensor, mask: Tensor, layout: Optional[str] = None) -> SparseTensor: From ce302c445e7f4d8afdbea3075ebf05959df260f1 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 11 Dec 2024 21:22:57 +0100 Subject: [PATCH 12/16] Add small fixes --- torch_geometric/nn/conv/gat_conv.py | 3 ++- torch_geometric/nn/conv/gatv2_conv.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index 70e860081c6e..e8c9f28d7dbc 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -102,7 +102,8 @@ def gat_norm( # noqa: F811 repeat_deg_tilde_inv.view(1, -1, num_heads)) # should never be None, only for typing purpose alpha = att_mat.storage.value() - return adj_t, alpha if alpha is not None else edge_weight + return att_mat.set_value( + None), alpha if alpha is not None else edge_weight if is_torch_sparse_tensor(edge_index): assert edge_index.size(0) == edge_index.size(1) diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index 982de16cf74e..f6dcde265bf0 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -337,7 +337,7 @@ def forward( # noqa: F811 edge_index, edge_attr = remove_self_loops( edge_index, edge_attr) elif isinstance(edge_index, SparseTensor): - edge_index = torch_sparse.fill_diag(edge_index, 0.0) + edge_index = torch_sparse.remove_diag(edge_index) # edge_updater_type: (x: PairTensor, edge_attr: OptTensor) alpha = self.edge_updater(edge_index, x=(x_l, x_r), From 00c73d8ce15ffa0443e6a6a0efb46f3e3ef7c409 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 11 Dec 2024 21:50:47 +0100 Subject: [PATCH 13/16] Clean typing and add typing test --- test/nn/conv/test_gat_conv.py | 13 ++++++++++++- torch_geometric/typing.py | 6 ------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 3c1e08210134..b5b59385166e 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -8,7 +8,7 @@ from torch_geometric.nn import GATConv from torch_geometric.nn.conv.gat_conv import gat_norm from torch_geometric.testing import is_full_test, withDevice -from torch_geometric.typing import Adj, Size, SparseTensor +from torch_geometric.typing import Adj, Size, SparseTensor, torch_sparse from torch_geometric.utils import to_torch_csc_tensor, to_torch_csr_tensor @@ -274,3 +274,14 @@ def test_gat_conv_bipartite_error(): match="not supported for bipartite message passing"): conv = GATConv((8, 16), 32, heads=2, normalize=True) conv((x1, x2), edge_index) + + +def test_remove_diag_sparse_tensor(): + # Used in GAT Normalization + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_index2 = torch.tensor([[1, 2, 3], [0, 1, 1]]) + + adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4)) + + assert torch_sparse.remove_diag(adj1.t()) == adj2.t() diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index c6e91f4f1db0..3698674dd6be 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -183,9 +183,6 @@ def from_dense(self, mat: Tensor, has_value: bool = True) -> 'SparseTensor': raise ImportError("'SparseTensor' requires 'torch-sparse'") - def copy(self) -> 'SparseTensor': - raise ImportError("'SparseTensor' requires 'torch-sparse'") - def size(self, dim: int) -> int: raise ImportError("'SparseTensor' requires 'torch-sparse'") @@ -221,9 +218,6 @@ def to_torch_sparse_csr_tensor( ) -> Tensor: raise ImportError("'SparseTensor' requires 'torch-sparse'") - def to_dense(self) -> torch.Tensor: - raise ImportError("'SparseTensor' requires 'torch-sparse'") - class torch_sparse: # type: ignore @staticmethod def matmul(src: SparseTensor, other: Tensor, From 6197e8f3445f99e6207ec45cd2baddf1ba89e4ab Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 11 Dec 2024 21:51:16 +0100 Subject: [PATCH 14/16] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..e1969848e4aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `normalize` parameter to `GATConv` and `GATv2Conv` ([#9840](https://github.com/pyg-team/pytorch_geometric/pull/9840)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) From 9c5c19e32dac9cad51c4069a2222e4db6faf84db Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Wed, 11 Dec 2024 21:59:31 +0100 Subject: [PATCH 15/16] Update tests --- test/nn/conv/test_gat_conv.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index b5b59385166e..eeeee2b31145 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -281,7 +281,8 @@ def test_remove_diag_sparse_tensor(): edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) edge_index2 = torch.tensor([[1, 2, 3], [0, 1, 1]]) - adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) - adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4)) + if torch_geometric.typing.WITH_TORCH_SPARSE: + adj1 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + adj2 = SparseTensor.from_edge_index(edge_index2, sparse_sizes=(4, 4)) - assert torch_sparse.remove_diag(adj1.t()) == adj2.t() + assert torch_sparse.remove_diag(adj1.t()) == adj2.t() From d4864eedbbd668d19a05990faedb871afc97f807 Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Fri, 13 Dec 2024 07:42:54 +0100 Subject: [PATCH 16/16] Update flow --- torch_geometric/nn/conv/gat_conv.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index e8c9f28d7dbc..1ef86100d9e6 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -77,7 +77,10 @@ def gat_norm( # noqa: F811 if not adj_t.has_value(): adj_t = adj_t.fill_value(1., dtype=dtype) - deg = torch_sparse.sum(adj_t, dim=1) + # idx = col if flow == 'source_to_target' else row + dim = 1 if flow == 'source_to_target' else 0 + deg = torch_sparse.sum(adj_t, dim=dim) + att_mat = adj_t.set_value(edge_weight) # repeat_interleave for any dtype tensor num_heads = edge_weight.shape[1] if edge_weight.dim() > 1 else 1 @@ -487,6 +490,8 @@ def forward( # noqa: F811 alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) alpha = (alpha_src, alpha_dst) + # Skipping add_self_loops when normalize + # as we are performing it on gat_norm function if self.add_self_loops and not self.normalize: if isinstance(edge_index, Tensor): # We only want to add self-loops for nodes that appear both as