diff --git a/config2.py b/config2.py new file mode 100644 index 0000000..37176e6 --- /dev/null +++ b/config2.py @@ -0,0 +1,9 @@ +import os +from utils import * + +data_dir = os.path.join(os.path.dirname(__file__), 'data') + +virat = new(dir = os.path.join(data_dir, 'VIRAT/Public Dataset/VIRAT Video Dataset Release 2.0/')) +virat.aerial = new(dir = os.path.join(virat.dir, 'VIRAT Aerial Dataset')) +virat.ground = new(dir = os.path.join(virat.dir, 'VIRAT Ground Dataset')) +virat.ground.video = new(dir = os.path.join(virat.ground.dir, 'videos_original')) diff --git a/srgan/train.py b/srgan/train.py index 399b976..1f77212 100755 --- a/srgan/train.py +++ b/srgan/train.py @@ -1,5 +1,6 @@ #! /usr/bin/python # -*- coding: utf8 -*- +#%% import os import time @@ -11,6 +12,12 @@ from model import get_G, get_D from config import config +import sys +sys.path.append('..') +import videodataset +from config2 import * +import glob2 + ###====================== HYPER-PARAMETERS ===========================### ## Adam batch_size = config.TRAIN.batch_size # use 8 if your GPU memory is small, and change [4, 4] in tl.vis.save_images to [2, 4] @@ -34,13 +41,13 @@ def get_train_data(): # load dataset - train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20] + # train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20] # train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False)) # valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False)) # valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False)) ## If your machine have enough memory, please pre-load the entire train set. - train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) + # train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32) # for im in train_hr_imgs: # print(im.shape) # valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32) @@ -51,9 +58,15 @@ def get_train_data(): # print(im.shape) # dataset API and augmentation - def generator_train(): - for img in train_hr_imgs: - yield img + # def generator_train(): + # for img in train_hr_imgs: + # yield img + + videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4')) + generator = videodataset.FrameGenerator(videoPaths) + + + def _map_fn_train(img): hr_patch = tf.image.random_crop(img, [384, 384, 3]) hr_patch = hr_patch / (255. / 2.) @@ -63,7 +76,10 @@ def _map_fn_train(img): return lr_patch, hr_patch - train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32)) + train_ds = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32)) + # train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32)) + # print(next(iter(train_ds)).numpy()) + # return train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) # train_ds = train_ds.repeat(n_epoch_init + n_epoch) train_ds = train_ds.shuffle(shuffle_buffer_size) diff --git a/videodataset.py b/videodataset.py index 051a1f0..4028c28 100644 --- a/videodataset.py +++ b/videodataset.py @@ -34,11 +34,12 @@ def call(self): idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True) for i_frame in idx_frame: vid.set(cv2.CAP_PROP_POS_FRAMES, i_frame) # set video to this frame - yield { - 'video_index': i_vid, - 'video_path': self.videoPaths[i_vid], - 'frame': vid.read()[1] - } + # yield { + # 'video_index': i_vid, + # 'video_path': self.videoPaths[i_vid], + # 'frame': vid.read()[1] + # } + yield vid.read()[1] if __name__ == "__main__": videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))