forked from GarrettChristian/spvnas
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
107 lines (90 loc) · 3.39 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import random
import sys
import numpy as np
import torch
import torch.backends.cudnn
import torch.cuda
import torch.nn
import torch.utils.data
from torchpack import distributed as dist
from torchpack.callbacks import InferenceRunner, MaxSaver, Saver
from torchpack.environ import auto_set_run_dir, set_run_dir
from torchpack.utils.config import configs
from torchpack.utils.logging import logger
from core import builder
from core.callbacks import MeanIoU
from core.trainers import SemanticKITTITrainer
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument('config', metavar='FILE', help='config file')
parser.add_argument('--run-dir', metavar='DIR', help='run directory')
args, opts = parser.parse_known_args()
configs.load(args.config, recursive=True)
configs.update(opts)
if configs.distributed:
dist.init()
torch.backends.cudnn.benchmark = True
torch.cuda.set_device(dist.local_rank())
if args.run_dir is None:
args.run_dir = auto_set_run_dir()
else:
set_run_dir(args.run_dir)
logger.info(' '.join([sys.executable] + sys.argv))
logger.info(f'Experiment started: "{args.run_dir}".' + '\n' + f'{configs}')
# seed
if ('seed' not in configs.train) or (configs.train.seed is None):
configs.train.seed = torch.initial_seed() % (2 ** 32 - 1)
seed = configs.train.seed + dist.rank(
) * configs.workers_per_gpu * configs.num_epochs
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
dataset = builder.make_dataset()
dataflow = {}
for split in dataset:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset[split],
num_replicas=dist.size(),
rank=dist.rank(),
shuffle=(split == 'train'))
dataflow[split] = torch.utils.data.DataLoader(
dataset[split],
batch_size=configs.batch_size,
sampler=sampler,
num_workers=configs.workers_per_gpu,
pin_memory=True,
collate_fn=dataset[split].collate_fn)
model = builder.make_model().cuda()
if configs.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[dist.local_rank()], find_unused_parameters=True)
criterion = builder.make_criterion()
optimizer = builder.make_optimizer(model)
scheduler = builder.make_scheduler(optimizer)
trainer = SemanticKITTITrainer(model=model,
criterion=criterion,
optimizer=optimizer,
scheduler=scheduler,
num_workers=configs.workers_per_gpu,
seed=seed,
amp_enabled=configs.amp_enabled)
trainer.train_with_defaults(
dataflow['train'],
num_epochs=configs.num_epochs,
callbacks=[
InferenceRunner(
dataflow[split],
callbacks=[
MeanIoU(name=f'iou/{split}',
num_classes=configs.data.num_classes,
ignore_label=configs.data.ignore_label)
],
) for split in ['test']
] + [
MaxSaver('iou/test'),
Saver(),
])
if __name__ == '__main__':
main()