Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GATv3Conv implementation and tests #9937

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions test/nn/conv/test_gatv3_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import pytest
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.nn import GATv3Conv
from torch_geometric.testing import is_full_test
from torch_geometric.typing import Adj, SparseTensor
from torch_geometric.utils import add_self_loops # If needed
from torch_geometric.utils import to_torch_csc_tensor


@pytest.mark.parametrize("share_weights", [False, True])
def test_gatv3_conv(share_weights):
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

# Create a CSC adjacency matrix for tests:
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))

# Instantiate the layer:
conv = GATv3Conv(
in_channels=8,
out_channels=32,
heads=2,
share_weights=share_weights,
)
assert str(conv) == "GATv3Conv(8, 32, heads=2)"

# Standard usage:
out = conv(x1, edge_index)
assert out.size() == (4, 64) # (num_nodes, heads * out_channels)
# Check consistency on repeated calls:
assert torch.allclose(conv(x1, edge_index), out)

# Check CSC adjacency:
# (Tensors must often be transposed if needed, similar to GATv2 tests.)
assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)

# If torch_sparse is available, also check with a SparseTensor:
if torch_geometric.typing.WITH_TORCH_SPARSE:
s_adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert torch.allclose(conv(x1, s_adj.t()), out, atol=1e-6)

# Test scripting/JIT if in "full test" mode:
if is_full_test():

class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = conv

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
return self.conv(x, edge_index)

jit = torch.jit.script(MyModule())
assert torch.allclose(jit(x1, edge_index), out)

if torch_geometric.typing.WITH_TORCH_SPARSE:
assert torch.allclose(jit(x1, s_adj.t()), out, atol=1e-6)

# Test return_attention_weights:
result = conv(x1, edge_index, return_attention_weights=True)
out_att, (ei_att, alpha_att) = result
assert torch.allclose(out_att, out)
assert ei_att.size(0) == 2 and ei_att.size(1) == edge_index.size(1) + 4
# Typically, GATv3 might add self-loops automatically, so the edge count
# may differ. Adjust the check if you remove that or keep it.

assert alpha_att.size() == (ei_att.size(1), 2) # (num_edges, heads)
assert alpha_att.min() >= 0 and alpha_att.max() <= 1

# Bipartite message passing:
# Use an adjacency shaped for (src_nodes=4, dst_nodes=2).
adj_bip = to_torch_csc_tensor(edge_index, size=(4, 2))

out_bip = conv((x1, x2), edge_index)
assert out_bip.size() == (2, 64)
assert torch.allclose(conv((x1, x2), edge_index), out_bip)
assert torch.allclose(conv((x1, x2), adj_bip.t()), out_bip, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
s_adj_bip = SparseTensor.from_edge_index(edge_index,
sparse_sizes=(4, 2))
assert torch.allclose(conv((x1, x2), s_adj_bip.t()), out_bip,
atol=1e-6)

if is_full_test():

class MyBipModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = conv

def forward(self, x_tuple, edge_index: Adj) -> Tensor:
return self.conv(x_tuple, edge_index)

jit = torch.jit.script(MyBipModule())
assert torch.allclose(jit((x1, x2), edge_index), out_bip)


def test_gatv3_conv_with_edge_attr():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]])
edge_weight = torch.randn(edge_index.size(1)) # 1D features
edge_attr = torch.randn(edge_index.size(1), 4) # 4D features

# Case 1: 1D edge features with fill_value=0.5
conv = GATv3Conv(
in_channels=8,
out_channels=32,
heads=2,
edge_dim=1,
fill_value=0.5,
)
out = conv(x, edge_index, edge_weight)
assert out.size() == (4, 64)

# Case 2: 1D edge features with fill_value='mean'
conv = GATv3Conv(
in_channels=8,
out_channels=32,
heads=2,
edge_dim=1,
fill_value="mean",
)
out = conv(x, edge_index, edge_weight)
assert out.size() == (4, 64)

# Case 3: 4D edge features with fill_value=0.5
conv = GATv3Conv(
in_channels=8,
out_channels=32,
heads=2,
edge_dim=4,
fill_value=0.5,
)
out = conv(x, edge_index, edge_attr)
assert out.size() == (4, 64)

# Case 4: 4D edge features with fill_value='mean'
conv = GATv3Conv(
in_channels=8,
out_channels=32,
heads=2,
edge_dim=4,
fill_value="mean",
)
out = conv(x, edge_index, edge_attr)
assert out.size() == (4, 64)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .cugraph.gat_conv import CuGraphGATConv
from .fused_gat_conv import FusedGATConv
from .gatv2_conv import GATv2Conv
from .gatv3_conv import GATv3Conv
from .transformer_conv import TransformerConv
from .agnn_conv import AGNNConv
from .tag_conv import TAGConv
Expand Down Expand Up @@ -79,6 +80,7 @@
'CuGraphGATConv',
'FusedGATConv',
'GATv2Conv',
'GATv3Conv',
'TransformerConv',
'AGNNConv',
'TAGConv',
Expand Down
211 changes: 211 additions & 0 deletions torch_geometric/nn/conv/gatv3_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import math
from typing import Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor, SparseTensor
from torch_geometric.utils import add_self_loops, remove_self_loops, softmax


class GATv3Conv(MessagePassing):
"""GATv3Conv implements a context-aware attention mechanism inspired by the
GATher framework [1], extending GATv2 with element-wise feature multiplication,
optional weight sharing, and scaling.

Reference:
[1] "GATher: Graph Attention Based Predictions of Gene-Disease Links,"
Narganes-Carlon et al., 2024 (arXiv:2409.16327).
"""

_alpha: OptTensor

def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
heads: int = 1,
dropout: float = 0.0,
edge_dim: Optional[int] = None,
fill_value: Union[float, Tensor, str] = "mean",
name: str = "unnamed",
*,
concat: bool = True,
bias: bool = True,
share_weights: bool = False,
**kwargs,
):
super().__init__(node_dim=0, aggr="sum", flow="source_to_target",
**kwargs)

self.in_channels = in_channels
self.out_channels = out_channels
self.heads = heads
self.concat = concat
self.dropout = dropout
self.edge_dim = edge_dim
self.fill_value = fill_value
self.share_weights = share_weights
self.sigmoid = torch.nn.Sigmoid()
self.attention_sigmoid = torch.nn.Sigmoid()
self.name = name

# Define my layer that will scale the attentions
self.context_attention_layer = torch.nn.Linear(
self.out_channels,
1,
bias=True,
)

if isinstance(in_channels, int):
self.lin_l = Linear(
in_channels,
heads * out_channels,
bias=bias,
weight_initializer="glorot",
)
if share_weights:
self.lin_r = self.lin_l
else:
self.lin_r = Linear(
in_channels,
heads * out_channels,
bias=bias,
weight_initializer="glorot",
)
else:
self.lin_l = Linear(
in_channels[0],
heads * out_channels,
bias=bias,
weight_initializer="glorot",
)
if share_weights:
self.lin_r = self.lin_l
else:
self.lin_r = Linear(
in_channels[1],
heads * out_channels,
bias=bias,
weight_initializer="glorot",
)

if edge_dim is not None:
self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,
weight_initializer="glorot")
else:
self.lin_edge = None

if bias and concat:
self.bias = Parameter(torch.Tensor(heads * out_channels))
if bias and not concat:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias", None)

self._alpha = None
self.reset_parameters()

def reset_parameters(self):
self.lin_l.reset_parameters()
self.lin_r.reset_parameters()
if self.lin_edge is not None:
self.lin_edge.reset_parameters()
zeros(self.bias)

def forward(
self,
x: Union[Tensor, PairTensor],
edge_index: Adj,
edge_attr: OptTensor = None,
*,
return_attention_weights: bool = None,
):

x_l: OptTensor = None
x_r: OptTensor = None
if isinstance(x, Tensor):
assert x.dim() == 2
x_l = self.lin_l(x).view(-1, self.heads, self.out_channels)
if self.share_weights:
x_r = x_l
else:
x_r = self.lin_r(x).view(-1, self.heads, self.out_channels)
else:
x_l, x_r = x[0], x[1]
assert x[0].dim() == 2
x_l = self.lin_l(x_l).view(-1, self.heads, self.out_channels)
if x_r is not None:
x_r = self.lin_r(x_r).view(-1, self.heads, self.out_channels)

# Remove and add self loops
# The original code did not check if the graph is bipartite
if type(x) == tuple:
pass
else:
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

out = self.propagate(
edge_index,
x=(x_l, x_r),
edge_attr=edge_attr,
size=(x_l.shape[0], x_r.shape[0]),
)

alpha = self._alpha
self._alpha = None

# Take the mean of the output representation for each of the heads
if self.concat:
out = out.view(-1, self.heads * self.out_channels)
else:
out = out.mean(dim=1)

if self.bias is not None:
out = out + self.bias

if isinstance(return_attention_weights, bool):
assert alpha is not None
if isinstance(edge_index, Tensor):
return out, (edge_index, alpha)
if isinstance(edge_index, SparseTensor):
return out, edge_index.set_value(alpha, layout="coo")
return out

def message(self, x_j: Tensor, x_i: Tensor, edge_index: Tensor) -> Tensor:

# Self dot product from Vaswani et al. Dec 2017 https://arxiv.org/pdf/1706.03762.pdf
alpha = x_i * x_j
alpha = self.context_attention_layer(alpha).squeeze(-1)

# Equation 1 from Vaswani et al. Dec 2017 https://arxiv.org/pdf/1706.03762.pdf
norm = math.sqrt(x_i.size(-1))
alpha = alpha / norm
alpha = softmax(alpha, index=edge_index[1, :], ptr=None)

self._alpha = alpha
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
alpha = alpha.unsqueeze(-1) # Fix shapes

return x_j * alpha

def aggregate(
self,
inputs: Tensor,
index: Tensor,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
) -> Tensor:
return super().aggregate(inputs, index, ptr, dim_size)

def update(self, inputs: Tensor) -> Tensor:
return super().update(inputs)

def __repr__(self) -> str:
return f"{repr(self.__class__)}({self.in_channels}, {self.out_channels}, heads={self.heads})"
Loading