Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement "Scaling Graph Neural Networks with Approximate PageRank" #9847

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))
- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))
- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))
- Added `PPRGo` implementation and example ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847))
- Allow top-k sparsification in `utils.get_ppr` and `transforms.GDC` ([#9847](https://github.com/pyg-team/pytorch_geometric/pull/9847))

### Changed

Expand Down
112 changes: 112 additions & 0 deletions examples/pprgo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Example usage of PPRGo (from the paper "Scaling Graph Neural Networks
# with Approximate PageRank). This trains on a small subset of the Reddit
# graph dataset and predicts on the rest.

import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

from torch_geometric.data import Data
from torch_geometric.datasets import Reddit
from torch_geometric.nn.models import PPRGo, pprgo_prune_features
from torch_geometric.transforms import GDC
from torch_geometric.utils import subgraph

# Parameters
parser = argparse.ArgumentParser()
parser.add_argument('--alpha', type=float, default=0.5)
parser.add_argument('--eps', type=float, default=1e-3)
parser.add_argument('--topk', type=int, default=32)
parser.add_argument('--n_train', type=int, default=1000)
parser.add_argument('--hidden_size', type=int, default=64)
parser.add_argument('--n_layers', type=int, default=2)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--n_epochs', type=int, default=50)
parser.add_argument('--n_power_iters', type=int, default=2)
parser.add_argument('--frac_predict', type=float, default=0.2)
args = parser.parse_args()

# Load Reddit dataset
s = time.time()
name = 'Reddit'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
data = Reddit(path)[0]
print(f'Loaded Reddit dataset in {round(time.time() - s, 2)}s')

# Get random split of train/test nodes (|train| << |test|)
s = time.time()
ind = torch.randperm(len(data.x), dtype=torch.long)
train_idx = ind[:args.n_train]
test_idx = ind[args.n_train:]

train_edge_index, _ = subgraph(train_idx, data.edge_index, relabel_nodes=True)
train = Data(x=data.x[train_idx], edge_index=train_edge_index,
y=data.y[train_idx])
print(f'Split data into {len(train_idx)} train and {len(test_idx)} test nodes '
f'in {round(time.time() - s, 2)}s')

# Set up ppr transform via gdc
s = time.time()
ppr = GDC(
exact=False, normalization_in='sym', normalization_out=None,
diffusion_kwargs=dict(
method='ppr',
alpha=args.alpha,
eps=args.eps,
), sparsification_kwargs=dict(method='topk', topk=args.topk))
train = ppr(train)
print(f'Ran PPR on {args.n_train} train nodes in {round(time.time() - s, 2)}s')

# Set up model and optimizer
model = PPRGo(num_features=data.x.shape[1], num_classes=data.y.max() + 1,
hidden_size=args.hidden_size, n_layers=args.n_layers,
dropout=args.dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)

# Run training, single-batch full graph
s = time.time()
print()
print('=' * 40)
model.train()
train = pprgo_prune_features(train)
for i in range(1, args.n_epochs + 1):
optimizer.zero_grad()
pred = model(train.x, train.edge_index, train.edge_attr)

loss = F.cross_entropy(pred, train.y)
top1 = torch.argmax(pred, dim=1)
acc = (torch.sum(top1 == train.y) / len(train.y)).item()

loss.backward()
optimizer.step()
if i % 10 == 0:
print(f'Epoch {i}: loss = {round(loss.item(), 3)}, '
f'acc = {round(acc, 3)}')

print('=' * 40)
print(f'Finished training in {round(time.time() - s, 2)}s')

# Do inference on the test graph and report performance
s = time.time()
preds = model.predict_power_iter(data.x, data.edge_index, args.n_power_iters,
args.alpha, args.frac_predict)
preds = preds.argmax(1)
labels = data.y

acc_train = 100 * accuracy_score(labels[train_idx], preds[train_idx])
acc_test = 100 * accuracy_score(labels[test_idx], preds[test_idx])

f1_train = f1_score(labels[train_idx], preds[train_idx], average='macro')
f1_test = f1_score(labels[test_idx], preds[test_idx], average='macro')

print()
print('=' * 40)
print(f'Accuracy ... train: {acc_train:.1f}%, test: {acc_test:.1f}%')
print(f'F1 score ... train: {f1_train:.3f}, test: {f1_test:.3f}')
print('=' * 40)
print(f'Ran inference in {round(time.time() - s, 2)}s')
69 changes: 69 additions & 0 deletions test/nn/models/test_pprgo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest
import torch

from torch_geometric.datasets import KarateClub
from torch_geometric.nn.models.pprgo import PPRGo, pprgo_prune_features
from torch_geometric.testing import withPackage


@pytest.mark.parametrize('n_layers', [1, 4])
@pytest.mark.parametrize('dropout', [0.0, 0.2])
def test_pprgo_forward(n_layers, dropout):
num_nodes = 100
num_edges = 500
num_features = 64
num_classes = 16
hidden_size = 64

# Need to ensure edge_index is sorted + full so dim size checks are right
# edge_index should contain all num_node unique nodes (ie every node is
# connected to at least one destination, since we truncate by topk)
edge_index = torch.stack([
torch.arange(0, num_nodes).repeat(num_edges // num_nodes),
torch.randint(0, num_nodes, [num_edges])
], dim=0)

# Mimic the behavior of pprgo_prune_features manually
# i.e., we expect node_embeds to be |V| x d
node_embeds = torch.rand((num_nodes, num_features))
node_embeds = node_embeds[edge_index[1], :]

edge_weight = torch.rand(num_edges)

model = PPRGo(num_features, num_classes, hidden_size, n_layers, dropout)
pred = model(node_embeds, edge_index, edge_weight)
assert pred.size() == (num_nodes, num_classes)


def test_pprgo_karate():
data = KarateClub()[0]
num_nodes = data.num_nodes

data = pprgo_prune_features(data)
data.edge_weight = torch.ones((data.edge_index.shape[1], ))

assert data.x.shape[0] == data.edge_index.shape[1]
num_classes = 16
model = PPRGo(num_nodes, num_classes, hidden_size=64, n_layers=3,
dropout=0.0)
pred = model(data.x, data.edge_index, data.edge_weight)
assert pred.shape == (num_nodes, num_classes)


@pytest.mark.parametrize('n_power_iters', [1, 3])
@pytest.mark.parametrize('frac_predict', [1.0, 0.5])
@pytest.mark.parametrize('batch_size', [1, 9999])
@withPackage('torch_sparse')
def test_pprgo_inference(n_power_iters, frac_predict, batch_size):
data = KarateClub()[0]
num_nodes = data.num_nodes

data = pprgo_prune_features(data)
data.edge_weight = torch.rand(data.edge_index.shape[1])

num_classes = 16
model = PPRGo(num_nodes, num_classes, 64, 3, 0.0)
logits = model.predict_power_iter(data.x, data.edge_index, n_power_iters,
frac_predict, batch_size=batch_size)
assert torch.all(~torch.isnan(logits))
assert logits.shape == (data.x.shape[0], num_classes)
12 changes: 5 additions & 7 deletions test/utils/test_ppr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@

@withPackage('numba')
@pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])])
def test_get_ppr(target):
@pytest.mark.parametrize('topk', [None, 1, 3, 7])
def test_get_ppr(target, topk):
data = KarateClub()[0]

edge_index, edge_weight = get_ppr(
data.edge_index,
alpha=0.1,
eps=1e-5,
target=target,
)
edge_index, edge_weight = get_ppr(data.edge_index, alpha=0.1, eps=1e-5,
target=target, topk=topk)

assert edge_index.size(0) == 2
assert edge_index.size(1) == edge_weight.numel()
Expand All @@ -26,3 +23,4 @@ def test_get_ppr(target):
assert edge_index[0].min() == min_row and edge_index[0].max() == max_row
assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes
assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0
assert len(edge_weight) % (topk or 1) == 0
56 changes: 11 additions & 45 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,54 +32,20 @@
from .git_mol import GITMol
from .molecule_gpt import MoleculeGPT
from .glem import GLEM
from .pprgo import PPRGo, pprgo_prune_features
# Deprecated:
from torch_geometric.explain.algorithm.captum import (to_captum_input,
captum_output_to_dicts)

__all__ = classes = [
'MLP',
'GCN',
'GraphSAGE',
'GIN',
'GAT',
'PNA',
'EdgeCNN',
'JumpingKnowledge',
'HeteroJumpingKnowledge',
'MetaLayer',
'Node2Vec',
'DeepGraphInfomax',
'InnerProductDecoder',
'GAE',
'VGAE',
'ARGA',
'ARGVA',
'SignedGCN',
'RENet',
'GraphUNet',
'SchNet',
'DimeNet',
'DimeNetPlusPlus',
'to_captum_model',
'to_captum_input',
'captum_output_to_dicts',
'MetaPath2Vec',
'DeepGCNLayer',
'TGNMemory',
'LabelPropagation',
'CorrectAndSmooth',
'AttentiveFP',
'RECT_L',
'LINKX',
'LightGCN',
'MaskLabel',
'GroupAddRev',
'GNNFF',
'PMLP',
'NeuralFingerprint',
'ViSNet',
'GRetriever',
'GITMol',
'MoleculeGPT',
'GLEM',
'MLP', 'GCN', 'GraphSAGE', 'GIN', 'GAT', 'PNA', 'EdgeCNN',
'JumpingKnowledge', 'HeteroJumpingKnowledge', 'MetaLayer', 'Node2Vec',
'DeepGraphInfomax', 'InnerProductDecoder', 'GAE', 'VGAE', 'ARGA', 'ARGVA',
'SignedGCN', 'RENet', 'GraphUNet', 'SchNet', 'DimeNet', 'DimeNetPlusPlus',
'to_captum_model', 'to_captum_input', 'captum_output_to_dicts',
'MetaPath2Vec', 'DeepGCNLayer', 'TGNMemory', 'LabelPropagation',
'CorrectAndSmooth', 'AttentiveFP', 'RECT_L', 'LINKX', 'LightGCN',
'MaskLabel', 'GroupAddRev', 'GNNFF', 'PMLP', 'NeuralFingerprint', 'ViSNet',
'GRetriever', 'GITMol', 'MoleculeGPT', 'GLEM', 'PPRGo',
'pprgo_prune_features'
]
Loading
Loading