From b46459504e2325f7754b01cd14d03c72993ca938 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 27 May 2024 20:21:44 +0200 Subject: [PATCH] Fix `pickle.load` in combination with `torch.jit.script` for `MessagePassing` modules (#9368) --- CHANGELOG.md | 1 + test/nn/conv/test_message_passing.py | 14 +- torch_geometric/nn/conv/message_passing.py | 148 ++++++++++++--------- 3 files changed, 97 insertions(+), 66 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71628423e6d5..10c3fd37ffa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368)) - Fixed batching of sparse tensors saved via `data.edge_index` ([#9317](https://github.com/pyg-team/pytorch_geometric/pull/9317)) - Fixed arbitrary keyword ordering in `MessagePassing.propgate` ([#9245](https://github.com/pyg-team/pytorch_geometric/pull/9245)) - Fixed node mapping bug in `RCDD` dataset ([#9234](https://github.com/pyg-team/pytorch_geometric/pull/9234)) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index b1d7857ab7ba..76ff02dc750b 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -9,7 +9,7 @@ import torch_geometric.typing from torch_geometric import EdgeIndex -from torch_geometric.nn import MessagePassing, aggr +from torch_geometric.nn import GATConv, MessagePassing, aggr from torch_geometric.typing import ( Adj, OptPairTensor, @@ -728,3 +728,15 @@ def test_traceable_my_conv_with_self_loops(num_nodes): assert torch.allclose(out, traced_out) assert torch.allclose(out, scripted_out) + + +def test_pickle(tmp_path): + path = osp.join(tmp_path, 'model.pt') + model = GATConv(16, 32) + torch.save(model, path) + + GATConv.propagate = GATConv._orig_propagate + GATConv.edge_updater = GATConv._orig_edge_updater + + model = torch.load(path) + torch.jit.script(model) diff --git a/torch_geometric/nn/conv/message_passing.py b/torch_geometric/nn/conv/message_passing.py index ccf90ba16068..aa73dbaeb087 100644 --- a/torch_geometric/nn/conv/message_passing.py +++ b/torch_geometric/nn/conv/message_passing.py @@ -167,71 +167,8 @@ def __init__( self._edge_update_forward_pre_hooks: HookDict = OrderedDict() self._edge_update_forward_hooks: HookDict = OrderedDict() - root_dir = osp.dirname(osp.realpath(__file__)) - jinja_prefix = f'{self.__module__}_{self.__class__.__name__}' - # Optimize `propagate()` via `*.jinja` templates: - if not self.propagate.__module__.startswith(jinja_prefix): - try: - if 'propagate' in self.__class__.__dict__: - raise ValueError("Cannot compile custom 'propagate' " - "method") - - module = module_from_template( - module_name=f'{jinja_prefix}_propagate', - template_path=osp.join(root_dir, 'propagate.jinja'), - tmp_dirname='message_passing', - # Keyword arguments: - modules=self.inspector._modules, - collect_name='collect', - signature=self._get_propagate_signature(), - collect_param_dict=self.inspector.get_flat_param_dict( - ['message', 'aggregate', 'update']), - message_args=self.inspector.get_param_names('message'), - aggregate_args=self.inspector.get_param_names('aggregate'), - message_and_aggregate_args=self.inspector.get_param_names( - 'message_and_aggregate'), - update_args=self.inspector.get_param_names('update'), - fuse=self.fuse, - ) - - self.__class__._orig_propagate = self.__class__.propagate - self.__class__._jinja_propagate = module.propagate - - self.__class__.propagate = module.propagate - self.__class__.collect = module.collect - except Exception: # pragma: no cover - self.__class__._orig_propagate = self.__class__.propagate - self.__class__._jinja_propagate = self.__class__.propagate - - # Optimize `edge_updater()` via `*.jinja` templates (if implemented): - if (self.inspector.implements('edge_update') - and not self.edge_updater.__module__.startswith(jinja_prefix)): - try: - if 'edge_updater' in self.__class__.__dict__: - raise ValueError("Cannot compile custom 'edge_updater' " - "method") - - module = module_from_template( - module_name=f'{jinja_prefix}_edge_updater', - template_path=osp.join(root_dir, 'edge_updater.jinja'), - tmp_dirname='message_passing', - # Keyword arguments: - modules=self.inspector._modules, - collect_name='edge_collect', - signature=self._get_edge_updater_signature(), - collect_param_dict=self.inspector.get_param_dict( - 'edge_update'), - ) - - self.__class__._orig_edge_updater = self.__class__.edge_updater - self.__class__._jinja_edge_updater = module.edge_updater - - self.__class__.edge_updater = module.edge_updater - self.__class__.edge_collect = module.edge_collect - except Exception: # pragma: no cover - self.__class__._orig_edge_updater = self.__class__.edge_updater - self.__class__._jinja_edge_updater = ( - self.__class__.edge_updater) + # Set jittable `propagate` and `edge_updater` function templates: + self._set_jittable_templates() # Explainability: self._explain: Optional[bool] = None @@ -248,6 +185,12 @@ def reset_parameters(self) -> None: if self.aggr_module is not None: self.aggr_module.reset_parameters() + def __setstate__(self, data: Dict[str, Any]) -> None: + self.inspector = data['inspector'] + self.fuse = data['fuse'] + self._set_jittable_templates() + super().__setstate__(data) + def __repr__(self) -> str: channels_repr = '' if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'): @@ -981,6 +924,81 @@ def register_edge_update_forward_hook( # TorchScript Support ##################################################### + def _set_jittable_templates(self, raise_on_error: bool = False) -> None: + root_dir = osp.dirname(osp.realpath(__file__)) + jinja_prefix = f'{self.__module__}_{self.__class__.__name__}' + # Optimize `propagate()` via `*.jinja` templates: + if not self.propagate.__module__.startswith(jinja_prefix): + try: + if ('propagate' in self.__class__.__dict__ + and self.__class__.__dict__['propagate'] + != MessagePassing.propagate): + raise ValueError("Cannot compile custom 'propagate' " + "method") + + module = module_from_template( + module_name=f'{jinja_prefix}_propagate', + template_path=osp.join(root_dir, 'propagate.jinja'), + tmp_dirname='message_passing', + # Keyword arguments: + modules=self.inspector._modules, + collect_name='collect', + signature=self._get_propagate_signature(), + collect_param_dict=self.inspector.get_flat_param_dict( + ['message', 'aggregate', 'update']), + message_args=self.inspector.get_param_names('message'), + aggregate_args=self.inspector.get_param_names('aggregate'), + message_and_aggregate_args=self.inspector.get_param_names( + 'message_and_aggregate'), + update_args=self.inspector.get_param_names('update'), + fuse=self.fuse, + ) + + self.__class__._orig_propagate = self.__class__.propagate + self.__class__._jinja_propagate = module.propagate + + self.__class__.propagate = module.propagate + self.__class__.collect = module.collect + except Exception as e: # pragma: no cover + if raise_on_error: + raise e + self.__class__._orig_propagate = self.__class__.propagate + self.__class__._jinja_propagate = self.__class__.propagate + + # Optimize `edge_updater()` via `*.jinja` templates (if implemented): + if (self.inspector.implements('edge_update') + and not self.edge_updater.__module__.startswith(jinja_prefix)): + try: + if ('edge_updater' in self.__class__.__dict__ + and self.__class__.__dict__['edge_updater'] + != MessagePassing.edge_updater): + raise ValueError("Cannot compile custom 'edge_updater' " + "method") + + module = module_from_template( + module_name=f'{jinja_prefix}_edge_updater', + template_path=osp.join(root_dir, 'edge_updater.jinja'), + tmp_dirname='message_passing', + # Keyword arguments: + modules=self.inspector._modules, + collect_name='edge_collect', + signature=self._get_edge_updater_signature(), + collect_param_dict=self.inspector.get_param_dict( + 'edge_update'), + ) + + self.__class__._orig_edge_updater = self.__class__.edge_updater + self.__class__._jinja_edge_updater = module.edge_updater + + self.__class__.edge_updater = module.edge_updater + self.__class__.edge_collect = module.edge_collect + except Exception as e: # pragma: no cover + if raise_on_error: + raise e + self.__class__._orig_edge_updater = self.__class__.edge_updater + self.__class__._jinja_edge_updater = ( + self.__class__.edge_updater) + def _get_propagate_signature(self) -> Signature: param_dict = self.inspector.get_params_from_method_call( 'propagate', exclude=[0, 'edge_index', 'size'])