Skip to content

Commit

Permalink
Fix pickle.load in combination with torch.jit.script for `Message…
Browse files Browse the repository at this point in the history
…Passing` modules (#9368)
  • Loading branch information
rusty1s authored May 27, 2024
1 parent ca3b0f4 commit b464595
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368))
- Fixed batching of sparse tensors saved via `data.edge_index` ([#9317](https://github.com/pyg-team/pytorch_geometric/pull/9317))
- Fixed arbitrary keyword ordering in `MessagePassing.propgate` ([#9245](https://github.com/pyg-team/pytorch_geometric/pull/9245))
- Fixed node mapping bug in `RCDD` dataset ([#9234](https://github.com/pyg-team/pytorch_geometric/pull/9234))
Expand Down
14 changes: 13 additions & 1 deletion test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.nn import MessagePassing, aggr
from torch_geometric.nn import GATConv, MessagePassing, aggr
from torch_geometric.typing import (
Adj,
OptPairTensor,
Expand Down Expand Up @@ -728,3 +728,15 @@ def test_traceable_my_conv_with_self_loops(num_nodes):

assert torch.allclose(out, traced_out)
assert torch.allclose(out, scripted_out)


def test_pickle(tmp_path):
path = osp.join(tmp_path, 'model.pt')
model = GATConv(16, 32)
torch.save(model, path)

GATConv.propagate = GATConv._orig_propagate
GATConv.edge_updater = GATConv._orig_edge_updater

model = torch.load(path)
torch.jit.script(model)
148 changes: 83 additions & 65 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,71 +167,8 @@ def __init__(
self._edge_update_forward_pre_hooks: HookDict = OrderedDict()
self._edge_update_forward_hooks: HookDict = OrderedDict()

root_dir = osp.dirname(osp.realpath(__file__))
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
# Optimize `propagate()` via `*.jinja` templates:
if not self.propagate.__module__.startswith(jinja_prefix):
try:
if 'propagate' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'propagate' "
"method")

module = module_from_template(
module_name=f'{jinja_prefix}_propagate',
template_path=osp.join(root_dir, 'propagate.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='collect',
signature=self._get_propagate_signature(),
collect_param_dict=self.inspector.get_flat_param_dict(
['message', 'aggregate', 'update']),
message_args=self.inspector.get_param_names('message'),
aggregate_args=self.inspector.get_param_names('aggregate'),
message_and_aggregate_args=self.inspector.get_param_names(
'message_and_aggregate'),
update_args=self.inspector.get_param_names('update'),
fuse=self.fuse,
)

self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = module.propagate

self.__class__.propagate = module.propagate
self.__class__.collect = module.collect
except Exception: # pragma: no cover
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = self.__class__.propagate

# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)):
try:
if 'edge_updater' in self.__class__.__dict__:
raise ValueError("Cannot compile custom 'edge_updater' "
"method")

module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)

self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = module.edge_updater

self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
except Exception: # pragma: no cover
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = (
self.__class__.edge_updater)
# Set jittable `propagate` and `edge_updater` function templates:
self._set_jittable_templates()

# Explainability:
self._explain: Optional[bool] = None
Expand All @@ -248,6 +185,12 @@ def reset_parameters(self) -> None:
if self.aggr_module is not None:
self.aggr_module.reset_parameters()

def __setstate__(self, data: Dict[str, Any]) -> None:
self.inspector = data['inspector']
self.fuse = data['fuse']
self._set_jittable_templates()
super().__setstate__(data)

def __repr__(self) -> str:
channels_repr = ''
if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):
Expand Down Expand Up @@ -981,6 +924,81 @@ def register_edge_update_forward_hook(

# TorchScript Support #####################################################

def _set_jittable_templates(self, raise_on_error: bool = False) -> None:
root_dir = osp.dirname(osp.realpath(__file__))
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
# Optimize `propagate()` via `*.jinja` templates:
if not self.propagate.__module__.startswith(jinja_prefix):
try:
if ('propagate' in self.__class__.__dict__
and self.__class__.__dict__['propagate']
!= MessagePassing.propagate):
raise ValueError("Cannot compile custom 'propagate' "
"method")

module = module_from_template(
module_name=f'{jinja_prefix}_propagate',
template_path=osp.join(root_dir, 'propagate.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='collect',
signature=self._get_propagate_signature(),
collect_param_dict=self.inspector.get_flat_param_dict(
['message', 'aggregate', 'update']),
message_args=self.inspector.get_param_names('message'),
aggregate_args=self.inspector.get_param_names('aggregate'),
message_and_aggregate_args=self.inspector.get_param_names(
'message_and_aggregate'),
update_args=self.inspector.get_param_names('update'),
fuse=self.fuse,
)

self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = module.propagate

self.__class__.propagate = module.propagate
self.__class__.collect = module.collect
except Exception as e: # pragma: no cover
if raise_on_error:
raise e
self.__class__._orig_propagate = self.__class__.propagate
self.__class__._jinja_propagate = self.__class__.propagate

# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
if (self.inspector.implements('edge_update')
and not self.edge_updater.__module__.startswith(jinja_prefix)):
try:
if ('edge_updater' in self.__class__.__dict__
and self.__class__.__dict__['edge_updater']
!= MessagePassing.edge_updater):
raise ValueError("Cannot compile custom 'edge_updater' "
"method")

module = module_from_template(
module_name=f'{jinja_prefix}_edge_updater',
template_path=osp.join(root_dir, 'edge_updater.jinja'),
tmp_dirname='message_passing',
# Keyword arguments:
modules=self.inspector._modules,
collect_name='edge_collect',
signature=self._get_edge_updater_signature(),
collect_param_dict=self.inspector.get_param_dict(
'edge_update'),
)

self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = module.edge_updater

self.__class__.edge_updater = module.edge_updater
self.__class__.edge_collect = module.edge_collect
except Exception as e: # pragma: no cover
if raise_on_error:
raise e
self.__class__._orig_edge_updater = self.__class__.edge_updater
self.__class__._jinja_edge_updater = (
self.__class__.edge_updater)

def _get_propagate_signature(self) -> Signature:
param_dict = self.inspector.get_params_from_method_call(
'propagate', exclude=[0, 'edge_index', 'size'])
Expand Down

0 comments on commit b464595

Please sign in to comment.