Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated to work with tensorflow 1.0.1 #2

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/SentenceMatchDataStream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def pad_2d_matrix(in_val, max_length=None, dtype=np.int32):
if max_length is None: max_length = np.max([len(cur_in_val) for cur_in_val in in_val])
batch_size = len(in_val)
out_val = np.zeros((batch_size, max_length), dtype=dtype)
for i in xrange(batch_size):
for i in iter(range(batch_size)):
cur_in_val = in_val[i]
kept_length = len(cur_in_val)
if kept_length>max_length: kept_length = max_length
Expand All @@ -21,10 +21,10 @@ def pad_3d_tensor(in_val, max_length1=None, max_length2=None, dtype=np.int32):
if max_length2 is None: max_length2 = np.max([np.max([len(val) for val in cur_in_val]) for cur_in_val in in_val])
batch_size = len(in_val)
out_val = np.zeros((batch_size, max_length1, max_length2), dtype=dtype)
for i in xrange(batch_size):
for i in iter(range(batch_size)):
cur_length1 = max_length1
if len(in_val[i])<max_length1: cur_length1 = len(in_val[i])
for j in xrange(cur_length1):
for j in iter(range(cur_length1)):
cur_in_val = in_val[i][j]
kept_length = len(cur_in_val)
if kept_length>max_length2: kept_length = max_length2
Expand All @@ -37,9 +37,9 @@ class SentenceMatchDataStream(object):
def __init__(self, inpath, word_vocab=None, char_vocab=None, POS_vocab=None, NER_vocab=None, label_vocab=None, batch_size=60,
isShuffle=False, isLoop=False, isSort=True, max_char_per_word=10, max_sent_length=200):
instances = []
infile = open(inpath, 'rt')
infile = open(inpath, 'rt', encoding='utf-8')
for line in infile:
line = line.decode('utf-8').strip()
line = line.strip()
if line.startswith('-'): continue
items = re.split("\t", line)
label = items[0]
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(self, inpath, word_vocab=None, char_vocab=None, POS_vocab=None, NER
NER_idx_2_batch = None
if NER_vocab is not None: NER_idx_2_batch = []

for i in xrange(batch_start, batch_end):
for i in iter(range(batch_start, batch_end)):
(label, sentence1, sentence2, label_id, word_idx_1, word_idx_2, char_matrix_idx_1, char_matrix_idx_2,
POS_idx_1, POS_idx_2, NER_idx_1, NER_idx_2) = instances[i]
label_batch.append(label)
Expand Down
Binary file modified src/SentenceMatchDataStream.pyc
Binary file not shown.
26 changes: 13 additions & 13 deletions src/SentenceMatchModelGraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
import my_rnn
from tensorflow.python.ops.rnn import dynamic_rnn
import match_utils


Expand Down Expand Up @@ -106,20 +106,20 @@ def __init__(self, num_classes, word_vocab=None, char_vocab=None, POS_vocab=None
passage_char_lengths = tf.reshape(self.passage_char_lengths, [-1])
with tf.variable_scope('char_lstm'):
# lstm cell
char_lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(char_lstm_dim)
char_lstm_cell = tf.contrib.rnn.BasicLSTMCell(char_lstm_dim)
# dropout
if is_training: char_lstm_cell = tf.nn.rnn_cell.DropoutWrapper(char_lstm_cell, output_keep_prob=(1 - dropout_rate))
char_lstm_cell = tf.nn.rnn_cell.MultiRNNCell([char_lstm_cell])
if is_training: char_lstm_cell = tf.contrib.rnn.DropoutWrapper(char_lstm_cell, output_keep_prob=(1 - dropout_rate))
char_lstm_cell = tf.contrib.rnn.MultiRNNCell([char_lstm_cell])

# question_representation
question_char_outputs = my_rnn.dynamic_rnn(char_lstm_cell, in_question_char_repres,
question_char_outputs = dynamic_rnn(char_lstm_cell, in_question_char_repres,
sequence_length=question_char_lengths,dtype=tf.float32)[0] # [batch_size*question_len, q_char_len, char_lstm_dim]
question_char_outputs = question_char_outputs[:,-1,:]
question_char_outputs = tf.reshape(question_char_outputs, [batch_size, question_len, char_lstm_dim])

tf.get_variable_scope().reuse_variables()
# passage representation
passage_char_outputs = my_rnn.dynamic_rnn(char_lstm_cell, in_passage_char_repres,
passage_char_outputs = dynamic_rnn(char_lstm_cell, in_passage_char_repres,
sequence_length=passage_char_lengths,dtype=tf.float32)[0] # [batch_size*question_len, q_char_len, char_lstm_dim]
passage_char_outputs = passage_char_outputs[:,-1,:]
passage_char_outputs = tf.reshape(passage_char_outputs, [batch_size, passage_len, char_lstm_dim])
Expand All @@ -129,15 +129,15 @@ def __init__(self, num_classes, word_vocab=None, char_vocab=None, POS_vocab=None

input_dim += char_lstm_dim

in_question_repres = tf.concat(2, in_question_repres) # [batch_size, question_len, dim]
in_passage_repres = tf.concat(2, in_passage_repres) # [batch_size, passage_len, dim]
in_question_repres = tf.concat(in_question_repres, 2) # [batch_size, question_len, dim]
in_passage_repres = tf.concat(in_passage_repres, 2) # [batch_size, passage_len, dim]

if is_training:
in_question_repres = tf.nn.dropout(in_question_repres, (1 - dropout_rate))
in_passage_repres = tf.nn.dropout(in_passage_repres, (1 - dropout_rate))
else:
in_question_repres = tf.mul(in_question_repres, (1 - dropout_rate))
in_passage_repres = tf.mul(in_passage_repres, (1 - dropout_rate))
in_question_repres = tf.multiply(in_question_repres, (1 - dropout_rate))
in_passage_repres = tf.multiply(in_passage_repres, (1 - dropout_rate))



Expand Down Expand Up @@ -171,7 +171,7 @@ def __init__(self, num_classes, word_vocab=None, char_vocab=None, POS_vocab=None
if is_training:
logits = tf.nn.dropout(logits, (1 - dropout_rate))
else:
logits = tf.mul(logits, (1 - dropout_rate))
logits = tf.multiply(logits, (1 - dropout_rate))
logits = tf.matmul(logits, w_1) + b_1

self.prob = tf.nn.softmax(logits)
Expand All @@ -181,11 +181,11 @@ def __init__(self, num_classes, word_vocab=None, char_vocab=None, POS_vocab=None

gold_matrix = tf.one_hot(self.truth, num_classes, dtype=tf.float32)
# gold_matrix = tf.one_hot(self.truth, num_classes)
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits, gold_matrix))
self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=gold_matrix))

correct = tf.nn.in_top_k(logits, self.truth, 1)
self.eval_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
self.predictions = tf.arg_max(self.prob, 1)
self.predictions = tf.argmax(self.prob, 1)

if optimize_type == 'adadelta':
clipper = 50
Expand Down
Binary file modified src/SentenceMatchModelGraph.pyc
Binary file not shown.
19 changes: 10 additions & 9 deletions src/SentenceMatchTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def collect_vocabs(train_path, with_POS=False, with_NER=False):
all_NERs = None
if with_POS: all_POSs = set()
if with_NER: all_NERs = set()
infile = open(train_path, 'rt')
infile = open(train_path, 'rt', encoding='utf-8')
for line in infile:
line = line.decode('utf-8').strip()
line = line.strip()
if line.startswith('-'): continue
items = re.split("\t", line)
label = items[0]
Expand Down Expand Up @@ -51,7 +51,7 @@ def evaluate(dataStream, valid_graph, sess, outpath=None, label_vocab=None, mode
total_tags = 0.0
correct_tags = 0.0
dataStream.reset()
for batch_index in xrange(dataStream.get_num_batch()):
for batch_index in iter(range(dataStream.get_num_batch())):
cur_dev_batch = dataStream.get_batch(batch_index)
(label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch,
char_matrix_idx_1_batch, char_matrix_idx_2_batch, sent1_length_batch, sent2_length_batch,
Expand Down Expand Up @@ -89,12 +89,12 @@ def evaluate(dataStream, valid_graph, sess, outpath=None, label_vocab=None, mode
if outpath is not None:
if mode =='prediction':
predictions = sess.run(valid_graph.get_predictions(), feed_dict=feed_dict)
for i in xrange(len(label_batch)):
for i in iter(range(len(label_batch))):
outline = label_batch[i] + "\t" + label_vocab.getWord(predictions[i]) + "\t" + sent1_batch[i] + "\t" + sent2_batch[i] + "\n"
outfile.write(outline.encode('utf-8'))
else:
probs = sess.run(valid_graph.get_prob(), feed_dict=feed_dict)
for i in xrange(len(label_batch)):
for i in iter(range(len(label_batch))):
outfile.write(label_batch[i] + "\t" + output_probs(probs[i], label_vocab) + "\n")

if outpath is not None: outfile.close()
Expand All @@ -104,7 +104,7 @@ def evaluate(dataStream, valid_graph, sess, outpath=None, label_vocab=None, mode

def output_probs(probs, label_vocab):
out_string = ""
for i in xrange(probs.size):
for i in iter(range(probs.size)):
out_string += " {}:{}".format(label_vocab.getWord(i), probs[i])
return out_string.strip()

Expand Down Expand Up @@ -232,7 +232,7 @@ def main(_):

initializer = tf.global_variables_initializer()
vars_ = {}
for var in tf.all_variables():
for var in tf.global_variables():
if "word_embedding" in var.name: continue
# if not var.name.startswith("Model"): continue
vars_[var.name.split(":")[0]] = var
Expand All @@ -250,7 +250,7 @@ def main(_):
max_steps = train_size * FLAGS.max_epochs
total_loss = 0.0
start_time = time.time()
for step in xrange(max_steps):
for step in iter(range(max_steps)):
# read data
cur_batch = trainDataStream.nextBatch()
(label_batch, sent1_batch, sent2_batch, label_id_batch, word_idx_1_batch, word_idx_2_batch,
Expand Down Expand Up @@ -303,6 +303,7 @@ def main(_):
accuracy = evaluate(devDataStream, valid_graph, sess,char_vocab=char_vocab, POS_vocab=POS_vocab, NER_vocab=NER_vocab)
print("Current accuracy is %.2f" % accuracy)
if accuracy>best_accuracy:
print('Saving model since it\'s the best so far')
best_accuracy = accuracy
saver.save(sess, best_path)

Expand All @@ -327,7 +328,7 @@ def main(_):
with_full_match=(not FLAGS.wo_full_match), with_maxpool_match=(not FLAGS.wo_maxpool_match),
with_attentive_match=(not FLAGS.wo_attentive_match), with_max_attentive_match=(not FLAGS.wo_max_attentive_match))
vars_ = {}
for var in tf.all_variables():
for var in tf.global_variables():
if "word_embedding" in var.name: continue
if not var.name.startswith("Model"): continue
vars_[var.name.split(":")[0]] = var
Expand Down
Loading