forked from wiibrew/pytorch-yolo2
-
Notifications
You must be signed in to change notification settings - Fork 8
/
dataset.py
92 lines (78 loc) · 3.05 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/python
# encoding: utf-8
import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from utils import read_truths_args, read_truths
from image import *
class listDataset(Dataset):
def __init__(self, root, shape=None, shuffle=True, transform=None, target_transform=None, train=False, seen=0, batch_size=64, num_workers=4):
with open(root, 'r') as file:
self.lines = file.readlines()
if shuffle:
random.shuffle(self.lines)
self.nSamples = len(self.lines)
self.transform = transform
self.target_transform = target_transform
self.train = train
self.shape = shape
self.seen = seen
self.batch_size = batch_size
self.num_workers = num_workers
def __len__(self):
return self.nSamples
def __getitem__(self, index):
assert index <= len(self), 'index range error'
imgpath = self.lines[index].rstrip()
if self.train and index % 64== 0:
if self.seen < 4000*64:
width = 13*32
self.shape = (width, width)
elif self.seen < 8000*64:
width = (random.randint(0,3) + 13)*32
self.shape = (width, width)
elif self.seen < 12000*64:
width = (random.randint(0,5) + 12)*32
self.shape = (width, width)
elif self.seen < 16000*64:
width = (random.randint(0,7) + 11)*32
self.shape = (width, width)
else: # self.seen < 20000*64:
width = (random.randint(0,9) + 10)*32
self.shape = (width, width)
if self.train:
jitter = 0.2
hue = 0.1
saturation = 1.5
exposure = 1.5
img, label = load_data_detection(imgpath, self.shape, jitter, hue, saturation, exposure)
label = torch.from_numpy(label)
else:
img = Image.open(imgpath).convert('RGB')
if self.shape:
img = img.resize(self.shape)
labpath = imgpath.replace('images', 'labels').replace('JPEGImages', 'labels').replace('.jpg', '.txt').replace('.png','.txt')
label = torch.zeros(50*5)
#if os.path.getsize(labpath):
#tmp = torch.from_numpy(np.loadtxt(labpath))
try:
tmp = torch.from_numpy(read_truths_args(labpath, 8.0/img.width).astype('float32'))
except Exception:
tmp = torch.zeros(1,5)
#tmp = torch.from_numpy(read_truths(labpath))
tmp = tmp.view(-1)
tsz = tmp.numel()
#print('labpath = %s , tsz = %d' % (labpath, tsz))
if tsz > 50*5:
label = tmp[0:50*5]
elif tsz > 0:
label[0:tsz] = tmp
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
self.seen = self.seen + self.num_workers
return (img, label)