forked from tensorlayer/SRGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·206 lines (177 loc) · 9.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#! /usr/bin/python
# -*- coding: utf8 -*-
import os
import time
import random
import numpy as np
import scipy, multiprocessing
import tensorflow as tf
import tensorlayer as tl
from model import get_G, get_D
from config import config
###====================== HYPER-PARAMETERS ===========================###
## Adam
batch_size = config.TRAIN.batch_size # use 8 if your GPU memory is small, and change [4, 4] in tl.vis.save_images to [2, 4]
lr_init = config.TRAIN.lr_init
beta1 = config.TRAIN.beta1
## initialize G
n_epoch_init = config.TRAIN.n_epoch_init
## adversarial learning (SRGAN)
n_epoch = config.TRAIN.n_epoch
lr_decay = config.TRAIN.lr_decay
decay_every = config.TRAIN.decay_every
shuffle_buffer_size = 128
# ni = int(np.sqrt(batch_size))
# create folders to save result images and trained models
save_dir = "samples"
tl.files.exists_or_mkdir(save_dir)
checkpoint_dir = "models"
tl.files.exists_or_mkdir(checkpoint_dir)
def get_train_data():
# load dataset
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20]
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
# valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
# valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
## If your machine have enough memory, please pre-load the entire train set.
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
# valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
# for im in valid_lr_imgs:
# print(im.shape)
# valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
# for im in valid_hr_imgs:
# print(im.shape)
# dataset API and augmentation
def generator_train():
for img in train_hr_imgs:
yield img
def _map_fn_train(img):
hr_patch = tf.image.random_crop(img, [384, 384, 3])
hr_patch = hr_patch / (255. / 2.)
hr_patch = hr_patch - 1.
hr_patch = tf.image.random_flip_left_right(hr_patch)
lr_patch = tf.image.resize(hr_patch, size=[96, 96])
return lr_patch, hr_patch
train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
# train_ds = train_ds.repeat(n_epoch_init + n_epoch)
train_ds = train_ds.shuffle(shuffle_buffer_size)
train_ds = train_ds.prefetch(buffer_size=2)
train_ds = train_ds.batch(batch_size)
# value = train_ds.make_one_shot_iterator().get_next()
return train_ds
def train():
G = get_G((batch_size, 96, 96, 3))
D = get_D((batch_size, 384, 384, 3))
VGG = tl.models.vgg19(pretrained=True, end_with='pool4', mode='static')
lr_v = tf.Variable(lr_init)
g_optimizer_init = tf.optimizers.Adam(lr_v, beta_1=beta1)
g_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
d_optimizer = tf.optimizers.Adam(lr_v, beta_1=beta1)
G.train()
D.train()
VGG.train()
train_ds = get_train_data()
## initialize learning (G)
n_step_epoch = round(n_epoch_init // batch_size)
for epoch in range(n_epoch_init):
for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
break
step_time = time.time()
with tf.GradientTape() as tape:
fake_hr_patchs = G(lr_patchs)
mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
grad = tape.gradient(mse_loss, G.trainable_weights)
g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss))
if (epoch != 0) and (epoch % 10 == 0):
tl.vis.save_images(fake_hr_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_init_{}.png'.format(epoch)))
## adversarial learning (G, D)
n_step_epoch = round(n_epoch // batch_size)
for epoch in range(n_epoch):
for step, (lr_patchs, hr_patchs) in enumerate(train_ds):
if lr_patchs.shape[0] != batch_size: # if the remaining data in this epoch < batch_size
break
step_time = time.time()
with tf.GradientTape(persistent=True) as tape:
fake_patchs = G(lr_patchs)
logits_fake = D(fake_patchs)
logits_real = D(hr_patchs)
feature_fake = VGG((fake_patchs+1)/2.) # the pre-trained VGG uses the input range of [0, 1]
feature_real = VGG((hr_patchs+1)/2.)
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real))
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake))
d_loss = d_loss1 + d_loss2
g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake))
mse_loss = tl.cost.mean_squared_error(fake_patchs, hr_patchs, is_mean=True)
vgg_loss = 2e-6 * tl.cost.mean_squared_error(feature_fake, feature_real, is_mean=True)
g_loss = mse_loss + vgg_loss + g_gan_loss
grad = tape.gradient(g_loss, G.trainable_weights)
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
grad = tape.gradient(d_loss, D.trainable_weights)
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss(mse:{:.3f}, vgg:{:.3f}, adv:{:.3f}) d_loss: {:.3f}".format(
epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))
# update the learning rate
if epoch != 0 and (epoch % decay_every == 0):
new_lr_decay = lr_decay**(epoch // decay_every)
lr_v.assign(lr_init * new_lr_decay)
log = " ** new learning rate: %f (for GAN)" % (lr_init * new_lr_decay)
print(log)
if (epoch != 0) and (epoch % 10 == 0):
tl.vis.save_images(fake_patchs.numpy(), [2, 4], os.path.join(save_dir, 'train_g_{}.png'.format(epoch)))
G.save_weights(os.path.join(checkpoint_dir, 'g.h5'))
D.save_weights(os.path.join(checkpoint_dir, 'd.h5'))
def evaluate():
###====================== PRE-LOAD DATA ===========================###
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
## if your machine have enough memory, please pre-load the whole train set.
# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
# for im in valid_lr_imgs:
# print(im.shape)
valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
# for im in valid_hr_imgs:
# print(im.shape)
###========================== DEFINE MODEL ============================###
imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
valid_lr_img = valid_lr_imgs[imid]
valid_hr_img = valid_hr_imgs[imid]
# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image
valid_lr_img = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]
# print(valid_lr_img.min(), valid_lr_img.max())
G = get_G([1, None, None, 3])
G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
G.eval()
valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]
out = G(valid_lr_img).numpy()
print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
print("[*] save images")
tl.vis.save_image(out[0], os.path.join(save_dir, 'valid_gen.png'))
tl.vis.save_image(valid_lr_img[0], os.path.join(save_dir, 'valid_lr.png'))
tl.vis.save_image(valid_hr_img, os.path.join(save_dir, 'valid_hr.png'))
out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
tl.vis.save_image(out_bicu, os.path.join(save_dir, 'valid_bicubic.png'))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--mode', type=str, default='srgan', help='srgan, evaluate')
args = parser.parse_args()
tl.global_flag['mode'] = args.mode
if tl.global_flag['mode'] == 'srgan':
train()
elif tl.global_flag['mode'] == 'evaluate':
evaluate()
else:
raise Exception("Unknow --mode")