diff --git a/examples/multi_gpu/distributed_batching.py b/examples/multi_gpu/distributed_batching.py index f5c05a176823..b242499d3e76 100644 --- a/examples/multi_gpu/distributed_batching.py +++ b/examples/multi_gpu/distributed_batching.py @@ -1,36 +1,35 @@ import os +import os.path as osp import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F -from ogb.graphproppred import Evaluator -from ogb.graphproppred import PygGraphPropPredDataset as Dataset +from ogb.graphproppred import Evaluator, PygGraphPropPredDataset from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder from torch.nn import BatchNorm1d as BatchNorm from torch.nn import Linear, ReLU, Sequential from torch.nn.parallel import DistributedDataParallel from torch.utils.data.distributed import DistributedSampler +from torch_sparse import SparseTensor import torch_geometric.transforms as T from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, global_mean_pool -from torch_geometric.typing import WITH_TORCH_SPARSE - -if not WITH_TORCH_SPARSE: - quit("This example requires 'torch-sparse'") class GIN(torch.nn.Module): - def __init__(self, hidden_channels, out_channels, num_layers=3, - dropout=0.5): + def __init__( + self, + hidden_channels: int, + out_channels: int, + num_layers: int = 3, + dropout: float = 0.5, + ) -> None: super().__init__() - self.dropout = dropout - self.atom_encoder = AtomEncoder(hidden_channels) self.bond_encoder = BondEncoder(hidden_channels) - self.convs = torch.nn.ModuleList() for _ in range(num_layers): nn = Sequential( @@ -45,7 +44,12 @@ def __init__(self, hidden_channels, out_channels, num_layers=3, self.lin = Linear(hidden_channels, out_channels) - def forward(self, x, adj_t, batch): + def forward( + self, + x: torch.Tensor, + adj_t: SparseTensor, + batch: torch.Tensor, + ) -> torch.Tensor: x = self.atom_encoder(x) edge_attr = adj_t.coo()[2] adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo') @@ -59,21 +63,29 @@ def forward(self, x, adj_t, batch): return x -def run(rank, world_size: int, dataset_name: str, root: str): +def run(rank: int, world_size: int, dataset_name: str, root: str) -> None: os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) - dataset = Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + dataset = PygGraphPropPredDataset( + dataset_name, + root=root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) split_idx = dataset.get_idx_split() evaluator = Evaluator(dataset_name) train_dataset = dataset[split_idx['train']] - train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, - rank=rank) - train_loader = DataLoader(train_dataset, batch_size=128, - sampler=train_sampler) + train_loader = DataLoader( + train_dataset, + batch_size=128, + sampler=DistributedSampler( + train_dataset, + shuffle=True, + drop_last=True, + ), + ) torch.manual_seed(12345) model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank) @@ -87,20 +99,22 @@ def run(rank, world_size: int, dataset_name: str, root: str): for epoch in range(1, 51): model.train() - - total_loss = torch.zeros(2).to(rank) + train_loader.sampler.set_epoch(epoch) + total_loss = torch.zeros(2, device=rank) for data in train_loader: data = data.to(rank) - optimizer.zero_grad() logits = model(data.x, data.adj_t, data.batch) loss = criterion(logits, data.y.to(torch.float)) loss.backward() optimizer.step() - total_loss[0] += float(loss) * logits.size(0) - total_loss[1] += data.num_graphs + optimizer.zero_grad() - dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) - loss = float(total_loss[0] / total_loss[1]) + with torch.no_grad(): + total_loss[0] += loss * logits.size(0) + total_loss[1] += data.num_graphs + + dist.all_reduce(total_loss, op=dist.ReduceOp.AVG) + train_loss = total_loss[0] / total_loss[1] if rank == 0: # We evaluate on a single GPU for now. model.eval() @@ -127,8 +141,10 @@ def run(rank, world_size: int, dataset_name: str, root: str): 'y_true': torch.cat(y_true, dim=0), })['rocauc'] - print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ' - f'Val: {val_rocauc:.4f}, Test: {test_rocauc:.4f}') + print(f'Epoch: {epoch:03d}, ' + f'Loss: {train_loss:.4f}, ' + f'Val: {val_rocauc:.4f}, ' + f'Test: {test_rocauc:.4f}') dist.barrier() @@ -137,11 +153,19 @@ def run(rank, world_size: int, dataset_name: str, root: str): if __name__ == '__main__': dataset_name = 'ogbg-molhiv' - root = '../../data/OGB' - + root = osp.join( + osp.dirname(__file__), + '..', + '..', + 'data', + 'OGB', + ) # Download and process the dataset on main process. - Dataset(dataset_name, root, - pre_transform=T.ToSparseTensor(attr='edge_attr')) + PygGraphPropPredDataset( + dataset_name, + root, + pre_transform=T.ToSparseTensor(attr='edge_attr'), + ) world_size = torch.cuda.device_count() print('Let\'s use', world_size, 'GPUs!')