-
Notifications
You must be signed in to change notification settings - Fork 5
/
get_xy.py
121 lines (94 loc) · 3.4 KB
/
get_xy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=invalid-name, no-member, too-many-locals
import importlib
import os
import numpy as np
import tensorflow as tf
import texar as tx
import pickle
import json
from utils import *
from text2num import text2num, NumberException
# from replace_numbers import replace_numbers
"""""
with open(os.path.join('data2text', 'res', 'train.idx.json'), 'r') as idx_f:
sent_idx = json.load(idx_f)
sent_to_idx = dict(map(lambda pair: (' ', pair[1]), sent_idx))
"""
def get_align(text00, text01, text02, text1):
# print('=========text1 is :{}'.format(text1))
text00, text01, text02, text1 = map(
strip_special_tokens_of_list,
(text00, text01, text02, text1))
sd_texts, sent_texts = pack_sd(DataItem(text00, text01, text02)), text1
#print('=========sent_texts is :{}'.format(sent_texts))
sent = ' '.join(sent_texts)
# print('=========sent is :{}'.format(sent))
idxs = []
for entry in text00:
idxs.append(text1.tolist().index(entry) if entry in text1 else -1)
assert len(idxs) == len(sd_texts), "\nidxs = {}\nsd_texts = {}\nsent = {}".format(idxs, sd_texts, sent)
align = [
[int(j == idx)
for j in range(len(sent_texts))]
for idx in idxs]
return np.array(align)
batch_get_align = batchize(get_align)
def print_align(sd_text0, sd_text1, sd_text2, sent_text, align):
sd_text = [sd_text0, sd_text1, sd_text2]
for text, name in zip(sd_text, sd_fields):
print('{:>20}'.format(name) + ' '.join(map('{:>18}'.format, text)))
for j, sent_token in enumerate(sent_text):
print('{:>20}'.format(sent_token) + ' '.join(map(
lambda x: '{:18}'.format(x) if x != 0 else ' ' * 18,
align[:, j])))
batch_print_align = batchize(print_align)
def main():
# data batch
datasets = {mode: tx.data.MultiAlignedData(hparams)
for mode, hparams in config_data.datas.items()}
data_iterator = tx.data.FeedableDataIterator(datasets)
data_batch = data_iterator.get_next()
def _get_align(sess, mode):
print('in _get_align')
data_iterator.restart_dataset(sess, mode)
feed_dict = {
tx.global_mode(): tf.estimator.ModeKeys.EVAL,
data_iterator.handle: data_iterator.get_handle(sess, mode),
}
with open('align.pkl', 'wb') as out_file:
while True:
try:
batch = sess.run(data_batch, feed_dict)
sd_texts, sent_texts = (
[batch['{}{}_text'.format(field, ref_strs[1])]
for field in fields]
for fields in (sd_fields, sent_fields))
aligns = batch_get_align(*(sd_texts + sent_texts))
sd_texts, sent_texts = (
[batch_strip_special_tokens_of_list(texts)
for texts, field in zip(all_texts, fields)]
for all_texts, fields in zip(
(sd_texts, sent_texts), (sd_fields, sent_fields)))
if FLAGS.verbose:
batch_print_align(*(sd_texts + sent_texts + [aligns]))
for align in aligns:
pickle.dump(align, out_file)
except tf.errors.OutOfRangeError:
break
print('end _get_align')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())
_get_align(sess, 'train')
if __name__ == '__main__':
flags = tf.flags
flags.DEFINE_string("config_data", "config_data_nba_stable",
"The data config.")
flags.DEFINE_boolean("verbose", False, "verbose.")
FLAGS = flags.FLAGS
config_data = importlib.import_module(FLAGS.config_data)
main()