-
Notifications
You must be signed in to change notification settings - Fork 7
/
convert2dadgnn.py
80 lines (63 loc) · 2.56 KB
/
convert2dadgnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from nltk.stem.porter import *
from tqdm import tqdm
from data import load_data
def convert_data(name):
dataset = load_data(name)
train_texts, train_labels = dataset['train']
test_texts, test_labels = dataset['test']
label_dict = dataset['label_dict']
path_to_data = 'DADGNN/content/data'
# label.txt
with open(f'{path_to_data}/{name}/label.txt', 'w', encoding="utf-8") as f:
for label in list(label_dict.keys())[:-1]:
f.write(str(label) + '\n')
f.write(str(list(label_dict.keys())[-1]))
# remove newlines from texts
train_texts = [text.replace('\n', ' ') for text in train_texts]
test_texts = [text.replace('\n', ' ') for text in test_texts]
train_texts = [text.replace('\r', ' ') for text in train_texts]
test_texts = [text.replace('\r', ' ') for text in test_texts]
train_texts = [text.replace('\n\r', ' ') for text in train_texts]
test_texts = [text.replace('\n\r', ' ') for text in test_texts]
# train.txt
with open(f'{path_to_data}/{name}/{name}-train.txt', 'w', encoding="utf-8") as f:
for text, label in zip(train_texts[:-1], train_labels[:-1]):
f.write(f'{label}\t{text}\n')
f.write(f'{train_labels[-1]}\t{train_texts[-1]}')
# test.txt
with open(f'{path_to_data}/{name}/{name}-test.txt', 'w', encoding="utf-8") as f:
for text, label in zip(test_texts[:-1], test_labels[:-1]):
f.write(f'{label}\t{text}\n')
f.write(f'{test_labels[-1]}\t{test_texts[-1]}')
# region vocab.txt
print('building vocab')
words = []
for text in train_texts + test_texts:
words += text.split()
# stemming
stemmer = PorterStemmer()
words = [stemmer.stem(word) for word in words]
print(words[:10])
print()
# filter out words that occur less than 5 times
counter = dict()
for word in tqdm(words):
if word in counter:
counter[word] += 1
else:
counter[word] = 1
words = [word for word in words if counter[word] >= 5]
words = list(set(words))
print(f'{len(words)} words found')
with open(f'{path_to_data}/{name}/{name}-vocab.txt', 'w', encoding="utf-8") as f:
# write UNK token
f.write('UNK\n')
for word in words[:-1]:
f.write(word + '\n')
f.write(words[-1])
# endregion
if __name__ == '__main__':
datasets = ['MR', 'TREC', 'SST2', 'R8', 'Twitter', 'SearchSnippets', 'NICE', 'NICE2', 'STOPS', 'STOPS2']
for dataset in datasets:
convert_data(dataset)
print(f'convert {dataset} done.')