-
Notifications
You must be signed in to change notification settings - Fork 7
/
data.py
93 lines (79 loc) · 3.36 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import numpy as np
from PIL import Image
import cPickle
import nltk
from nltk.tokenize import RegexpTokenizer
import torch
import torch.utils.data as data
from torch.utils.serialization import load_lua
import torchvision.transforms as transforms
def split_sentence_into_words(sentence):
tokenizer = RegexpTokenizer(r'\w+')
return tokenizer.tokenize(sentence.lower())
class ReedICML2016(data.Dataset):
def __init__(self, img_root, caption_root, classes_fllename,
word_embedding, max_word_length, img_transform=None):
super(ReedICML2016, self).__init__()
self.alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{} "
self.max_word_length = max_word_length
self.img_transform = img_transform
if self.img_transform == None:
self.img_transform = transforms.ToTensor()
self.data = self._load_dataset(img_root, caption_root, classes_fllename, word_embedding)
def _load_dataset(self, img_root, caption_root, classes_filename, word_embedding):
output = []
with open(os.path.join(caption_root, classes_filename)) as f:
lines = f.readlines()
for line in lines:
cls = line.replace('\n', '')
filenames = os.listdir(os.path.join(caption_root, cls))
for filename in filenames:
if caption_root == '/disk2/datasets/FashionGAN_txt':
datum = cPickle.load(open(os.path.join(caption_root, cls, filename),"rb"))
else:
datum = load_lua(os.path.join(caption_root, cls, filename))
raw_desc = datum['char'].numpy()
desc, len_desc = self._get_word_vectors(raw_desc, word_embedding)
output.append({
'img': os.path.join(img_root, datum['img']),
'desc': desc,
'len_desc': len_desc
})
return output
def _get_word_vectors(self, desc, word_embedding):
output = []
len_desc = []
for i in range(desc.shape[1]):
words = self._nums2chars(desc[:, i])
words = split_sentence_into_words(words)
word_vecs = torch.FloatTensor([word_embedding.get_word_vector(w).astype(np.float64) for w in words])
# zero padding
if len(words) < self.max_word_length:
word_vecs = torch.cat((
word_vecs,
torch.zeros(self.max_word_length - len(words), word_vecs.size(1))
))
output.append(word_vecs)
len_desc.append(len(words))
return torch.stack(output), len_desc
def _nums2chars(self, nums):
chars = ''
for num in nums:
chars += self.alphabet[num - 1]
return chars
def __len__(self):
return len(self.data)
def __getitem__(self, index):
datum = self.data[index]
img = Image.open(datum['img'])
img = self.img_transform(img)
if img.size(0) == 1:
img = img.repeat(3, 1, 1)
desc = datum['desc']
len_desc = datum['len_desc']
# randomly select one sentence
selected = np.random.choice(desc.size(0))
desc = desc[selected, ...]
len_desc = len_desc[selected]
return img, desc, len_desc