-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_vae.py
149 lines (111 loc) · 5.57 KB
/
train_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""
Trains the VAE model and saves it inside models for later use
"""
import string
from collections import Counter
import numpy as np
import pandas as pd
import nltk
# nltk.download()
from nltk.tokenize import word_tokenize
from sklearn.model_selection import train_test_split
from lstm_vae import create_lstm_vae, inference
from pre_processing import preProcessing
def get_text_data(num_samples, data_path, dataset):
thousandwords = [line.rstrip('\n') for line in open('data/1-1000.txt')]
print('thousandwords', thousandwords)
# vectorize the data
input_texts = []
input_texts_original = []
input_words = set(["\t"])
all_input_words = []
lines = []
df = pd.read_csv(data_path, encoding='utf-8')
if dataset == "polarity":
X = df['tweet'].values
y = df['class'].values
elif dataset == "hate":
# Removing the offensive comments, keeping only neutral and hatespeech,
# and convert the class value from 2 to 1 for simplification purposes
df = df[df['class'] != 1]
X = df['tweet'].values
y = df['class'].apply(lambda x: 1 if x == 2 else 0).values
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42, stratify=y, test_size=0.25)
new_X_test = preProcessing(X_test)
for line in new_X_test:
input_texts_original.append(line)
lines.append(
line.lower().translate(str.maketrans('', '', string.punctuation))) # lowercase and remove punctuation
print(lines)
for line in lines[: min(num_samples, len(lines) - 1)]:
input_text = line
input_text = word_tokenize(input_text)
input_text.append("<end>")
input_texts.append(input_text)
for word in input_text:
if word not in input_words:
input_words.add(word)
for word in input_text: # This will be used to count the words and keep the most frequent ones
all_input_words.append(word)
words_to_keep = 4999
most_common_words = [word for word, word_count in
Counter(all_input_words).most_common(words_to_keep)] # Keep the 1000 most common words
most_common_words.append('\t')
for word in thousandwords: # Here we add the 1000 most common english words
most_common_words.append(word)
print(most_common_words)
input_texts_cleaned = [[word for word in text if word in most_common_words] for text in input_texts]
final_input_words = sorted(list(set(most_common_words)))
num_encoder_tokens = len(final_input_words)
max_encoder_seq_length = max([len(txt) for txt in input_texts_cleaned]) + 1
print("input_texts_cleaned", input_texts_cleaned)
print(most_common_words)
print(final_input_words)
print("Number of samples:", len(input_texts_cleaned))
print("Number of unique input tokens:", num_encoder_tokens)
print("Max sequence length for inputs:", max_encoder_seq_length)
input_token_index = dict([(char, i) for i, char in enumerate(final_input_words)])
reverse_input_char_index = dict((i, char) for char, i in input_token_index.items())
encoder_input_data = np.zeros((len(input_texts_cleaned), max_encoder_seq_length, num_encoder_tokens),
dtype="float32")
decoder_input_data = np.zeros((len(input_texts_cleaned), max_encoder_seq_length, num_encoder_tokens),
dtype="float32")
for i, input_text_cleand in enumerate(input_texts_cleaned):
decoder_input_data[i, 0, input_token_index["\t"]] = 1.0
for t, char in enumerate(input_text_cleand):
encoder_input_data[i, t, input_token_index[char]] = 1.0
decoder_input_data[i, t + 1, input_token_index[char]] = 1.0
print('.......')
for i in range(10):
print(input_texts[i])
print(input_texts_cleaned[i])
print('')
return max_encoder_seq_length, num_encoder_tokens, final_input_words, input_token_index, reverse_input_char_index, \
encoder_input_data, decoder_input_data, input_texts_original, X_test, y_test, new_X_test
def decode(s):
return inference.decode_sequence(s, gen, stepper, input_dim, char2id, id2char, max_encoder_seq_length)
if __name__ == "__main__":
# Insert 'hate' or 'polarity' as dataset
dataset_name = 'hate'
res = get_text_data(num_samples=20000, data_path='data/' + dataset_name + '_tweets.csv', dataset=dataset_name)
max_encoder_seq_length, num_enc_tokens, characters, char2id, id2char, \
encoder_input_data, decoder_input_data, input_texts_original, X_original, y_original, X_original_processed = res
print(encoder_input_data.shape, "Creating model...")
input_dim = encoder_input_data.shape[-1]
batch_size = 1
latent_dim = 500
intermediate_dim = 256
if dataset_name == 'hate':
epochs = 200
elif dataset_name == 'polarity':
epochs = 250
vae, enc, gen, stepper, vae_loss = create_lstm_vae(input_dim,
batch_size=batch_size,
intermediate_dim=intermediate_dim,
latent_dim=latent_dim)
print("Training VAE model...")
vae.fit([encoder_input_data, decoder_input_data], encoder_input_data, epochs=epochs, verbose=1)
vae.save('models/' + dataset_name + '_vae_model.h5', overwrite=True)
enc.save('models/' + dataset_name + '_enc_model.h5', overwrite=True)
gen.save('models/' + dataset_name + '_gen_model.h5', overwrite=True)
stepper.save('models/' + dataset_name + '_stepper_model.h5', overwrite=True)