From e4cdd0f6c0bfec3a56d3e60acf107a987c5f4dcb Mon Sep 17 00:00:00 2001 From: Vincenzooo Date: Tue, 10 Dec 2024 23:09:47 +0100 Subject: [PATCH] Update test condition to check bipartite graph in forward pass --- test/nn/conv/test_gat_conv.py | 6 +++--- test/nn/conv/test_gatv2_conv.py | 6 +++--- torch_geometric/nn/conv/gat_conv.py | 12 +++++------- torch_geometric/nn/conv/gatv2_conv.py | 10 ++++------ 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/test/nn/conv/test_gat_conv.py b/test/nn/conv/test_gat_conv.py index 52054aa99f0e..3c1e08210134 100644 --- a/test/nn/conv/test_gat_conv.py +++ b/test/nn/conv/test_gat_conv.py @@ -267,10 +267,10 @@ def test_gat_norm_csc_error(): def test_gat_conv_bipartite_error(): x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) with pytest.raises(NotImplementedError, match="not supported for bipartite message passing"): - conv = GATConv(8, 32, heads=2, normalize=True) - _ = conv(x1, adj1.t()) + conv = GATConv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) diff --git a/test/nn/conv/test_gatv2_conv.py b/test/nn/conv/test_gatv2_conv.py index d8a94e72a9f7..3a0477d8df37 100644 --- a/test/nn/conv/test_gatv2_conv.py +++ b/test/nn/conv/test_gatv2_conv.py @@ -205,10 +205,10 @@ def test_gatv2_conv_with_edge_attr(): def test_gat_conv_bipartite_error(): x1 = torch.randn(4, 8) + x2 = torch.randn(2, 16) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) - adj1 = to_torch_csc_tensor(edge_index, size=(4, 2)) with pytest.raises(NotImplementedError, match="not supported for bipartite message passing"): - conv = GATv2Conv(8, 32, heads=2, normalize=True) - _ = conv(x1, adj1.t()) + conv = GATv2Conv((8, 16), 32, heads=2, normalize=True) + conv((x1, x2), edge_index) diff --git a/torch_geometric/nn/conv/gat_conv.py b/torch_geometric/nn/conv/gat_conv.py index f364533ddff8..998a4fe1ce24 100644 --- a/torch_geometric/nn/conv/gat_conv.py +++ b/torch_geometric/nn/conv/gat_conv.py @@ -502,12 +502,10 @@ def forward( # noqa: F811 "'edge_index' in a 'SparseTensor' form") if self.normalize: - if isinstance(edge_index, - SparseTensor) or is_torch_sparse_tensor(edge_index): - if edge_index.size(0) != edge_index.size(1): - raise NotImplementedError( - "The usage of 'normalize' is not supported " - "for bipartite message passing.") + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops( @@ -520,7 +518,7 @@ def forward( # noqa: F811 size=size) if self.normalize: - num_nodes = None + num_nodes: Optional[int] = None if isinstance(edge_index, Tensor): num_nodes = x_src.size(0) if x_dst is not None: diff --git a/torch_geometric/nn/conv/gatv2_conv.py b/torch_geometric/nn/conv/gatv2_conv.py index a368453adce3..a95117f28feb 100644 --- a/torch_geometric/nn/conv/gatv2_conv.py +++ b/torch_geometric/nn/conv/gatv2_conv.py @@ -331,12 +331,10 @@ def forward( # noqa: F811 "'edge_index' in a 'SparseTensor' form") if self.normalize: - if isinstance(edge_index, - SparseTensor) or is_torch_sparse_tensor(edge_index): - if edge_index.size(0) != edge_index.size(1): - raise NotImplementedError( - "The usage of 'normalize' is not supported " - "for bipartite message passing.") + if not isinstance(self.in_channels, int): + raise NotImplementedError( + "The usage of 'normalize' is not supported " + "for bipartite message passing.") if isinstance(edge_index, Tensor): edge_index, edge_attr = remove_self_loops(