Skip to content

Commit

Permalink
epochs weird
Browse files Browse the repository at this point in the history
  • Loading branch information
Muhsin Fatih Yorulmaz committed Dec 3, 2019
1 parent cb962e2 commit 33515ec
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 24 deletions.
8 changes: 0 additions & 8 deletions config.py

This file was deleted.

2 changes: 1 addition & 1 deletion srgan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_train_data():
# yield img

videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
generator = videodataset.FrameGenerator(videoPaths)
generator = videodataset.FrameGenerator(videoPaths, iteration_size=12)



Expand Down
30 changes: 15 additions & 15 deletions videodataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,25 @@ def _preprocess_image(feature):


class FrameGenerator():
def __init__(self, videoPaths):
def __init__(self, videoPaths, iteration_size):
self.videoPaths = videoPaths
self.videos = [cv2.VideoCapture(path) for path in videoPaths]
self.totalFrames = np.array([vid.get(cv2.CAP_PROP_FRAME_COUNT) for vid in self.videos]).astype(np.int)

self.iteration_size = iteration_size
def call(self):
while True:
# let's just use the first video for testing:
i_vid = 0
vid = self.videos[i_vid]
idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True)
for i_frame in idx_frame:
vid.set(cv2.CAP_PROP_POS_FRAMES, i_frame) # set video to this frame
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
yield vid.read()[1]
# while True:
# let's just use the first video for testing:
i_vid = 0
vid = self.videos[i_vid]
idx_frame = np.random.choice(self.totalFrames[i_vid], self.totalFrames[i_vid], replace=True)
for i in range(self.iteration_size):
vid.set(cv2.CAP_PROP_POS_FRAMES, idx_frame[i]) # set video to this frame
# yield {
# 'video_index': i_vid,
# 'video_path': self.videoPaths[i_vid],
# 'frame': vid.read()[1]
# }
yield vid.read()[1]

if __name__ == "__main__":
videoPaths = np.array(glob2.glob(virat.ground.video.dir + '/*.mp4'))
Expand Down

0 comments on commit 33515ec

Please sign in to comment.