-
Notifications
You must be signed in to change notification settings - Fork 76
/
train.py
125 lines (104 loc) · 4.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
import time
import os
import tqdm
from os.path import dirname
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
cudnn.enabled = True
import pickle
import torch
import importlib
import argparse
def parse_command_line():
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--task', type=str, default='pose', help='task to be trained')
parser.add_argument('-c', '--continue_exp', type=str, help='continue exp')
parser.add_argument('-e', '--exp', type=str, default='pose', help='experiments name')
parser.add_argument('-m', '--mode', type=str, default='single', help='scale mode')
args = parser.parse_args()
return args
def reload(config):
"""
load or initialize model's parameters by config from config['opt'].continue_exp
config['train']['epoch'] records the epoch num
config['inference']['net'] is the model
"""
opt = config['opt']
if opt.continue_exp:
resume = os.path.join('exp', opt.continue_exp)
resume_file = os.path.join(resume, 'checkpoint.pth.tar')
if os.path.isfile(resume_file):
print("=> loading checkpoint '{}'".format(resume))
checkpoint = torch.load(resume_file)
config['inference']['net'].load_state_dict(checkpoint['state_dict'])
config['train']['optimizer'].load_state_dict(checkpoint['optimizer'])
config['train']['epoch'] = checkpoint['epoch']
print("=> loaded checkpoint '{}' (epoch {})"
.format(resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(resume))
exit(0)
if 'epoch' not in config['train']:
config['train']['epoch'] = 0
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
"""
from pytorch/examples
"""
basename = dirname(filename)
if not os.path.exists(basename):
os.makedirs(basename)
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')
def save(config):
resume = os.path.join('exp', config['opt'].exp)
resume_file = os.path.join(resume, 'checkpoint.pth.tar')
save_checkpoint({
'state_dict': config['inference']['net'].state_dict(),
'optimizer' : config['train']['optimizer'].state_dict(),
'epoch': config['train']['epoch'],
}, False, filename=resume_file)
print('=> save checkpoint')
def train(train_func, data_func, config, post_epoch=None):
while True:
print('epoch: ', config['train']['epoch'])
if 'epoch_num' in config['train']:
if config['train']['epoch'] > config['train']['epoch_num']:
break
for phase in ['train', 'valid']:
num_step = config['train']['{}_iters'.format(phase)]
generator = data_func(phase)
print('start', phase, config['opt'].exp)
show_range = range(num_step)
show_range = tqdm.tqdm(show_range, total = num_step, ascii=True)
batch_id = num_step * config['train']['epoch']
for i in show_range:
datas = next(generator)
outs = train_func(batch_id + i, config, phase, **datas)
config['train']['epoch'] += 1
save(config)
def init():
"""
import configurations specified by opt.task
task.__config__ contains the variables that control the training and testing
make_network builds a function which can do forward and backward propagation
please check task/base.py
"""
opt = parse_command_line()
task = importlib.import_module('task.' + opt.task)
exp_path = os.path.join('exp', opt.exp)
config = task.__config__
try: os.makedirs(exp_path)
except FileExistsError: pass
config['opt'] = opt
config['data_provider'] = importlib.import_module(config['data_provider'])
func = task.make_network(config)
reload(config)
return func, config
def main():
func, config = init()
data_func = config['data_provider'].init(config)
train(func, data_func, config)
if __name__ == '__main__':
main()