Skip to content

Commit

Permalink
aerial footage superresolution works + fixed blur issue
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhsin Fatih Yorulmaz committed Dec 18, 2019
1 parent 9a2f337 commit 9804a86
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 18 deletions.
1 change: 1 addition & 0 deletions config2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
virat = new(dir = os.path.join(data_dir, 'VIRAT/Public Dataset/VIRAT Video Dataset Release 2.0/'))
virat.aerial = new(dir = os.path.join(virat.dir, 'VIRAT Aerial Dataset'))
virat.ground = new(dir = os.path.join(virat.dir, 'VIRAT Ground Dataset'))
virat.aerial.video = new(dir = os.path.join(virat.aerial.dir, 'videos'))
virat.ground.video = new(dir = os.path.join(virat.ground.dir, 'videos_original'))
14 changes: 7 additions & 7 deletions srgan/job_script.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
CONTAINER=container/sandboxdir
sing="singularity exec --nv $CONTAINER "
(\
export CUDA_VISIBLE_DEVICES=1
JOBNAME=input48_interleaved_long_blurfixed
$sing python3 -u 'train.py' --inputsize 48 --exp $JOBNAME >> outputs/$JOBNAME.log 2>&1
) &
# (\
# export CUDA_VISIBLE_DEVICES=0
# JOBNAME=input48_interleaved_no_input_upsample
# $sing python3 -u 'train.py' --inputsize 48 --exp $JOBNAME >> outputs/$JOBNAME.log 2>&1
# JOBNAME=input96_interleaved
# $sing python3 -u 'train.py' --inputsize 96 --exp $JOBNAME >> outputs/$JOBNAME.log 2>&1
# ) &
(\
export CUDA_VISIBLE_DEVICES=0
JOBNAME=input96_interleaved
$sing python3 -u 'train.py' --inputsize 96 --exp $JOBNAME >> outputs/$JOBNAME.log 2>&1
) &

wait
31 changes: 22 additions & 9 deletions srgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ def _map_fn_downsample_centercrop(img):
lr_patch = tf.image.resize(lr_patch, size=[140, 140]) # re-upsample if it was lower than this
return lr_patch, hr_patch

def get_train_data():
videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
def get_train_data(aerial=False):
ds_source = virat.aerial if aerial else virat.ground
videoPaths = np.array(glob2.glob(ds_source.video.dir + '/*.mp4') + glob2.glob(ds_source.video.dir + '/*.mpg'))
# videoPaths = np.array([ds_source.video.dir + '/09152008flight2tape3_9.mpg'])
generator = videodataset.FrameGeneratorInterleaved(videoPaths, iteration_size)

train_ds = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32))
Expand All @@ -128,10 +130,16 @@ def get_train_data():
train_ds = train_ds.batch(batch_size)
# value = train_ds.make_one_shot_iterator().get_next()

test_generator = videodataset.FrameGeneratorInterleaved(videoPaths, iteration_size, isTest=True)
if aerial:
test_generator = videodataset.FrameGenerator_sequential(videoPaths)
else:
test_generator = videodataset.FrameGeneratorInterleaved(videoPaths, iteration_size, isTest=True)
test_ds = tf.data.Dataset.from_generator(generator.call, output_types=(tf.float32))
test_ds = test_ds.map(_map_fn_preprocess, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.map(_map_fn_downsample, num_parallel_calls=AUTOTUNE)
if aerial:
test_ds = test_ds.map(lambda img: (img,img), num_parallel_calls=AUTOTUNE) # lowres = highres
else:
test_ds = test_ds.map(_map_fn_downsample, num_parallel_calls=AUTOTUNE)
test_ds = test_ds.prefetch(AUTOTUNE)
test_ds = test_ds.batch(batch_size)

Expand All @@ -141,7 +149,10 @@ def get_train_data():
sample_ds = path_ds.map(_map_fn_path2img, num_parallel_calls=AUTOTUNE)
sample_ds = sample_ds.map(_map_fn_preprocess, num_parallel_calls=AUTOTUNE)

return train_ds, test_ds, sample_ds
if aerial:
return None, test_ds, None
else:
return train_ds, test_ds, sample_ds

def train():
size = [1080, 1920]
Expand Down Expand Up @@ -263,7 +274,7 @@ def train():

def __evaluate(ds, eval_out_path, filenames=None):
G = get_G([1, None, None, 3])
G.load_weights(os.path.join(checkpoint_dir, 'g_epoch_540.h5'))
G.load_weights(os.path.join(checkpoint_dir, 'g.h5'))
G.eval()
sample_folders = ['lr', 'hr', 'gen', 'bicubic', 'combined']
for sample_folder in sample_folders:
Expand Down Expand Up @@ -297,12 +308,12 @@ def __evaluate(ds, eval_out_path, filenames=None):
out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
tl.vis.save_image(out_bicu, os.path.join(eval_out_path, 'bicubic', filename))
# tl.vis.save_images(np.array([valid_lr_img[0], np.array(out_bicu), out[0]]), [1,3], os.path.join(eval_out_path, 'combined', f'valid_bicu_{i}.jpg'))
def other():
def other(aerial=False):

if 1:
train_ds, test_ds, sample_ds = get_train_data()
train_ds, test_ds, sample_ds = get_train_data(aerial=True)
test_ds = test_ds.unbatch()
candidate_dir = os.path.join(save_dir, 'handPickCandidates_jpg_new')
candidate_dir = os.path.join(save_dir, 'aerial')
tl.files.exists_or_mkdir(candidate_dir)
test_ds = test_ds.take(1000)
__evaluate(test_ds, candidate_dir)
Expand Down Expand Up @@ -361,5 +372,7 @@ def evaluate():
evaluate()
elif tl.global_flag['mode'] == 'other':
other()
elif tl.global_flag['mode'] == 'aerial':
other(aerial=True)
else:
raise Exception("Unknow --mode")
28 changes: 26 additions & 2 deletions videodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def call(self):
# 'frame': vid.read()[1]
# }
img = vid.read()[1]
img = cv2.blur(img,(5,5))
# img = cv2.blur(img,(5,5))
yield img
if not self.isTest:
self.i_vid = (self.i_vid + 1) % (len(self.ind_vid))
Expand Down Expand Up @@ -94,14 +94,38 @@ def call(self):
vid.set(cv2.CAP_PROP_POS_FRAMES, self.idx_frame[i_vid][i_frame]) # set video to this frame

img = vid.read()[1]
img = cv2.blur(img,(5,5))
# img = cv2.blur(img,(5,5))
if self.isTest:
self.i_frame[i_vid] = ((self.i_frame[i_vid] + 15 + 1) % 15) - 15 # last 15 images are reserved for test
else:
self.i_frame[i_vid] = (self.i_frame[i_vid] + 1) % (len(self.i_frame) - 20) # last 15 images are reserved for test
self.i_vid =(self.i_vid + 1) % (len(self.ind_vid))
yield img

class FrameGenerator_sequential():
def __init__(self, videoPaths):
self.videoPaths = videoPaths
print('getting vids')
self.videos = [cv2.VideoCapture(path) for path in videoPaths]
self.totalFrames = np.array([vid.get(cv2.CAP_PROP_FRAME_COUNT) for vid in self.videos]).astype(np.int)
self.i_vid = 0

def call(self):
i_vid = self.i_vid
vid = self.videos[i_vid]
n_frames = self.totalFrames[i_vid]
for i in range(n_frames):
vid.set(cv2.CAP_PROP_POS_FRAMES, i) # set video to this frame
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
img = vid.read()[1]
# img = cv2.blur(img,(5,5))
yield img
if not self.isTest:
self.i_vid = (self.i_vid + 1) % n_frames

if __name__ == "__main__":
videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
Expand Down

0 comments on commit 9804a86

Please sign in to comment.