-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathmodel.py
130 lines (99 loc) · 4.37 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from keras.models import Model
from keras.layers import Input, Add, PReLU, Conv2DTranspose, Concatenate, MaxPooling2D, UpSampling2D, Dropout
from keras.layers.convolutional import Conv2D
from keras.layers.normalization import BatchNormalization
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
class L0Loss:
def __init__(self):
self.gamma = K.variable(2.)
def __call__(self):
def calc_loss(y_true, y_pred):
loss = K.pow(K.abs(y_true - y_pred) + 1e-8, self.gamma)
return loss
return calc_loss
class UpdateAnnealingParameter(Callback):
def __init__(self, gamma, nb_epochs, verbose=0):
super(UpdateAnnealingParameter, self).__init__()
self.gamma = gamma
self.nb_epochs = nb_epochs
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
new_gamma = 2.0 * (self.nb_epochs - epoch) / self.nb_epochs
K.set_value(self.gamma, new_gamma)
if self.verbose > 0:
print('\nEpoch %05d: UpdateAnnealingParameter reducing gamma to %s.' % (epoch + 1, new_gamma))
def tf_log10(x):
numerator = tf.log(x)
denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
def PSNR(y_true, y_pred):
max_pixel = 255.0
y_pred = K.clip(y_pred, 0.0, 255.0)
return 10.0 * tf_log10((max_pixel ** 2) / (K.mean(K.square(y_pred - y_true))))
def get_model(model_name="srresnet"):
if model_name == "srresnet":
return get_srresnet_model()
elif model_name == "unet":
return get_unet_model(out_ch=3)
else:
raise ValueError("model_name should be 'srresnet'or 'unet'")
# SRResNet
def get_srresnet_model(input_channel_num=3, feature_dim=64, resunit_num=16):
def _residual_block(inputs):
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
x = BatchNormalization()(x)
x = PReLU(shared_axes=[1, 2])(x)
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(x)
x = BatchNormalization()(x)
m = Add()([x, inputs])
return m
inputs = Input(shape=(None, None, input_channel_num))
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(inputs)
x = PReLU(shared_axes=[1, 2])(x)
x0 = x
for i in range(resunit_num):
x = _residual_block(x)
x = Conv2D(feature_dim, (3, 3), padding="same", kernel_initializer="he_normal")(x)
x = BatchNormalization()(x)
x = Add()([x, x0])
x = Conv2D(input_channel_num, (3, 3), padding="same", kernel_initializer="he_normal")(x)
model = Model(inputs=inputs, outputs=x)
return model
# UNet: code from https://github.com/pietz/unet-keras
def get_unet_model(input_channel_num=3, out_ch=3, start_ch=64, depth=4, inc_rate=2., activation='relu',
dropout=0.5, batchnorm=False, maxpool=True, upconv=True, residual=False):
def _conv_block(m, dim, acti, bn, res, do=0):
n = Conv2D(dim, 3, activation=acti, padding='same')(m)
n = BatchNormalization()(n) if bn else n
n = Dropout(do)(n) if do else n
n = Conv2D(dim, 3, activation=acti, padding='same')(n)
n = BatchNormalization()(n) if bn else n
return Concatenate()([m, n]) if res else n
def _level_block(m, dim, depth, inc, acti, do, bn, mp, up, res):
if depth > 0:
n = _conv_block(m, dim, acti, bn, res)
m = MaxPooling2D()(n) if mp else Conv2D(dim, 3, strides=2, padding='same')(n)
m = _level_block(m, int(inc * dim), depth - 1, inc, acti, do, bn, mp, up, res)
if up:
m = UpSampling2D()(m)
m = Conv2D(dim, 2, activation=acti, padding='same')(m)
else:
m = Conv2DTranspose(dim, 3, strides=2, activation=acti, padding='same')(m)
n = Concatenate()([n, m])
m = _conv_block(n, dim, acti, bn, res)
else:
m = _conv_block(m, dim, acti, bn, res, do)
return m
i = Input(shape=(None, None, input_channel_num))
o = _level_block(i, start_ch, depth, inc_rate, activation, dropout, batchnorm, maxpool, upconv, residual)
o = Conv2D(out_ch, 1)(o)
model = Model(inputs=i, outputs=o)
return model
def main():
# model = get_model()
model = get_model("unet")
model.summary()
if __name__ == '__main__':
main()