-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Improve the accuracy to 63% for papers100m with graphSage model (#9386)
In this PR, we would like to add a script to achieve 63% accuracy for ogb-papers100m dataset with graphSage model. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rishi Puri <[email protected]> Co-authored-by: rusty1s <[email protected]>
- Loading branch information
1 parent
ff85e57
commit 4d7c4ed
Showing
2 changed files
with
90 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,135 +1,126 @@ | ||
# Reaches around 0.7870 ± 0.0036 test accuracy. | ||
|
||
import argparse | ||
import os | ||
import time | ||
from typing import Optional | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from ogb.nodeproppred import PygNodePropPredDataset | ||
from tqdm import tqdm | ||
|
||
import torch_geometric | ||
from torch_geometric.loader import NeighborLoader | ||
from torch_geometric.nn import SAGEConv | ||
from torch_geometric.utils import to_undirected | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--hidden_channels', type=int, default=256) | ||
parser.add_argument('--num_layers', type=int, default=2) | ||
parser.add_argument('--lr', type=float, default=0.001) | ||
parser.add_argument('--epochs', type=int, default=20) | ||
parser.add_argument('--device', type=str, default='cuda') | ||
parser.add_argument('--epochs', type=int, default=10) | ||
parser.add_argument('--num_layers', type=int, default=3) | ||
parser.add_argument('--batch_size', type=int, default=1024) | ||
parser.add_argument('--fan_out', type=int, default=30) | ||
parser.add_argument( | ||
"--use_gat_conv", | ||
action='store_true', | ||
help="Wether or not to use GATConv. (Defaults to using GCNConv)", | ||
) | ||
parser.add_argument( | ||
"--n_gat_conv_heads", | ||
type=int, | ||
default=4, | ||
help="If using GATConv, number of attention heads to use", | ||
) | ||
parser.add_argument('--num_neighbors', type=int, default=10) | ||
parser.add_argument('--channels', type=int, default=256) | ||
parser.add_argument('--lr', type=float, default=0.003) | ||
parser.add_argument('--dropout', type=float, default=0.5) | ||
parser.add_argument('--workers', type=int, default=12) | ||
args = parser.parse_args() | ||
wall_clock_start = time.perf_counter() | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
dataset = PygNodePropPredDataset(name='ogbn-papers100M', | ||
root='/datasets/ogb_datasets') | ||
root = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'papers100') | ||
dataset = PygNodePropPredDataset('ogbn-papers100M', root) | ||
split_idx = dataset.get_idx_split() | ||
|
||
data = dataset[0] | ||
data.edge_index = to_undirected(data.edge_index, reduce="mean") | ||
|
||
def get_num_workers() -> int: | ||
try: | ||
return len(os.sched_getaffinity(0)) // 2 | ||
except Exception: | ||
return os.cpu_count() // 2 | ||
|
||
|
||
kwargs = dict( | ||
num_neighbors=[args.fan_out] * args.num_layers, | ||
train_loader = NeighborLoader( | ||
data, | ||
input_nodes=split_idx['train'], | ||
num_neighbors=[args.num_neighbors] * args.num_layers, | ||
batch_size=args.batch_size, | ||
shuffle=True, | ||
num_workers=args.num_workers, | ||
persistent_workers=args.num_workers > 0, | ||
) | ||
val_loader = NeighborLoader( | ||
data, | ||
input_nodes=split_idx['valid'], | ||
num_neighbors=[args.num_neighbors] * args.num_layers, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers, | ||
persistent_workers=args.num_workers > 0, | ||
) | ||
test_loader = NeighborLoader( | ||
data, | ||
input_nodes=split_idx['test'], | ||
num_neighbors=[args.num_neighbors] * args.num_layers, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers, | ||
persistent_workers=args.num_workers > 0, | ||
) | ||
# Set Up Neighbor Loading | ||
data = dataset[0] | ||
num_work = get_num_workers() | ||
train_loader = NeighborLoader(data=data, input_nodes=split_idx['train'], | ||
num_workers=num_work, drop_last=True, | ||
shuffle=False, **kwargs) | ||
val_loader = NeighborLoader(data=data, input_nodes=split_idx['valid'], | ||
num_workers=num_work, **kwargs) | ||
test_loader = NeighborLoader(data=data, input_nodes=split_idx['test'], | ||
num_workers=num_work, **kwargs) | ||
|
||
if args.use_gat_conv: | ||
model = torch_geometric.nn.models.GAT( | ||
dataset.num_features, args.hidden_channels, args.num_layers, | ||
dataset.num_classes, heads=args.n_gat_conv_heads).to(device) | ||
else: | ||
model = torch_geometric.nn.models.GCN( | ||
dataset.num_features, | ||
args.hidden_channels, | ||
args.num_layers, | ||
dataset.num_classes, | ||
).to(device) | ||
|
||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, | ||
weight_decay=0.0005) | ||
|
||
warmup_steps = 20 | ||
|
||
|
||
class SAGE(torch.nn.Module): | ||
def __init__(self, in_channels, out_channels): | ||
super().__init__() | ||
|
||
self.convs = torch.nn.ModuleList() | ||
self.convs.append(SAGEConv(in_channels, args.channels)) | ||
for _ in range(args.num_layers - 2): | ||
self.convs.append(SAGEConv(args.channels, args.channels)) | ||
self.convs.append(SAGEConv(args.channels, out_channels)) | ||
|
||
def forward(self, x, edge_index): | ||
for i, conv in enumerate(self.convs): | ||
x = conv(x, edge_index) | ||
if i != args.num_layers - 1: | ||
x = x.relu() | ||
x = F.dropout(x, p=args.dropout, training=self.training) | ||
return x | ||
|
||
|
||
model = SAGE(dataset.num_features, dataset.num_classes).to(args.device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
|
||
def train(): | ||
model.train() | ||
for i, batch in enumerate(train_loader): | ||
if i == warmup_steps: | ||
torch.cuda.synchronize() | ||
start_avg_time = time.perf_counter() | ||
batch = batch.to(device) | ||
|
||
total_loss = total_correct = total_examples = 0 | ||
for batch in tqdm(train_loader): | ||
batch = batch.to(args.device) | ||
optimizer.zero_grad() | ||
batch_size = batch.num_sampled_nodes[0] | ||
out = model(batch.x, batch.edge_index)[:batch_size] | ||
y = batch.y[:batch_size].view(-1).to(torch.long) | ||
out = model(batch.x, batch.edge_index)[:batch.batch_size] | ||
y = batch.y[:batch.batch_size].view(-1).to(torch.long) | ||
loss = F.cross_entropy(out, y) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if i % 10 == 0: | ||
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') | ||
torch.cuda.synchronize() | ||
print(f'Average Training Iteration Time (s/iter): \ | ||
{(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}') | ||
total_loss += float(loss) * y.size(0) | ||
total_correct += int(out.argmax(dim=-1).eq(y).sum()) | ||
total_examples += y.size(0) | ||
|
||
return total_loss / total_examples, total_correct / total_examples | ||
|
||
|
||
@torch.no_grad() | ||
def test(loader: NeighborLoader, val_steps: Optional[int] = None): | ||
def test(loader): | ||
model.eval() | ||
|
||
total_correct = total_examples = 0 | ||
for i, batch in enumerate(loader): | ||
if val_steps is not None and i >= val_steps: | ||
break | ||
batch = batch.to(device) | ||
batch_size = batch.num_sampled_nodes[0] | ||
out = model(batch.x, batch.edge_index)[:batch_size] | ||
pred = out.argmax(dim=-1) | ||
y = batch.y[:batch_size].view(-1).to(torch.long) | ||
|
||
total_correct += int((pred == y).sum()) | ||
for batch in tqdm(loader): | ||
batch = batch.to(args.device) | ||
out = model(batch.x, batch.edge_index)[:batch.batch_size] | ||
y = batch.y[:batch.batch_size].view(-1).to(torch.long) | ||
|
||
total_correct += int(out.argmax(dim=-1).eq(y).sum()) | ||
total_examples += y.size(0) | ||
|
||
return total_correct / total_examples | ||
|
||
|
||
torch.cuda.synchronize() | ||
prep_time = round(time.perf_counter() - wall_clock_start, 2) | ||
print("Total time before training begins (prep_time)=", prep_time, "seconds") | ||
print("Beginning training...") | ||
for epoch in range(1, 1 + args.epochs): | ||
train() | ||
val_acc = test(val_loader, val_steps=100) | ||
print(f'Val Acc: ~{val_acc:.4f}') | ||
|
||
test_acc = test(test_loader) | ||
print(f'Test Acc: {test_acc:.4f}') | ||
total_time = round(time.perf_counter() - wall_clock_start, 2) | ||
print("Total Program Runtime (total_time) =", total_time, "seconds") | ||
print("total_time - prep_time =", total_time - prep_time, "seconds") | ||
for epoch in range(1, args.epochs + 1): | ||
loss, train_acc = train(epoch) | ||
print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}') | ||
val_acc = test(val_loader) | ||
print(f'Epoch {epoch:02d}, Val Acc: {val_acc:.4f}') | ||
test_acc = test(test_loader) | ||
print(f'Epoch {epoch:02d}, Test Acc: {test_acc:.4f}') |