Skip to content

Commit

Permalink
Update test condition to check bipartite graph in forward pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincenzooo authored and Vincenzooo committed Dec 11, 2024
1 parent d22db1d commit c19ada3
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
6 changes: 3 additions & 3 deletions test/nn/conv/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ def test_gat_norm_csc_error():

def test_gat_conv_bipartite_error():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))

with pytest.raises(NotImplementedError,
match="not supported for bipartite message passing"):
conv = GATConv(8, 32, heads=2, normalize=True)
_ = conv(x1, adj1.t())
conv = GATConv((8, 16), 32, heads=2, normalize=True)
conv((x1, x2), edge_index)
6 changes: 3 additions & 3 deletions test/nn/conv/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def test_gatv2_conv_with_edge_attr():

def test_gat_conv_bipartite_error():
x1 = torch.randn(4, 8)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))

with pytest.raises(NotImplementedError,
match="not supported for bipartite message passing"):
conv = GATv2Conv(8, 32, heads=2, normalize=True)
_ = conv(x1, adj1.t())
conv = GATv2Conv((8, 16), 32, heads=2, normalize=True)
conv((x1, x2), edge_index)
12 changes: 5 additions & 7 deletions torch_geometric/nn/conv/gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,10 @@ def forward( # noqa: F811
"'edge_index' in a 'SparseTensor' form")

if self.normalize:
if isinstance(edge_index,
SparseTensor) or is_torch_sparse_tensor(edge_index):
if edge_index.size(0) != edge_index.size(1):
raise NotImplementedError(
"The usage of 'normalize' is not supported "
"for bipartite message passing.")
if not isinstance(self.in_channels, int):
raise NotImplementedError(
"The usage of 'normalize' is not supported "
"for bipartite message passing.")

if isinstance(edge_index, Tensor):
edge_index, edge_attr = remove_self_loops(
Expand All @@ -520,7 +518,7 @@ def forward( # noqa: F811
size=size)

if self.normalize:
num_nodes = None
num_nodes: Optional[int] = None
if isinstance(edge_index, Tensor):
num_nodes = x_src.size(0)
if x_dst is not None:
Expand Down
10 changes: 4 additions & 6 deletions torch_geometric/nn/conv/gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,10 @@ def forward( # noqa: F811
"'edge_index' in a 'SparseTensor' form")

if self.normalize:
if isinstance(edge_index,
SparseTensor) or is_torch_sparse_tensor(edge_index):
if edge_index.size(0) != edge_index.size(1):
raise NotImplementedError(
"The usage of 'normalize' is not supported "
"for bipartite message passing.")
if not isinstance(self.in_channels, int):
raise NotImplementedError(
"The usage of 'normalize' is not supported "
"for bipartite message passing.")

if isinstance(edge_index, Tensor):
edge_index, edge_attr = remove_self_loops(
Expand Down

0 comments on commit c19ada3

Please sign in to comment.