Skip to content

Commit

Permalink
srgan evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhsin Fatih Yorulmaz committed Dec 4, 2019
1 parent eb2c468 commit 2aa4ea8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
28 changes: 23 additions & 5 deletions srgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,22 +164,40 @@ def evaluate():
###====================== PRE-LOAD DATA ===========================###
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))
# valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
# valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

## if your machine have enough memory, please pre-load the whole train set.
# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
# valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
# for im in valid_lr_imgs:
# print(im.shape)
valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
# valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)
# for im in valid_hr_imgs:
# print(im.shape)

videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
generator = videodataset.FrameGenerator(videoPaths, iteration_size=12)

def _map_fn_train(img):
hr_patch = tf.image.random_crop(img, [384, 384, 3])
hr_patch = hr_patch / (255. / 2.)
hr_patch = hr_patch - 1.
hr_patch = tf.image.random_flip_left_right(hr_patch)
lr_patch = tf.image.resize(hr_patch, size=[96, 96])
return lr_patch, hr_patch

ds_lowres = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32))
ds_highres = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32))
ds_lowres = ds_lowres.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
valid_hr_imgs = [next(iter(ds_highres))[1].numpy()] #[img[1].numpy() for img in next(iter(ds_highres.take(65)))]
valid_lr_imgs = [next(iter(ds_lowres))[1].numpy()] #[img[1].numpy() for img in next(iter(ds_lowres.take(65)))]
print('valid_lr_imgs: ', valid_lr_imgs)

###========================== DEFINE MODEL ============================###
imid = 64 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
imid = 0 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
valid_lr_img = valid_lr_imgs[imid]
valid_hr_img = valid_hr_imgs[imid]
# valid_lr_img = get_imgs_fn('test.png', 'data2017/') # if you want to test your own image
Expand Down
26 changes: 13 additions & 13 deletions videodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ def __init__(self, videoPaths, iteration_size):
self.totalFrames = np.array([vid.get(cv2.CAP_PROP_FRAME_COUNT) for vid in self.videos]).astype(np.int)
self.iteration_size = iteration_size
def call(self):
# while True:
# let's just use the first video for testing:
i_vid = 0
vid = self.videos[i_vid]
idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True)
for i in range(self.iteration_size):
vid.set(cv2.CAP_PROP_POS_FRAMES, idx_frame[i]) # set video to this frame
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
yield vid.read()[1]
while True:
# let's just use the first video for testing:
i_vid = 0
vid = self.videos[i_vid]
idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True)
for i in range(self.iteration_size):
vid.set(cv2.CAP_PROP_POS_FRAMES, idx_frame[i]) # set video to this frame
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
yield vid.read()[1]

if __name__ == "__main__":
videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
Expand Down

0 comments on commit 2aa4ea8

Please sign in to comment.