Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta committed Dec 27, 2024
1 parent ab2b458 commit b566219
Showing 1 changed file with 41 additions and 26 deletions.
67 changes: 41 additions & 26 deletions examples/multi_gpu/distributed_sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import os.path as osp
from math import ceil
from tqdm import tqdm

import torch
import torch.distributed as dist
Expand All @@ -14,10 +16,14 @@


class SAGE(torch.nn.Module):
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, num_layers: int = 2):
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
num_layers: int = 2,
) -> None:
super().__init__()

self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
Expand All @@ -34,20 +40,25 @@ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:


@torch.no_grad()
def test(loader, model, rank):
def test(
loader: NeighborLoader,
model: DistributedDataParallel,
rank: int,
) -> Tensor:
model.eval()

total_correct = total_examples = 0
total_correct = torch.tensor(0, dtype=torch.long, device=rank)
total_examples = 0
for i, batch in enumerate(loader):
out = model(batch.x, batch.edge_index.to(rank))
pred = out[:batch.batch_size].argmax(dim=-1)
y = batch.y[:batch.batch_size].to(rank)
total_correct += int((pred == y).sum())
total_correct += (pred == y).sum()
total_examples += batch.batch_size
return torch.tensor(total_correct / total_examples, device=rank)

return total_correct / total_examples


def run(rank, world_size, dataset):
def run(rank: int, world_size: int, dataset: Reddit) -> None:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)
Expand Down Expand Up @@ -94,43 +105,47 @@ def run(rank, world_size, dataset):

for epoch in range(1, 21):
model.train()
for batch in train_loader:
optimizer.zero_grad()
for batch in tqdm(
train_loader,
desc=f'Epoch {epoch:02d}',
disable=rank != 0,
):
out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]
loss = F.cross_entropy(out, batch.y[:batch.batch_size])
loss.backward()
optimizer.step()

dist.barrier()
optimizer.zero_grad()

if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')
print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}')

if epoch % 5 == 0:
train_acc = test(train_loader, model, rank)
val_acc = test(val_loader, model, rank)
test_acc = test(test_loader, model, rank)

if world_size > 1:
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
dist.all_reduce(train_acc, op=dist.ReduceOp.SUM)
train_acc /= world_size
val_acc /= world_size
test_acc /= world_size
dist.all_reduce(train_acc, op=dist.ReduceOp.AVG)
dist.all_reduce(val_acc, op=dist.ReduceOp.AVG)
dist.all_reduce(test_acc, op=dist.ReduceOp.AVG)

if rank == 0:
print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '
f'Test: {test_acc:.4f}')

dist.barrier()
print(f'Train acc: {train_acc:.4f}, '
f'Val acc: {val_acc:.4f}, '
f'Test acc: {test_acc:.4f}')

dist.destroy_process_group()


if __name__ == '__main__':
dataset = Reddit('../../data/Reddit')

path = osp.join(
osp.dirname(__file__),
'..',
'..',
'data',
'Reddit',
)
dataset = Reddit(path)
world_size = torch.cuda.device_count()
print("Let's use", world_size, "GPUs!")
mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)

0 comments on commit b566219

Please sign in to comment.