Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent ab2b458 commit 63156b1
Showing 1 changed file with 56 additions and 32 deletions.
88 changes: 56 additions & 32 deletions examples/multi_gpu/distributed_batching.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
import os
import os.path as osp

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn.functional as F
from ogb.graphproppred import Evaluator
from ogb.graphproppred import PygGraphPropPredDataset as Dataset
from ogb.graphproppred import Evaluator, PygGraphPropPredDataset
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch.nn import BatchNorm1d as BatchNorm
from torch.nn import Linear, ReLU, Sequential
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch_sparse import SparseTensor

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, global_mean_pool
from torch_geometric.typing import WITH_TORCH_SPARSE

if not WITH_TORCH_SPARSE:
quit("This example requires 'torch-sparse'")


class GIN(torch.nn.Module):
def __init__(self, hidden_channels, out_channels, num_layers=3,
dropout=0.5):
def __init__(
self,
hidden_channels: int,
out_channels: int,
num_layers: int = 3,
dropout: float = 0.5,
) -> None:
super().__init__()

self.dropout = dropout

self.atom_encoder = AtomEncoder(hidden_channels)
self.bond_encoder = BondEncoder(hidden_channels)

self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
nn = Sequential(
Expand All @@ -45,7 +44,12 @@ def __init__(self, hidden_channels, out_channels, num_layers=3,

self.lin = Linear(hidden_channels, out_channels)

def forward(self, x, adj_t, batch):
def forward(
self,
x: torch.Tensor,
adj_t: SparseTensor,
batch: torch.Tensor,
) -> torch.Tensor:
x = self.atom_encoder(x)
edge_attr = adj_t.coo()[2]
adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo')
Expand All @@ -59,21 +63,29 @@ def forward(self, x, adj_t, batch):
return x


def run(rank, world_size: int, dataset_name: str, root: str):
def run(rank: int, world_size: int, dataset_name: str, root: str) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)

dataset = Dataset(dataset_name, root,
pre_transform=T.ToSparseTensor(attr='edge_attr'))
dataset = PygGraphPropPredDataset(
dataset_name,
root=root,
pre_transform=T.ToSparseTensor(attr='edge_attr'),
)
split_idx = dataset.get_idx_split()
evaluator = Evaluator(dataset_name)

train_dataset = dataset[split_idx['train']]
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size,
rank=rank)
train_loader = DataLoader(train_dataset, batch_size=128,
sampler=train_sampler)
train_loader = DataLoader(
train_dataset,
batch_size=128,
sampler=DistributedSampler(
train_dataset,
shuffle=True,
drop_last=True,
),
)

torch.manual_seed(12345)
model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank)
Expand All @@ -87,20 +99,22 @@ def run(rank, world_size: int, dataset_name: str, root: str):

for epoch in range(1, 51):
model.train()

total_loss = torch.zeros(2).to(rank)
train_loader.sampler.set_epoch(epoch)
total_loss = torch.zeros(2, device=rank)
for data in train_loader:
data = data.to(rank)
optimizer.zero_grad()
logits = model(data.x, data.adj_t, data.batch)
loss = criterion(logits, data.y.to(torch.float))
loss.backward()
optimizer.step()
total_loss[0] += float(loss) * logits.size(0)
total_loss[1] += data.num_graphs
optimizer.zero_grad()

dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
loss = float(total_loss[0] / total_loss[1])
with torch.no_grad():
total_loss[0] += loss * logits.size(0)
total_loss[1] += data.num_graphs

dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)
train_loss = total_loss[0] / total_loss[1]

if rank == 0: # We evaluate on a single GPU for now.
model.eval()
Expand All @@ -127,8 +141,10 @@ def run(rank, world_size: int, dataset_name: str, root: str):
'y_true': torch.cat(y_true, dim=0),
})['rocauc']

print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '
f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}')
print(f'Epoch: {epoch:03d}, '
f'Loss: {train_loss:.4f}, '
f'Val: {val_rocauc:.4f}, '
f'Test: {test_rocauc:.4f}')

dist.barrier()

Expand All @@ -137,11 +153,19 @@ def run(rank, world_size: int, dataset_name: str, root: str):

if __name__ == '__main__':
dataset_name = 'ogbg-molhiv'
root = '../../data/OGB'

root = osp.join(
osp.dirname(__file__),
'..',
'..',
'data',
'OGB',
)
# Download and process the dataset on main process.
Dataset(dataset_name, root,
pre_transform=T.ToSparseTensor(attr='edge_attr'))
PygGraphPropPredDataset(
dataset_name,
root,
pre_transform=T.ToSparseTensor(attr='edge_attr'),
)

world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')
Expand Down

0 comments on commit 63156b1

Please sign in to comment.