-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
46 lines (33 loc) · 1.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
import caffe
from lib.vdbc.dataset_factory import VDBC
from tools.solverwrapper import SolverWrapper
ROOT = '.'
dbtype = 'VOT'
dbpath = os.path.join(ROOT, 'data', dbtype)
gtpath = dbpath
output_dir = os.path.join(ROOT, 'model')
solver_prototxt = os.path.join(output_dir, 'solver.prototxt')
pretrained_model = os.path.join(output_dir, 'vggm.caffemodel')
EXCLUDE_SET = {
'vot2014': ['Basketball', 'Bolt', 'David', 'Diving',
'MotorRolling', 'Skating1', 'Trellis', 'Woman']}
def train_net(pretrained_model, snapshot_iters=1000000):
vdbc = VDBC(dbtype=dbtype, dbpath=dbpath, gtpath=gtpath, flush=True)
vdbc.del_exclude(EXCLUDE_SET['vot2014'])
print('VDBC instance built.')
num_frame = vdbc.get_frame_count()
max_iters = 64 * 4 * num_frame
snapshot_iters = 64 * num_frame
print('Total number of frames: {}'.format(num_frame))
print('Max iterations: {}'.format(max_iters))
sw = SolverWrapper(solver_prototxt, vdbc, output_dir, pretrained_model)
print('Initialization of SolverWrapper finished.')
sw.train_model(max_iters, snapshot_iters)
def main():
caffe.set_mode_gpu()
print('Model training begins.')
train_net(pretrained_model)
print('Model training finished.')
if __name__ == '__main__':
main()