-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdata_loader.py
161 lines (124 loc) · 5.77 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import random
import numpy as np
from tqdm import tqdm_notebook
from collections import defaultdict
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, Dataset
from transformers import *
from create_dataset import MOSI, MOSEI, PAD, UNK,UR_FUNNY
bert_tokenizer = BertTokenizer.from_pretrained('/home/cjl/code/cjl/MMIM/bert-base-uncased/', do_lower_case=True)
class MSADataset(Dataset):
def __init__(self, config):
self.config = config
## Fetch dataset
if "mosi" in str(config.data_dir).lower():
dataset = MOSI(config)
elif "mosei" in str(config.data_dir).lower():
dataset = MOSEI(config)
elif "ur_funny" in str(config.data_dir).lower():
dataset = UR_FUNNY(config)
else:
print("Dataset not defined correctly")
exit()
self.data, self.word2id, _ = dataset.get_data(config.mode)
self.len = len(self.data)
config.word2id = self.word2id
# config.pretrained_emb = self.pretrained_emb
@property
def tva_dim(self):
t_dim = 768
return t_dim, self.data[0][0][1].shape[1], self.data[0][0][2].shape[1]
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
def get_loader(hp, config, shuffle=True):
"""Load DataLoader of given DialogDataset"""
dataset = MSADataset(config)
print(config.mode)
config.data_len = len(dataset)
config.tva_dim = dataset.tva_dim
if config.mode == 'train':
hp.n_train = len(dataset)
elif config.mode == 'valid':
hp.n_valid = len(dataset)
elif config.mode == 'test':
hp.n_test = len(dataset)
def collate_fn(batch):
'''
Collate functions assume batch = [Dataset[i] for i in index_set]
'''
# for later use we sort the batch in descending order of length
batch = sorted(batch, key=lambda x: len(x[0][3]), reverse=True)
v_lens = []
a_lens = []
labels = []
ids = []
for sample in batch:
if len(sample[0]) > 4: # unaligned case
v_lens.append(torch.IntTensor([sample[0][4]]))
a_lens.append(torch.IntTensor([sample[0][5]]))
else: # aligned cases
v_lens.append(torch.IntTensor([len(sample[0][3])]))
a_lens.append(torch.IntTensor([len(sample[0][3])]))
labels.append(torch.from_numpy(sample[1]))
ids.append(sample[2])
vlens = torch.cat(v_lens)
alens = torch.cat(a_lens)
labels = torch.cat(labels, dim=0)
# MOSEI sentiment labels locate in the first column of sentiment matrix
if labels.size(1) == 7:
labels = labels[:, 0][:, None]
# Rewrite this
def pad_sequence(sequences, target_len=-1, batch_first=False, padding_value=0.0):
if target_len < 0:
max_size = sequences[0].size()
trailing_dims = max_size[1:]
else:
max_size = target_len
trailing_dims = sequences[0].size()[1:]
max_len = max([s.size(0) for s in sequences])
if batch_first:
out_dims = (len(sequences), max_len) + trailing_dims
else:
out_dims = (max_len, len(sequences)) + trailing_dims
out_tensor = sequences[0].new_full(out_dims, padding_value)
for i, tensor in enumerate(sequences):
length = tensor.size(0)
# use index notation to prevent duplicate references to the tensor
if batch_first:
out_tensor[i, :length, ...] = tensor
else:
out_tensor[:length, i, ...] = tensor
return out_tensor
v_masks = pad_sequence([torch.zeros(torch.FloatTensor(sample[0][1]).size(0)) for sample in batch], target_len=vlens.max().item(),padding_value=1)
a_masks = pad_sequence([torch.zeros(torch.FloatTensor(sample[0][2]).size(0)) for sample in batch], target_len=vlens.max().item(),padding_value=1)
sentences = pad_sequence([torch.LongTensor(sample[0][0]) for sample in batch], padding_value=PAD)
visual = pad_sequence([torch.FloatTensor(sample[0][1]) for sample in batch], target_len=vlens.max().item())
acoustic = pad_sequence([torch.FloatTensor(sample[0][2]) for sample in batch], target_len=alens.max().item())
## BERT-based features input prep
SENT_LEN = min(sentences.size(0),50)
# Create bert indices using tokenizer
bert_details = []
for sample in batch:
text = " ".join(sample[0][3])
encoded_bert_sent = bert_tokenizer.encode_plus(
text, max_length=SENT_LEN, add_special_tokens=True, truncation=True, padding='max_length')
bert_details.append(encoded_bert_sent)
# Bert things are batch_first
bert_sentences = torch.LongTensor([sample["input_ids"] for sample in bert_details])
bert_sentence_types = torch.LongTensor([sample["token_type_ids"] for sample in bert_details])
bert_sentence_att_mask = torch.LongTensor([sample["attention_mask"] for sample in bert_details])
# lengths are useful later in using RNNs
lengths = torch.LongTensor([len(sample[0][0]) for sample in batch])
if (vlens <= 0).sum() > 0:
vlens[np.where(vlens == 0)] = 1
return sentences, visual, vlens, acoustic, alens, labels, lengths, bert_sentences, bert_sentence_types, bert_sentence_att_mask, ids,v_masks,a_masks
data_loader = DataLoader(
dataset=dataset,
batch_size=config.batch_size,
shuffle=shuffle,
collate_fn=collate_fn)
return data_loader