Skip to content

Commit

Permalink
Use message_and_aggregate in MessagePassing for EdgeIndex (#9131)
Browse files Browse the repository at this point in the history
Also enforce sort by `row` in spmm

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 38bb5f2 commit e213c29
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))
- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))
- Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173))
- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075))
Expand Down
34 changes: 34 additions & 0 deletions test/nn/conv/test_graph_conv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.nn import GraphConv
from torch_geometric.testing import is_full_test
from torch_geometric.typing import SparseTensor
Expand All @@ -22,6 +23,11 @@ def test_graph_conv():
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out1, atol=1e-6)
assert torch.allclose(conv(x1, adj1.t()), out1, atol=1e-6)

assert conv(
x1,
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)),
).allclose(out1, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
assert torch.allclose(conv(x1, adj3.t()), out1, atol=1e-6)
Expand All @@ -32,6 +38,12 @@ def test_graph_conv():
atol=1e-6)
assert torch.allclose(conv(x1, adj2.t()), out2, atol=1e-6)

assert conv(
x1,
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)),
value,
).allclose(out2, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))
assert torch.allclose(conv(x1, adj4.t()), out2, atol=1e-6)
Expand All @@ -58,19 +70,41 @@ def test_graph_conv():
assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)

assert conv(
(x1, x2),
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),
).allclose(out1, atol=1e-6)

out2 = conv((x1, None), edge_index, size=(4, 2))
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)

assert conv(
(x1, None),
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),
).allclose(out2, atol=1e-6)

out3 = conv((x1, x2), edge_index, value)
assert out3.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out3)
assert torch.allclose(conv((x1, x2), adj2.t()), out3, atol=1e-6)

assert conv(
(x1, x2),
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),
value,
).allclose(out3, atol=1e-6)

out4 = conv((x1, None), edge_index, value, size=(4, 2))
assert out4.size() == (2, 32)
assert torch.allclose(conv((x1, None), adj2.t()), out4, atol=1e-6)

assert conv(
(x1, None),
EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),
value,
).allclose(out4, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2))
Expand Down
24 changes: 21 additions & 3 deletions torch_geometric/nn/conv/graph_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Tuple, Union
from typing import Final, Tuple, Union

import torch
from torch import Tensor

from torch_geometric import EdgeIndex
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size
Expand Down Expand Up @@ -44,6 +46,8 @@ class GraphConv(MessagePassing):
- **output:** node features :math:`(|\mathcal{V}|, F_{out})` or
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = True

def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
Expand Down Expand Up @@ -90,5 +94,19 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:
return spmm(adj_t, x[0], reduce=self.aggr)
def message_and_aggregate(
self,
edge_index: Adj,
x: OptPairTensor,
edge_weight: OptTensor,
) -> Tensor:

if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
return edge_index.matmul(
other=x[0],
input_value=edge_weight,
reduce=self.aggr,
transpose=True,
)

return spmm(edge_index, x[0], reduce=self.aggr)
19 changes: 17 additions & 2 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Any,
Callable,
Dict,
Final,
List,
Optional,
OrderedDict,
Expand Down Expand Up @@ -102,6 +103,10 @@ class MessagePassing(torch.nn.Module):
'size_i', 'size_j', 'ptr', 'index', 'dim_size'
}

# Supports `message_and_aggregate` via `EdgeIndex`.
# TODO Remove once migration is finished.
SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = False

def __init__(
self,
aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',
Expand Down Expand Up @@ -517,7 +522,17 @@ def propagate(
mutable_size = self._check_input(edge_index, size)

# Run "fused" message and aggregation (if applicable).
if is_sparse(edge_index) and self.fuse and not self.explain:
fuse = False
if self.fuse and not self.explain:
if is_sparse(edge_index):
fuse = True
elif (not torch.jit.is_scripting()
and isinstance(edge_index, EdgeIndex)):
if (self.SUPPORTS_FUSED_EDGE_INDEX
and edge_index.is_sorted_by_col):
fuse = True

if fuse:
coll_dict = self._collect(self._fused_user_args, edge_index,
mutable_size, kwargs)

Expand Down Expand Up @@ -636,7 +651,7 @@ def aggregate(
dim=self.node_dim)

@abstractmethod
def message_and_aggregate(self, adj_t: Adj) -> Tensor:
def message_and_aggregate(self, edge_index: Adj) -> Tensor:
r"""Fuses computations of :func:`message` and :func:`aggregate` into a
single function.
If applicable, this saves both time and memory since messages do not
Expand Down
10 changes: 9 additions & 1 deletion torch_geometric/nn/conv/propagate.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,15 @@ def propagate(
# End Propagate Forward Pre Hook ###########################################

mutable_size = self._check_input(edge_index, size)
fuse = is_sparse(edge_index) and self.fuse

# Run "fused" message and aggregation (if applicable).
fuse = False
if self.fuse:
if is_sparse(edge_index):
fuse = True
elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):
if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col:
fuse = True

if fuse:

Expand Down

0 comments on commit e213c29

Please sign in to comment.