-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
93 lines (68 loc) · 2.45 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
from torch.utils.data import Dataset
import pickle
import torch
from collections import deque
import os
class ImitationData(Dataset):
def __init__(self, frames_path):
self.frames_path = frames_path
self.actions = [0, 2, 3]
self.action_path = dict()
self.action_files = dict()
self.action_idx = dict()
for a in self.actions:
if a == 1:
continue
self.action_path[a] = self.frames_path + str(a) + '/'
self.action_files[a] = [self.action_path[a] +
x for x in os.listdir(self.action_path[a])]
self.action_idx[a] = len(self.action_files[a])
self.sample_size = min(self.action_idx.values())
self.datapaths = []
self.data_actions = []
for a in self.actions:
if a == 1:
continue
samples = np.random.choice(
self.action_files[a], size=self.sample_size, replace=False)
samples = list(samples)
a_samples = [a]*len(samples)
self.datapaths.extend(samples)
self.data_actions.extend(a_samples)
self.data_actions = np.array(self.data_actions)
assert len(self.datapaths) == len(self.data_actions)
def __len__(self):
return len(self.datapaths)
def __getitem__(self, idx):
with open(self.datapaths[idx], 'rb') as f:
state = pickle.load(f)
state = state.squeeze(0)
return state, self.data_actions[idx]
class ChannelFirst:
def __call__(self, frame):
# frame: 210 x 160 x 3
frame = torch.from_numpy(frame)
# frame: 210 x 160 x 3
frame = frame.transpose(0, 2).contiguous()
# frame: 3 x 160 x 210
return frame
class FrameStack:
def __init__(self, nframes):
self.nframes = nframes
self.frames = deque([], maxlen=self.nframes)
def _getframes(self):
assert len(self.frames) == self.nframes
frames = torch.cat(tuple(self.frames), dim=0)
frames = frames.unsqueeze(0) # For batch Dimension
return frames
def __call__(self, frame):
frame = frame.type(torch.FloatTensor)
if not len(self.frames):
for _ in range(self.nframes):
self.frames.append(frame)
else:
self.frames.append(frame)
return self._getframes()
def reset(self):
self.frames.clear()