This repository has been archived by the owner on Mar 23, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 145
/
callback.py
168 lines (151 loc) · 6.37 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import json
import keras.backend as kb
import numpy as np
import os
import shutil
import warnings
from keras.callbacks import Callback
from sklearn.metrics import roc_auc_score
class MultipleClassAUROC(Callback):
"""
Monitor mean AUROC and update model
"""
def __init__(self, sequence, class_names, weights_path, stats=None, workers=1):
super(Callback, self).__init__()
self.sequence = sequence
self.workers = workers
self.class_names = class_names
self.weights_path = weights_path
self.best_weights_path = os.path.join(
os.path.split(weights_path)[0],
f"best_{os.path.split(weights_path)[1]}",
)
self.best_auroc_log_path = os.path.join(
os.path.split(weights_path)[0],
"best_auroc.log",
)
self.stats_output_path = os.path.join(
os.path.split(weights_path)[0],
".training_stats.json"
)
# for resuming previous training
if stats:
self.stats = stats
else:
self.stats = {"best_mean_auroc": 0}
# aurocs log
self.aurocs = {}
for c in self.class_names:
self.aurocs[c] = []
def on_epoch_end(self, epoch, logs={}):
"""
Calculate the average AUROC and save the best model weights according
to this metric.
"""
print("\n*********************************")
self.stats["lr"] = float(kb.eval(self.model.optimizer.lr))
print(f"current learning rate: {self.stats['lr']}")
"""
y_hat shape: (#samples, len(class_names))
y: [(#samples, 1), (#samples, 1) ... (#samples, 1)]
"""
y_hat = self.model.predict_generator(self.sequence, workers=self.workers)
y = self.sequence.get_y_true()
print(f"*** epoch#{epoch + 1} dev auroc ***")
current_auroc = []
for i in range(len(self.class_names)):
try:
score = roc_auc_score(y[:, i], y_hat[:, i])
except ValueError:
score = 0
self.aurocs[self.class_names[i]].append(score)
current_auroc.append(score)
print(f"{i+1}. {self.class_names[i]}: {score}")
print("*********************************")
# customize your multiple class metrics here
mean_auroc = np.mean(current_auroc)
print(f"mean auroc: {mean_auroc}")
if mean_auroc > self.stats["best_mean_auroc"]:
print(f"update best auroc from {self.stats['best_mean_auroc']} to {mean_auroc}")
# 1. copy best model
shutil.copy(self.weights_path, self.best_weights_path)
# 2. update log file
print(f"update log file: {self.best_auroc_log_path}")
with open(self.best_auroc_log_path, "a") as f:
f.write(f"(epoch#{epoch + 1}) auroc: {mean_auroc}, lr: {self.stats['lr']}\n")
# 3. write stats output, this is used for resuming the training
with open(self.stats_output_path, 'w') as f:
json.dump(self.stats, f)
print(f"update model file: {self.weights_path} -> {self.best_weights_path}")
self.stats["best_mean_auroc"] = mean_auroc
print("*********************************")
return
class MultiGPUModelCheckpoint(Callback):
"""
Checkpointing callback for multi_gpu_model
copy from https://github.com/keras-team/keras/issues/8463
"""
def __init__(self, filepath, base_model, monitor='val_loss', verbose=0,
save_best_only=False, save_weights_only=False,
mode='auto', period=1):
super(Callback, self).__init__()
self.base_model = base_model
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_save = 0
if mode not in ['auto', 'min', 'max']:
warnings.warn('ModelCheckpoint mode %s is unknown, '
'fallback to auto mode.' % (mode),
RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
self.best = np.Inf
elif mode == 'max':
self.monitor_op = np.greater
self.best = -np.Inf
else:
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
self.monitor_op = np.greater
self.best = -np.Inf
else:
self.monitor_op = np.less
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
filepath = self.filepath.format(epoch=epoch + 1, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn('Can save best model only with %s available, '
'skipping.' % (self.monitor), RuntimeWarning)
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
print('Epoch %05d: %s improved from %0.5f to %0.5f,'
' saving model to %s'
% (epoch + 1, self.monitor, self.best,
current, filepath))
self.best = current
if self.save_weights_only:
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.save(filepath, overwrite=True)
else:
if self.verbose > 0:
print('Epoch %05d: %s did not improve' %
(epoch + 1, self.monitor))
else:
if self.verbose > 0:
print('Epoch %05d: saving model to %s' % (epoch + 1, filepath))
if self.save_weights_only:
self.base_model.save_weights(filepath, overwrite=True)
else:
self.base_model.save(filepath, overwrite=True)