Skip to content

Commit

Permalink
Refactor Sequential: Allow pickling (#9369)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored May 28, 2024
1 parent b464595 commit 476e768
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 107 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- `Sequential` can now be properly pickled ([#9369](https://github.com/pyg-team/pytorch_geometric/pull/9369))
- Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368))
- Fixed batching of sparse tensors saved via `data.edge_index` ([#9317](https://github.com/pyg-team/pytorch_geometric/pull/9317))
- Fixed arbitrary keyword ordering in `MessagePassing.propgate` ([#9245](https://github.com/pyg-team/pytorch_geometric/pull/9245))
Expand Down
8 changes: 1 addition & 7 deletions test/nn/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch_geometric.typing import SparseTensor


def test_sequential():
def test_sequential_basic():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
batch = torch.zeros(4, dtype=torch.long)
Expand Down Expand Up @@ -47,12 +47,6 @@ def test_sequential():
assert isinstance(model[3], ReLU)
assert isinstance(model[4], Linear)

assert model._module_descs[0] == 'x, edge_index -> x'
assert model._module_descs[1] == 'x -> x'
assert model._module_descs[2] == 'x, edge_index -> x'
assert model._module_descs[3] == 'x -> x'
assert model._module_descs[4] == 'x -> x'

out = model(x, edge_index)
assert out.size() == (4, 7)

Expand Down
33 changes: 10 additions & 23 deletions torch_geometric/nn/sequential.jinja
Original file line number Diff line number Diff line change
@@ -1,35 +1,22 @@
import typing
from typing import *

import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.typing import *
{% for module in modules %}
from {{module}} import *
{%- endfor %}


class Sequential(torch.nn.Module):
def reset_parameters(self) -> None:
{%- for child in children %}
if hasattr(self.{{child.name}}, 'reset_parameters'):
self.{{child.name}}.reset_parameters()
def forward(
self,
{%- for param in signature.param_dict.values() %}
{{param.name}}: {{param.type_repr}},
{%- endfor %}
) -> {{signature.return_type_repr}}:

def forward(self, {{ input_types|join(', ') }}) -> {{return_type}}:
{%- for child in children %}
{{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
{{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})
{%- endfor %}
return {{children[-1].return_names|join(', ')}}

def __getitem__(self, idx: int) -> torch.nn.Module:
return getattr(self, self._module_names[idx])

def __len__(self) -> int:
return {{children|length}}

def __repr__(self) -> str:
module_reprs = [
f' ({i}) - {self[i]}: {self._module_descs[i]}'
for i in range(len(self))
]
return 'Sequential(\n{}\n)'.format('\n'.join(module_reprs))
return {{children[-1].return_names|join(', ')}}
280 changes: 203 additions & 77 deletions torch_geometric/nn/sequential.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
import copy
import inspect
import os.path as osp
import random
from typing import Callable, List, NamedTuple, Optional, Tuple, Union
import sys
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
Tuple,
Union,
)

import torch
from torch import Tensor

from torch_geometric.inspector import split, type_repr
from torch_geometric.inspector import Parameter, Signature, eval_type, split
from torch_geometric.template import module_from_template


class Child(NamedTuple):
name: str
module: Callable
param_names: List[str]
return_names: List[str]


def Sequential(
input_args: str,
modules: List[Union[Tuple[Callable, str], Callable]],
) -> torch.nn.Module:
class Sequential(torch.nn.Module):
r"""An extension of the :class:`torch.nn.Sequential` container in order to
define a sequential GNN model.
Expand Down Expand Up @@ -74,74 +82,192 @@ def Sequential(
:obj:`OrderedDict` of modules (and function header definitions) can
be passed.
"""
signature = input_args.split('->')
if len(signature) == 1:
input_args = signature[0]
return_type = type_repr(Tensor, globals())
elif len(signature) == 2:
input_args, return_type = signature[0], signature[1].strip()
else:
raise ValueError(f"Failed to parse arguments (got '{input_args}')")

input_types = split(input_args, sep=',')
if len(input_types) == 0:
raise ValueError(f"Failed to parse arguments (got '{input_args}')")

if not isinstance(modules, dict):
modules = {f'module_{i}': module for i, module in enumerate(modules)}
if len(modules) == 0:
raise ValueError("'Sequential' expected a non-empty list of modules")

children: List[Child] = []
for i, (name, module) in enumerate(modules.items()):
desc: Optional[str] = None
if isinstance(module, (tuple, list)):
if len(module) == 1:
module = module[0]
elif len(module) == 2:
module, desc = module
else:
raise ValueError(f"Expected tuple of length 2 (got {module})")

if i == 0 and desc is None:
raise ValueError("Requires signature for first module")
if not callable(module):
raise ValueError(f"Expected callable module (got {module})")
if desc is not None and not isinstance(desc, str):
raise ValueError(f"Expected type hint representation (got {desc})")

if desc is not None:
signature = desc.split('->')
if len(signature) != 2:
raise ValueError(f"Failed to parse arguments (got '{desc}')")
param_names = [v.strip() for v in signature[0].split(',')]
return_names = [v.strip() for v in signature[1].split(',')]
child = Child(name, module, param_names, return_names)
children: List[Child]

def __init__(
self,
input_args: str,
modules: List[Union[Tuple[Callable, str], Callable]],
) -> None:
super().__init__()

caller_path = inspect.stack()[1].filename
self._caller_module = osp.splitext(osp.basename(caller_path))[0]

_globals = copy.copy(globals())
_globals.update(sys.modules['__main__'].__dict__)
if self._caller_module in sys.modules:
_globals.update(sys.modules[self._caller_module].__dict__)

signature = input_args.split('->')
if len(signature) == 1:
args_repr = signature[0]
return_type_repr = 'Tensor'
return_type = Tensor
elif len(signature) == 2:
args_repr = signature[0]
return_type_repr = signature[1].strip()
return_type = eval_type(return_type_repr, _globals)
else:
param_names = children[-1].return_names
child = Child(name, module, param_names, param_names)

children.append(child)

uid = '%06x' % random.randrange(16**6)
root_dir = osp.dirname(osp.realpath(__file__))
module = module_from_template(
module_name=f'torch_geometric.nn.sequential_{uid}',
template_path=osp.join(root_dir, 'sequential.jinja'),
tmp_dirname='sequential',
# Keyword arguments:
input_types=input_types,
return_type=return_type,
children=children,
)

model = module.Sequential()
model._module_names = [child.name for child in children]
model._module_descs = [
f"{', '.join(child.param_names)} -> {', '.join(child.return_names)}"
for child in children
]
for child in children:
setattr(model, child.name, child.module)

return model
raise ValueError(f"Failed to parse arguments (got '{input_args}')")

param_dict: Dict[str, Parameter] = {}
for arg in split(args_repr, sep=','):
signature = arg.split(':')
if len(signature) == 1:
name = signature[0].strip()
param_dict[name] = Parameter(
name=name,
type=Tensor,
type_repr='Tensor',
default=inspect._empty,
)
elif len(signature) == 2:
name = signature[0].strip()
param_dict[name] = Parameter(
name=name,
type=eval_type(signature[1].strip(), _globals),
type_repr=signature[1].strip(),
default=inspect._empty,
)
else:
raise ValueError(f"Failed to parse argument "
f"(got '{arg.strip()}')")

self.signature = Signature(param_dict, return_type, return_type_repr)

if not isinstance(modules, dict):
modules = {
f'module_{i}': module
for i, module in enumerate(modules)
}
if len(modules) == 0:
raise ValueError(f"'{self.__class__.__name__}' expects a "
f"non-empty list of modules")

self.children: List[Child] = []
for i, (name, module) in enumerate(modules.items()):
desc: Optional[str] = None
if isinstance(module, (tuple, list)):
if len(module) == 1:
module = module[0]
elif len(module) == 2:
module, desc = module
else:
raise ValueError(f"Expected tuple of length 2 "
f"(got {module})")

if i == 0 and desc is None:
raise ValueError("Signature for first module required")
if not callable(module):
raise ValueError(f"Expected callable module (got {module})")
if desc is not None and not isinstance(desc, str):
raise ValueError(f"Expected type hint representation "
f"(got {desc})")

if desc is not None:
signature = desc.split('->')
if len(signature) != 2:
raise ValueError(
f"Failed to parse arguments (got '{desc}')")
param_names = [v.strip() for v in signature[0].split(',')]
return_names = [v.strip() for v in signature[1].split(',')]
child = Child(name, param_names, return_names)
else:
param_names = self.children[-1].return_names
child = Child(name, param_names, param_names)

setattr(self, name, module)
self.children.append(child)

self._set_jittable_template()

def reset_parameters(self) -> None:
r"""Resets all learnable parameters of the module."""
for child in self.children:
module = getattr(self, child.name)
if hasattr(module, 'reset_parameters'):
module.reset_parameters()

def __len__(self) -> int:
return len(self.children)

def __getitem__(self, idx: int) -> torch.nn.Module:
return getattr(self, self.children[idx].name)

def __setstate__(self, data: Dict[str, Any]) -> None:
super().__setstate__(data)
self._set_jittable_template()

def __repr__(self) -> str:
module_descs = [
f"{', '.join(c.param_names)} -> {', '.join(c.return_names)}"
for c in self.children
]
module_reprs = [
f' ({i}) - {self[i]}: {module_descs[i]}' for i in range(len(self))
]
return '{}(\n{}\n)'.format(
self.__class__.__name__,
'\n'.join(module_reprs),
)

def forward(self, *args: Any, **kwargs: Any) -> Any:
"""""" # noqa: D419
value_dict = {
name: arg
for name, arg in zip(self.signature.param_dict.keys(), args)
}
for key, arg in kwargs.items():
if key in value_dict:
raise TypeError(f"'{self.__class__.__name__}' got multiple "
f"values for argument '{key}'")
value_dict[key] = arg

for child in self.children:
args = [value_dict[name] for name in child.param_names]
outs = getattr(self, child.name)(*args)
if len(child.return_names) == 1:
value_dict[child.return_names[0]] = outs
else:
for name, out in zip(child.return_names, outs):
value_dict[name] = out

return outs

# TorchScript Support #####################################################

def _set_jittable_template(self, raise_on_error: bool = False) -> None:
try: # Optimize `forward()` via `*.jinja` templates:
if ('forward' in self.__class__.__dict__ and
self.__class__.__dict__['forward'] != Sequential.forward):
raise ValueError("Cannot compile custom 'forward' method")

root_dir = osp.dirname(osp.realpath(__file__))
uid = '%06x' % random.randrange(16**6)
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}'
module = module_from_template(
module_name=jinja_prefix,
template_path=osp.join(root_dir, 'sequential.jinja'),
tmp_dirname='sequential',
# Keyword arguments:
modules=[self._caller_module],
signature=self.signature,
children=self.children,
)

self.forward = module.forward.__get__(self)

# NOTE We override `forward` on the class level here in order to
# support `torch.jit.trace` - this is generally dangerous to do,
# and limits `torch.jit.trace` to a single `Sequential` module:
self.__class__.forward = module.forward
except Exception as e: # pragma: no cover
if raise_on_error:
raise e

def __prepare_scriptable__(self) -> 'Sequential':
# Prevent type sharing when scripting `Sequential` modules:
type_store = torch.jit._recursive.concrete_type_store.type_store
type_store.pop(self.__class__, None)
return self

0 comments on commit 476e768

Please sign in to comment.