-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDataSet.py
94 lines (71 loc) · 2.57 KB
/
DataSet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from os.path import join
from PIL import Image
from torchvision.transforms import transforms
from torch.nn.utils.rnn import pad_sequence
import consts
from transformers import BertModel, BertConfig, BertTokenizer, AutoTokenizer
from transformers import DistilBertTokenizer, DistilBertModel
img_transform = transforms.Compose([
transforms.Resize(consts.IMG_SIZE), # 缩放图片(Image)
transforms.CenterCrop(consts.IMG_SIZE), # 从图片中间切出的图片
transforms.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1](直接除以255)
])
img_arr = [None]
for i in range(1, 5130):
try:
img_path = join(consts.DATA_FOLD, str(i) + ".jpg")
img = Image.open(img_path)
img_arr.append(img_transform(img))
except IOError:
img_arr.append(None)
class DataSet(torch.utils.data.Dataset):
def __init__(self, train=True):
super(DataSet, self).__init__()
self.train = train
self.dataset = []
train_path = join(consts.FOLD, "train.txt")
test_path = join(consts.FOLD, "test_without_label.txt")
path = train_path if train else test_path
with open(path, "r") as f:
lines = f.readlines()
for i, line in enumerate(lines):
if i == 0:
continue
(guid, tag) = line.strip().split(",")
self.dataset.append((int(guid), tag))
self.item = dict()
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
if idx not in self.item:
(guid, tag) = self.dataset[idx]
txt_path = join(consts.DATA_FOLD, str(guid) + ".txt")
img_path = join(consts.DATA_FOLD, str(guid) + ".jpg")
data = {
"img": None,
"txt": None,
"tag": tag,
"guid": guid
}
with open(txt_path, "rb") as f:
data["txt"] = str(f.read())
data["img"] = img_arr[guid]
self.item[idx] = data
return self.item[idx]
tokenizer = AutoTokenizer.from_pretrained(consts.MODEL_NAME)
def collate_fn(batch):
imgs = []
txts = []
tags = []
for item in batch:
imgs.append(item["img"])
txt = item["txt"]
if len(txt) > 510:
txt = txt[0:128] + txt[-392:]
txts.append(txt)
tags.append(consts.tag2int[item["tag"]])
imgs = pad_sequence(imgs, batch_first=True)
txts = torch.tensor(tokenizer(txts, padding=True).input_ids)
tags = torch.tensor(tags).long()
return imgs, txts, tags