-
Notifications
You must be signed in to change notification settings - Fork 1
/
training.py
157 lines (137 loc) · 5.92 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Training functions.
Author: Herman Kamper
Date: 2016, 2018
"""
from __future__ import division
from __future__ import print_function
from datetime import datetime
import numpy as np
import sys
import tensorflow as tf
import timeit
def train_fixed_epochs(n_epochs, optimizer, train_loss_tensor,
train_feed_iterator, feed_placeholders, validation_loss_tensor=None,
validation_feed_iterator=None, load_model_fn=None, save_model_fn=None,
save_best_val_model_fn=None, config=None, epoch_offset=0):
"""
Train a model for a fixed number of epochs.
Parameters
----------
train_loss : Tensor
The function that is optimized; should match the feed specified through
`train_feed_iterator` and `feed_placeholders`. This can also be a list
of Tensors, in which case the training loss is output as an array.
train_feed_batch_iterator : generator
Generates the values for the `feed_placeholders` for each training
batch.
feed_placeholders : list of placeholder
The placeholders that is required for the `train_loss` (and optionally
`validation_loss`) feeds.
load_model_fn : str
If provided, initialize session from this file.
save_model_fn : str
If provided, save final session to this file.
save_best_val_model_fn : str
If provided, save the best validation session to this file. If the
`validation_loss_tensor` is a list of Tensors, the last value is taken
as the current validation loss.
Return
------
record_dict : dict
Statistics tracked during training. Each key describe the statistic,
while the value is a list of (epoch, value) tuples.
"""
assert save_best_val_model_fn is None or validation_loss_tensor is not None
# Statistics
record_dict = {}
record_dict["epoch_time"] = []
record_dict["train_loss"] = []
if validation_loss_tensor is not None:
record_dict["validation_loss"] = []
best_validation_loss = np.inf
print(datetime.now())
def feed_dict(vals):
return {key: val for key, val in zip(feed_placeholders, vals)}
# Launch the graph
saver = tf.train.Saver()
if load_model_fn is None:
init = tf.global_variables_initializer()
with tf.Session(config=config) as session:
# Start or restore session
if load_model_fn is None:
session.run(init)
else:
saver.restore(session, load_model_fn)
# Train
for i_epoch in xrange(n_epochs):
print("Epoch {}:".format(epoch_offset + i_epoch)),
start_time = timeit.default_timer()
# Train model
train_losses = []
if not isinstance(train_loss_tensor, (list, tuple)):
for cur_feed in train_feed_iterator:
_, cur_loss = session.run(
[optimizer, train_loss_tensor],
feed_dict=feed_dict(cur_feed)
)
train_losses.append(cur_loss)
train_loss = np.mean(train_losses)
else:
for cur_feed in train_feed_iterator:
cur_loss = session.run(
[optimizer] + train_loss_tensor,
feed_dict=feed_dict(cur_feed)
)
cur_loss.pop(0) # remove the optimizer
cur_loss = np.array(cur_loss)
train_losses.append(cur_loss)
train_loss = np.mean(train_losses, axis=0)
record_dict["train_loss"].append((i_epoch, train_loss))
# Validation model
if validation_loss_tensor is not None:
validation_losses = []
if not isinstance(validation_loss_tensor, (list, tuple)):
for cur_feed in validation_feed_iterator:
cur_loss = session.run(
[validation_loss_tensor],
feed_dict=feed_dict(cur_feed)
)
validation_losses.append(cur_loss)
validation_loss = np.mean(validation_losses)
cur_validation_loss = validation_loss
else:
for cur_feed in validation_feed_iterator:
cur_loss = session.run(
validation_loss_tensor,
feed_dict=feed_dict(cur_feed)
)
cur_loss = np.array(cur_loss)
validation_losses.append(cur_loss)
validation_loss = np.mean(validation_losses, axis=0)
cur_validation_loss = validation_loss[-1]
record_dict["validation_loss"].append(
(i_epoch, validation_loss)
)
# Statistics
end_time = timeit.default_timer()
epoch_time = end_time - start_time
record_dict["epoch_time"].append((i_epoch, epoch_time))
log = "{:.3f} sec".format(epoch_time)
log += ", train loss: " + str(train_loss)
if validation_loss is not None:
log += ", val loss: " + str(validation_loss)
if (save_best_val_model_fn is not None and cur_validation_loss <
best_validation_loss):
saver.save(session, save_best_val_model_fn)
best_validation_loss = cur_validation_loss
log += " *"
print(log)
sys.stdout.flush()
if save_model_fn is not None:
print("Writing: {}".format(save_model_fn))
saver.save(session, save_model_fn)
total_time = sum([i[1] for i in record_dict["epoch_time"]])
print("Training time: {:.3f} min".format(total_time/60.))
print(datetime.now())
return record_dict