-
Notifications
You must be signed in to change notification settings - Fork 7
/
mnist_ddp_profiler.py
120 lines (95 loc) · 3.75 KB
/
mnist_ddp_profiler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Based on multiprocessing example from
# https://yangkky.github.io/2019/07/08/distributed-pytorch-tutorial.html
from datetime import datetime
import argparse
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from torch.profiler import profile, record_function, ProfilerActivity
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(
nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
def train(num_epochs):
dist.init_process_group(backend='nccl')
torch.manual_seed(0)
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
verbose = dist.get_rank() == 0 # print only on global_rank==0
prof = profile(
schedule=torch.profiler.schedule(
skip_first=10,
wait=5,
warmup=1,
active=3,
repeat=1)
on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs/profiler'),
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True, # record shapes of operator inputs
profile_memory=True, # track tensor memory allocation/deallocation
with_stack=True, # record source code information
with_flops=True, # estimate FLOPS of operators
)
model = ConvNet().cuda()
batch_size = 100
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(), 1e-4)
model = DistributedDataParallel(model, device_ids=[local_rank])
train_dataset = MNIST(root='./data', train=True,
transform=transforms.ToTensor(), download=True)
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
shuffle=False, num_workers=0, pin_memory=True,
sampler=train_sampler)
start = datetime.now()
prof.start()
for epoch in range(num_epochs):
tot_loss = 0
for i, (images, labels) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
prof.step()
tot_loss += loss.item()
if verbose:
print('Epoch [{}/{}], average loss: {:.4f}'.format(
epoch + 1,
num_epochs,
tot_loss / (i+1)))
prof.stop()
if verbose:
print("Training completed in: " + str(datetime.now() - start))
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', default=2, type=int, metavar='N',
help='number of total epochs to run')
args = parser.parse_args()
train(args.epochs)
if __name__ == '__main__':
main()