forked from csehong/SSPP-DAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_model.py
113 lines (93 loc) · 5.53 KB
/
train_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import tensorflow as tf
import numpy as np
import os
from DAN import Dom_Adapt_Net as Network
from util.Logger import Logger
# Set Flag for Experiment
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_float('learning_rate', 1e-5, 'Initial learning rate.')
flags.DEFINE_float('keep_prob', 0.5, 'Dropout rate (for keeping)')
flags.DEFINE_integer('max_steps', 10000, 'Maximum number of steps for training')
flags.DEFINE_integer('batch_size', 64, 'Training Batch size')
flags.DEFINE_integer('test_batch_size', 64, 'Test batch size')
flags.DEFINE_integer('display_step', 25, 'Display step for training')
flags.DEFINE_integer('test_step', 16, 'Display step for test')
flags.DEFINE_integer('save_step', 50, 'Save step for Checkpoint')
flags.DEFINE_string('summaries_dir', 'expr/0.5_1e-5_FC6_FC6', 'Directory containing summary information about the experiment')
def main(_):
# Set Domain Adaptation Network
net_opts = Network.OPTS()
net_opts.network_name = 'dom_adapt_net'
net_opts.weight_path = 'pretrained/vgg-face.mat' #download link: http://www.vlfeat.org/matconvnet/models/vgg-face.mat
net_opts.num_class = 30
net = Network(net_opts)
net.construct()
# Set Dataset Manager
from data.data_manager import Manager as DataManager
dataset = DataManager('./data', net_opts.num_class)
# Set Optimizer (fine-tuning VGG-Face)
with tf.variable_scope('optimizer'):
net.trainable_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
net.trainable_var_names = [v.name for v in net.trainable_vars]
to_select_names = ('fc6', 'dom', 'class')
net.sel_vars = []
for i in range(len(net.trainable_var_names)):
if net.trainable_var_names[i].startswith(to_select_names) :
net.sel_vars.append(net.trainable_vars[i])
net.adam = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate).minimize(net.loss, var_list=net.sel_vars)
# Start Session
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
# Load Pretrained Model (VGG-Face)
sess.run(tf.global_variables_initializer())
net.vgg_net.load_pretrained(sess)
# Set Writier, Logger, Checkpoint folder
train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train', sess.graph)
logger = Logger(FLAGS.summaries_dir)
logger.write(str(FLAGS.__flags))
checkpoint_dir = os.path.join(FLAGS.summaries_dir, 'checkpoints')
checkpoint_prefix = os.path.join(checkpoint_dir, "model.ckpt")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# Restore Checkpoint
step = 0
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
step = int(ckpt.model_checkpoint_path.split('-')[-1])
print('Session restored successfully. step: {0}'.format(step))
step += 1
# Generate Mini-batch
train_batch = dataset.batch_generator_thread(FLAGS.batch_size, 'train')
test_batch = dataset.batch_generator_thread(FLAGS.test_batch_size, 'test')
# Run Session
for i in range(step, FLAGS.max_steps):
p = float(i) / (FLAGS.max_steps)
lamb = (2. / (1. + np.exp(-10. * p)) - 1.)
x_batch, y_batch, idx, dom_label = train_batch.next()
sess.run(net.adam, feed_dict={net.x: x_batch, net.y_: y_batch, net.d_: dom_label, net.with_class_idx: idx, net.keep_prob: FLAGS.keep_prob, net.l: lamb})
if (i + 1) % FLAGS.display_step == 0:
loss, d_loss, c_loss, d_acc, c_acc = sess.run([net.loss, net.dom_loss, net.class_loss, net.dom_accuracy, net.class_accuracy],
feed_dict={net.x: x_batch, net.y_: y_batch, net.d_: dom_label, net.with_class_idx: idx,
net.keep_prob: 1., net.l: lamb})
logger.write("[iter %d] costs(a,d,c)=(%4.4g,%4.4g,%4.4g) dom_acc: %.6f, class_acc: %.6f" %(i + 1, loss, d_loss, c_loss, d_acc, c_acc))
short_summary = tf.Summary(value=[
tf.Summary.Value(tag="loss/loss", simple_value=float(loss)),
tf.Summary.Value(tag="loss/dom", simple_value=float(d_loss)),
tf.Summary.Value(tag="loss/cat", simple_value=float(c_loss)),
tf.Summary.Value(tag="acc/dom", simple_value=float(d_acc)),
tf.Summary.Value(tag="acc/cat", simple_value=float(c_acc)),
tf.Summary.Value(tag="lambda", simple_value=float(lamb)),
])
train_writer.add_summary(short_summary, i)
if (i + 1) % FLAGS.test_step == 0:
x_batch, y_batch, idx, dom_label = test_batch.next()
loss, d_loss, c_loss, d_acc, c_acc = sess.run([net.loss, net.dom_loss, net.class_loss, net.dom_accuracy, net.class_accuracy],
feed_dict={net.x: x_batch, net.y_: y_batch, net.d_: dom_label, net.with_class_idx: idx,
net.keep_prob: 1., net.l: lamb})
logger.write("[Test iter %d] costs(a,d,c)=(%4.4g,%4.4g,%4.4g) dom_acc: %.6f, class_acc: %.6f" % (i + 1, loss, d_loss, c_loss, d_acc, c_acc))
if (i + 1) % FLAGS.save_step == 0:
saver.save(sess, checkpoint_prefix, global_step=i+1)
if __name__ == '__main__':
tf.app.run()