Skip to content

Commit

Permalink
Fix torch.jit.load for MessagePassing modules (#8870)
Browse files Browse the repository at this point in the history
Fixes #8867
  • Loading branch information
rusty1s authored Feb 6, 2024
1 parent 1f93077 commit 64a1268
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 49 deletions.
10 changes: 10 additions & 0 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import os.path as osp
from typing import Tuple, Union

import pytest
Expand Down Expand Up @@ -184,6 +185,15 @@ def test_my_conv_jit():
jit.fuse = True


def test_my_conv_jit_save(tmp_path):
path = osp.join(tmp_path, 'model.pt')

conv = MyConv(8, 32)
conv = torch.jit.script(conv)
torch.jit.save(conv, path)
conv = torch.jit.load(path)


@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul'])
def test_my_conv_aggr(aggr):
x = torch.randn(4, 8)
Expand Down
58 changes: 31 additions & 27 deletions torch_geometric/nn/conv/collect.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -30,34 +30,43 @@ def {{collect_name}}(
i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)

# Collect special arguments:
if isinstance(edge_index, Tensor) and is_torch_sparse_tensor(edge_index):
if isinstance(edge_index, Tensor):
if is_torch_sparse_tensor(edge_index):
{%- if 'edge_index' in collect_param_dict %}
raise ValueError("Cannot collect 'edge_indices' for sparse matrices")
raise ValueError("Cannot collect 'edge_indices' for sparse matrices")
{%- endif %}
adj_t = edge_index
if adj_t.layout == torch.sparse_coo:
edge_index_i = adj_t.indices()[0]
edge_index_j = adj_t.indices()[1]
ptr = None
elif adj_t.layout == torch.sparse_csr:
ptr = adj_t.crow_indices()
edge_index_j = adj_t.col_indices()
edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
else:
raise ValueError(f"Received invalid layout '{adj_t.layout}'")
adj_t = edge_index
if adj_t.layout == torch.sparse_coo:
edge_index_i = adj_t.indices()[0]
edge_index_j = adj_t.indices()[1]
ptr = None
elif adj_t.layout == torch.sparse_csr:
ptr = adj_t.crow_indices()
edge_index_j = adj_t.col_indices()
edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())
else:
raise ValueError(f"Received invalid layout '{adj_t.layout}'")

{%- if 'edge_weight' in collect_param_dict %}
if edge_weight is None:
edge_weight = adj_t.values()
if edge_weight is None:
edge_weight = adj_t.values()
{%- elif 'edge_attr' in collect_param_dict %}
if edge_attr is None:
_value = adj_t.values()
edge_attr = None if _value.dim() == 1 else _value
if edge_attr is None:
_value = adj_t.values()
edge_attr = None if _value.dim() == 1 else _value
{%- elif 'edge_type' in collect_param_dict %}
if edge_type is None:
edge_type = adj_t.values()
if edge_type is None:
edge_type = adj_t.values()
{%- endif %}

else:
{%- if 'adj_t' in collect_param_dict %}
raise ValueError("Cannot collect 'adj_t' for edge indices")
{%- endif %}
edge_index_i = edge_index[i]
edge_index_j = edge_index[j]
ptr = None

elif isinstance(edge_index, SparseTensor):
{%- if 'edge_index' in collect_param_dict %}
raise ValueError("Cannot collect 'edge_indices' for sparse matrices")
Expand All @@ -77,13 +86,8 @@ def {{collect_name}}(
edge_type = _value
{%- endif %}

elif isinstance(edge_index, Tensor):
{%- if 'adj_t' in collect_param_dict %}
raise ValueError("Cannot collect 'adj_t' for edge indices")
{%- endif %}
edge_index_i = edge_index[i]
edge_index_j = edge_index[j]
ptr = None
else:
raise NotImplementedError

{%- if 'edge_weight' in collect_param_dict and
collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}
Expand Down
29 changes: 7 additions & 22 deletions torch_geometric/nn/conv/gen_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,7 @@
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import reset
from torch_geometric.nn.norm import MessageNorm
from torch_geometric.typing import (
Adj,
OptPairTensor,
OptTensor,
Size,
SparseTensor,
)
from torch_geometric.utils import is_torch_sparse_tensor, to_edge_index
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size


class MLP(Sequential):
Expand Down Expand Up @@ -216,20 +209,6 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
if hasattr(self, 'lin_src'):
x = (self.lin_src(x[0]), x[1])

if isinstance(edge_index, SparseTensor):
edge_attr = edge_index.storage.value()
elif is_torch_sparse_tensor(edge_index):
_, value = to_edge_index(edge_index)
if value.dim() > 1 or not value.all():
edge_attr = value

if edge_attr is not None and hasattr(self, 'lin_edge'):
edge_attr = self.lin_edge(edge_attr)

# Node and edge feature dimensionalites need to match.
if edge_attr is not None:
assert x[0].size(-1) == edge_attr.size(-1)

# propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)

Expand All @@ -250,6 +229,12 @@ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
return self.mlp(out)

def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:
if edge_attr is not None and hasattr(self, 'lin_edge'):
edge_attr = self.lin_edge(edge_attr)

if edge_attr is not None:
assert x_j.size(-1) == edge_attr.size(-1)

msg = x_j if edge_attr is None else x_j + edge_attr
return msg.relu() + self.eps

Expand Down

0 comments on commit 64a1268

Please sign in to comment.