-
Notifications
You must be signed in to change notification settings - Fork 250
/
charades_dataset_full.py
123 lines (96 loc) · 3.58 KB
/
charades_dataset_full.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import torch.utils.data as data_utl
from torch.utils.data.dataloader import default_collate
import numpy as np
import json
import csv
import h5py
import os
import os.path
import cv2
def video_to_tensor(pic):
"""Convert a ``numpy.ndarray`` to tensor.
Converts a numpy.ndarray (T x H x W x C)
to a torch.FloatTensor of shape (C x T x H x W)
Args:
pic (numpy.ndarray): Video to be converted to tensor.
Returns:
Tensor: Converted video.
"""
return torch.from_numpy(pic.transpose([3,0,1,2]))
def load_rgb_frames(image_dir, vid, start, num):
frames = []
for i in range(start, start+num):
img = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'.jpg'))[:, :, [2, 1, 0]]
w,h,c = img.shape
if w < 226 or h < 226:
d = 226.-min(w,h)
sc = 1+d/min(w,h)
img = cv2.resize(img,dsize=(0,0),fx=sc,fy=sc)
img = (img/255.)*2 - 1
frames.append(img)
return np.asarray(frames, dtype=np.float32)
def load_flow_frames(image_dir, vid, start, num):
frames = []
for i in range(start, start+num):
imgx = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'x.jpg'), cv2.IMREAD_GRAYSCALE)
imgy = cv2.imread(os.path.join(image_dir, vid, vid+'-'+str(i).zfill(6)+'y.jpg'), cv2.IMREAD_GRAYSCALE)
w,h = imgx.shape
if w < 224 or h < 224:
d = 224.-min(w,h)
sc = 1+d/min(w,h)
imgx = cv2.resize(imgx,dsize=(0,0),fx=sc,fy=sc)
imgy = cv2.resize(imgy,dsize=(0,0),fx=sc,fy=sc)
imgx = (imgx/255.)*2 - 1
imgy = (imgy/255.)*2 - 1
img = np.asarray([imgx, imgy]).transpose([1,2,0])
frames.append(img)
return np.asarray(frames, dtype=np.float32)
def make_dataset(split_file, split, root, mode, num_classes=157):
dataset = []
with open(split_file, 'r') as f:
data = json.load(f)
i = 0
for vid in data.keys():
if data[vid]['subset'] != split:
continue
if not os.path.exists(os.path.join(root, vid)):
continue
num_frames = len(os.listdir(os.path.join(root, vid)))
if mode == 'flow':
num_frames = num_frames//2
label = np.zeros((num_classes,num_frames), np.float32)
fps = num_frames/data[vid]['duration']
for ann in data[vid]['actions']:
for fr in range(0,num_frames,1):
if fr/fps > ann[1] and fr/fps < ann[2]:
label[ann[0], fr] = 1 # binary classification
dataset.append((vid, label, data[vid]['duration'], num_frames))
i += 1
return dataset
class Charades(data_utl.Dataset):
def __init__(self, split_file, split, root, mode, transforms=None, save_dir='', num=0):
self.data = make_dataset(split_file, split, root, mode)
self.split_file = split_file
self.transforms = transforms
self.mode = mode
self.root = root
self.save_dir = save_dir
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
vid, label, dur, nf = self.data[index]
if os.path.exists(os.path.join(self.save_dir, vid+'.npy')):
return 0, 0, vid
if self.mode == 'rgb':
imgs = load_rgb_frames(self.root, vid, 1, nf)
else:
imgs = load_flow_frames(self.root, vid, 1, nf)
imgs = self.transforms(imgs)
return video_to_tensor(imgs), torch.from_numpy(label), vid
def __len__(self):
return len(self.data)