-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
132 lines (114 loc) · 4.19 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from torch.utils.data import DataLoader, Dataset
import os
import torch
import torchvision
from torchvision import transforms
import numbers
import numpy as np
import mxnet as mx
import threading
import sys
import queue as Queue
class BackgroundGenerator(threading.Thread):
def __init__(self, generator,local_rank ,max_prefetch=6):
super(BackgroundGenerator,self).__init__()
self.queue = Queue.Queue(max_prefetch)
self.generator = generator
self.local_rank = local_rank
self.daemon = True
self.start()
def run(self):
torch.cuda.set_device(self.local_rank)
for item in self.generator:
self.queue.put(item)
self.queue.put(None)
def next(self):
next_item = self.queue.get()
if next_item is None:
raise StopIteration
return next_item
def __next__(self):
return self.next()
def __iter__(self):
return self
class DataLoaderX(DataLoader):
def __init__(self,local_rank, **kwargs):
super(DataLoaderX, self).__init__(**kwargs)
self.stream = torch.cuda.Stream(local_rank)
self.local_rank = local_rank
def __iter__(self):
self.iter = super(DataLoaderX, self).__iter__()
self.iter = BackgroundGenerator(self.iter, self.local_rank)
self.preload()
return self
def preload(self):
self.batch = next(self.iter,None)
if self.batch is None:
return None
with torch.cuda.stream(self.stream):
for k in range(len(self.batch)):
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
def __next__(self):
torch.cuda.current_stream().wait_stream(self.stream)
batch = self.batch
if batch is None:
raise StopIteration
self.preload()
return batch
TFS=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5]),
])
class MXFaceDataset (Dataset):
def __init__(self, root_dir, local_rank, transform=TFS):
super(MXFaceDataset, self).__init__()
self.transform = transform
self.root_dir = root_dir
self.local_rank = local_rank
path_imgrec = os.path.join(root_dir, 'train.rec')
path_imgidx = os.path.join(root_dir, 'train.idx')
self.imgrec = mx.recordio.MXIndexedRecordIO(
path_imgidx, path_imgrec, 'r')
s = self.imgrec.read_idx(0)
header, _ = mx.recordio.unpack(s)
if header.flag > 0:
print('header0 label', header.label)
self.header0 = (int(header.label[0]), int(header.label[1]))
self.imgidx = np.array(range(1, int(header.label[0])))
else:
self.imgidx = np.array(list(self.imgrec.keys))
print("Number of Samples:{}". format(len(self.imgidx)))
def __getitem__(self, index):
# index =0
idx = self.imgidx[index]
s = self.imgrec.read_idx(idx)
header, img = mx.recordio.unpack(s)
label = header.label
if not isinstance(label, numbers.Number):
label = label[0]
label = torch.tensor(label, dtype=torch.long)
sample = mx.image.imdecode(img).asnumpy()
if self.transform is not None:
sample = self.transform(sample)
return sample,label
def __len__(self):
return len(self.imgidx)
if __name__ == "__main__":
# /root/xy/face/faces_emore
# /root/face_datasets/webface/
trainset = MXFaceDataset(root_dir='/root/xy/face/faces_emore',local_rank=0)
torch.cuda.set_device(2)
trainloader = DataLoaderX(local_rank=0,
dataset=trainset, batch_size=128, #sampler=train_sampler,
num_workers=0,pin_memory=True,drop_last=False)
print(len(trainset))
for step,(img,label) in enumerate(trainloader):
# print(img.max(),img.min())
img=img.cuda(non_blocking=True)
label=label.cuda(non_blocking=True)
print(img,label,label.device)
del img,label
# img=torchvision.utils.make_grid(img,nrow=8)
# torchvision.utils.save_image(img,'tmp.jpg')
# print(label)
# break;