Skip to content

Commit

Permalink
Improve the accuracy to 63% for papers100m with graphSage model (#9386)
Browse files Browse the repository at this point in the history
In this PR, we would like to add a script to achieve 63% accuracy for
ogb-papers100m dataset with graphSage model.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Rishi Puri <[email protected]>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
4 people authored Jun 17, 2024
1 parent ff85e57 commit 4d7c4ed
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386))
- Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408))
- Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318))
- Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140))
Expand Down
187 changes: 89 additions & 98 deletions examples/ogbn_papers_100m.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,126 @@
# Reaches around 0.7870 ± 0.0036 test accuracy.

import argparse
import os
import time
from typing import Optional
import os.path as osp

import torch
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from tqdm import tqdm

import torch_geometric
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv
from torch_geometric.utils import to_undirected

parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--epochs', type=int, default=10)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--fan_out', type=int, default=30)
parser.add_argument(
"--use_gat_conv",
action='store_true',
help="Wether or not to use GATConv. (Defaults to using GCNConv)",
)
parser.add_argument(
"--n_gat_conv_heads",
type=int,
default=4,
help="If using GATConv, number of attention heads to use",
)
parser.add_argument('--num_neighbors', type=int, default=10)
parser.add_argument('--channels', type=int, default=256)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--dropout', type=float, default=0.5)
parser.add_argument('--workers', type=int, default=12)
args = parser.parse_args()
wall_clock_start = time.perf_counter()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = PygNodePropPredDataset(name='ogbn-papers100M',
root='/datasets/ogb_datasets')
root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'papers100')
dataset = PygNodePropPredDataset('ogbn-papers100M', root)
split_idx = dataset.get_idx_split()

data = dataset[0]
data.edge_index = to_undirected(data.edge_index, reduce="mean")

def get_num_workers() -> int:
try:
return len(os.sched_getaffinity(0)) // 2
except Exception:
return os.cpu_count() // 2


kwargs = dict(
num_neighbors=[args.fan_out] * args.num_layers,
train_loader = NeighborLoader(
data,
input_nodes=split_idx['train'],
num_neighbors=[args.num_neighbors] * args.num_layers,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
)
val_loader = NeighborLoader(
data,
input_nodes=split_idx['valid'],
num_neighbors=[args.num_neighbors] * args.num_layers,
batch_size=args.batch_size,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
)
test_loader = NeighborLoader(
data,
input_nodes=split_idx['test'],
num_neighbors=[args.num_neighbors] * args.num_layers,
batch_size=args.batch_size,
num_workers=args.num_workers,
persistent_workers=args.num_workers > 0,
)
# Set Up Neighbor Loading
data = dataset[0]
num_work = get_num_workers()
train_loader = NeighborLoader(data=data, input_nodes=split_idx['train'],
num_workers=num_work, drop_last=True,
shuffle=False, **kwargs)
val_loader = NeighborLoader(data=data, input_nodes=split_idx['valid'],
num_workers=num_work, **kwargs)
test_loader = NeighborLoader(data=data, input_nodes=split_idx['test'],
num_workers=num_work, **kwargs)

if args.use_gat_conv:
model = torch_geometric.nn.models.GAT(
dataset.num_features, args.hidden_channels, args.num_layers,
dataset.num_classes, heads=args.n_gat_conv_heads).to(device)
else:
model = torch_geometric.nn.models.GCN(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=0.0005)

warmup_steps = 20


class SAGE(torch.nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()

self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, args.channels))
for _ in range(args.num_layers - 2):
self.convs.append(SAGEConv(args.channels, args.channels))
self.convs.append(SAGEConv(args.channels, out_channels))

def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i != args.num_layers - 1:
x = x.relu()
x = F.dropout(x, p=args.dropout, training=self.training)
return x


model = SAGE(dataset.num_features, dataset.num_classes).to(args.device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


def train():
model.train()
for i, batch in enumerate(train_loader):
if i == warmup_steps:
torch.cuda.synchronize()
start_avg_time = time.perf_counter()
batch = batch.to(device)

total_loss = total_correct = total_examples = 0
for batch in tqdm(train_loader):
batch = batch.to(args.device)
optimizer.zero_grad()
batch_size = batch.num_sampled_nodes[0]
out = model(batch.x, batch.edge_index)[:batch_size]
y = batch.y[:batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()

if i % 10 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')
torch.cuda.synchronize()
print(f'Average Training Iteration Time (s/iter): \
{(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}')
total_loss += float(loss) * y.size(0)
total_correct += int(out.argmax(dim=-1).eq(y).sum())
total_examples += y.size(0)

return total_loss / total_examples, total_correct / total_examples


@torch.no_grad()
def test(loader: NeighborLoader, val_steps: Optional[int] = None):
def test(loader):
model.eval()

total_correct = total_examples = 0
for i, batch in enumerate(loader):
if val_steps is not None and i >= val_steps:
break
batch = batch.to(device)
batch_size = batch.num_sampled_nodes[0]
out = model(batch.x, batch.edge_index)[:batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
for batch in tqdm(loader):
batch = batch.to(args.device)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch.batch_size].view(-1).to(torch.long)

total_correct += int(out.argmax(dim=-1).eq(y).sum())
total_examples += y.size(0)

return total_correct / total_examples


torch.cuda.synchronize()
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time)=", prep_time, "seconds")
print("Beginning training...")
for epoch in range(1, 1 + args.epochs):
train()
val_acc = test(val_loader, val_steps=100)
print(f'Val Acc: ~{val_acc:.4f}')

test_acc = test(test_loader)
print(f'Test Acc: {test_acc:.4f}')
total_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total Program Runtime (total_time) =", total_time, "seconds")
print("total_time - prep_time =", total_time - prep_time, "seconds")
for epoch in range(1, args.epochs + 1):
loss, train_acc = train(epoch)
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}')
val_acc = test(val_loader)
print(f'Epoch {epoch:02d}, Val Acc: {val_acc:.4f}')
test_acc = test(test_loader)
print(f'Epoch {epoch:02d}, Test Acc: {test_acc:.4f}')

0 comments on commit 4d7c4ed

Please sign in to comment.