Skip to content

Commit

Permalink
training works
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhsin Fatih Yorulmaz committed Dec 4, 2019
1 parent 3d10f17 commit 3571732
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 11 deletions.
9 changes: 9 additions & 0 deletions config2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import os
from utils import *

data_dir = os.path.join(os.path.dirname(__file__), 'data')

virat = new(dir = os.path.join(data_dir, 'VIRAT/Public Dataset/VIRAT Video Dataset Release 2.0/'))
virat.aerial = new(dir = os.path.join(virat.dir, 'VIRAT Aerial Dataset'))
virat.ground = new(dir = os.path.join(virat.dir, 'VIRAT Ground Dataset'))
virat.ground.video = new(dir = os.path.join(virat.ground.dir, 'videos_original'))
28 changes: 22 additions & 6 deletions srgan/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#! /usr/bin/python
# -*- coding: utf8 -*-
#%%

import os
import time
Expand All @@ -11,6 +12,12 @@
from model import get_G, get_D
from config import config

import sys
sys.path.append('..')
import videodataset
from config2 import *
import glob2

###====================== HYPER-PARAMETERS ===========================###
## Adam
batch_size = config.TRAIN.batch_size # use 8 if your GPU memory is small, and change [4, 4] in tl.vis.save_images to [2, 4]
Expand All @@ -34,13 +41,13 @@

def get_train_data():
# load dataset
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20]
# train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))#[0:20]
# train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
# valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))
# valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))

## If your machine have enough memory, please pre-load the entire train set.
train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# train_hr_imgs = tl.vis.read_images(train_hr_img_list, path=config.TRAIN.hr_img_path, n_threads=32)
# for im in train_hr_imgs:
# print(im.shape)
# valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)
Expand All @@ -51,9 +58,15 @@ def get_train_data():
# print(im.shape)

# dataset API and augmentation
def generator_train():
for img in train_hr_imgs:
yield img
# def generator_train():
# for img in train_hr_imgs:
# yield img

videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
generator = videodataset.FrameGenerator(videoPaths)



def _map_fn_train(img):
hr_patch = tf.image.random_crop(img, [384, 384, 3])
hr_patch = hr_patch / (255. / 2.)
Expand All @@ -63,7 +76,10 @@ def _map_fn_train(img):
return lr_patch, hr_patch


train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
train_ds = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32))
# train_ds = tf.data.Dataset.from_generator(generator_train, output_types=(tf.float32))
# print(next(iter(train_ds)).numpy())
# return
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
# train_ds = train_ds.repeat(n_epoch_init + n_epoch)
train_ds = train_ds.shuffle(shuffle_buffer_size)
Expand Down
11 changes: 6 additions & 5 deletions videodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@ def call(self):
idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True)
for i_frame in idx_frame:
vid.set(cv2.CAP_PROP_POS_FRAMES, i_frame) # set video to this frame
yield {
'video_index': i_vid,
'video_path': self.videoPaths[i_vid],
'frame': vid.read()[1]
}
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
yield vid.read()[1]

if __name__ == "__main__":
videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
Expand Down

0 comments on commit 3571732

Please sign in to comment.