Skip to content

Commit

Permalink
Add module_headers property to nn.Sequential models (#8093)
Browse files Browse the repository at this point in the history
Fixes #8082.
  • Loading branch information
rusty1s authored Sep 29, 2023
1 parent 68552e7 commit 1e12d41
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093)
- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092))
- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938)
- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894)
Expand Down
16 changes: 11 additions & 5 deletions test/nn/test_sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def test_sequential():
assert len(model) == 5
assert str(model) == (
'Sequential(\n'
' (0): GCNConv(16, 64)\n'
' (1): ReLU(inplace=True)\n'
' (2): GCNConv(64, 64)\n'
' (3): ReLU(inplace=True)\n'
' (4): Linear(in_features=64, out_features=7, bias=True)\n'
' (0) - GCNConv(16, 64): x, edge_index -> x\n'
' (1) - ReLU(inplace=True): x -> x\n'
' (2) - GCNConv(64, 64): x, edge_index -> x\n'
' (3) - ReLU(inplace=True): x -> x\n'
' (4) - Linear(in_features=64, out_features=7, bias=True): x -> x\n'
')')

assert isinstance(model[0], GCNConv)
Expand All @@ -47,6 +47,12 @@ def test_sequential():
assert isinstance(model[3], ReLU)
assert isinstance(model[4], Linear)

assert model.module_headers[0] == (['x', 'edge_index'], ['x'])
assert model.module_headers[1] == (['x'], ['x'])
assert model.module_headers[2] == (['x', 'edge_index'], ['x'])
assert model.module_headers[3] == (['x'], ['x'])
assert model.module_headers[4] == (['x'], ['x'])

out = model(x, edge_index)
assert out.size() == (4, 7)

Expand Down
9 changes: 7 additions & 2 deletions torch_geometric/nn/sequential.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,10 @@ class {{cls_name}}(torch.nn.Module):
return {{calls|length}}

def __repr__(self) -> str:
return 'Sequential(\n{}\n)'.format('\n'.join(
[f' ({idx}): ' + str(self[idx]) for idx in range(len(self))]))
module_reprs = [
(f" ({i}) - {self[i]}: {', '.join(self.module_headers[i].args)} "
f"-> {', '.join(self.module_headers[i].output)}")
for i in range(len(self))
]

return 'Sequential(\n{}\n)'.format('\n'.join(module_reprs))
10 changes: 9 additions & 1 deletion torch_geometric/nn/sequential.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
import os
import os.path as osp
from typing import Callable, List, Tuple, Union
from typing import Callable, List, NamedTuple, Tuple, Union
from uuid import uuid1

import torch

from torch_geometric.nn.conv.utils.jit import class_from_module_repr


class HeaderDesc(NamedTuple):
args: List[str]
output: List[str]


def Sequential(
input_args: str,
modules: List[Union[Tuple[Callable, str], Callable]],
Expand Down Expand Up @@ -110,6 +115,9 @@ def Sequential(

# Instantiate a class from the rendered module representation.
module = class_from_module_repr(cls_name, module_repr)()
module.module_headers = [
HeaderDesc(in_desc, out_desc) for _, _, in_desc, out_desc in calls
]
module._names = list(modules.keys())
for name, submodule, _, _ in calls:
setattr(module, name, submodule)
Expand Down

0 comments on commit 1e12d41

Please sign in to comment.