-
Notifications
You must be signed in to change notification settings - Fork 117
/
test.py
114 lines (102 loc) · 4.58 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
__author__ = '[email protected]'
"""
标记文件
"""
import codecs
import yaml
import pickle
import tensorflow as tf
from load_data import load_vocs, init_data
from model import SequenceLabelingModel
def main():
# 加载配置文件
with open('./config.yml') as file_config:
config = yaml.load(file_config)
feature_names = config['model_params']['feature_names']
use_char_feature = config['model_params']['use_char_feature']
# 初始化embedding shape, dropouts, 预训练的embedding也在这里初始化)
feature_weight_shape_dict, feature_weight_dropout_dict, \
feature_init_weight_dict = dict(), dict(), dict()
for feature_name in feature_names:
feature_weight_shape_dict[feature_name] = \
config['model_params']['embed_params'][feature_name]['shape']
feature_weight_dropout_dict[feature_name] = \
config['model_params']['embed_params'][feature_name]['dropout_rate']
path_pre_train = config['model_params']['embed_params'][feature_name]['path']
if path_pre_train:
with open(path_pre_train, 'rb') as file_r:
feature_init_weight_dict[feature_name] = pickle.load(file_r)
# char embedding shape
if use_char_feature:
feature_weight_shape_dict['char'] = \
config['model_params']['embed_params']['char']['shape']
conv_filter_len_list = config['model_params']['conv_filter_len_list']
conv_filter_size_list = config['model_params']['conv_filter_size_list']
else:
conv_filter_len_list = None
conv_filter_size_list = None
# 加载数据
# 加载vocs
path_vocs = []
if use_char_feature:
path_vocs.append(config['data_params']['voc_params']['char']['path'])
for feature_name in feature_names:
path_vocs.append(config['data_params']['voc_params'][feature_name]['path'])
path_vocs.append(config['data_params']['voc_params']['label']['path'])
vocs = load_vocs(path_vocs)
# 加载数据
sep_str = config['data_params']['sep']
assert sep_str in ['table', 'space']
sep = '\t' if sep_str == 'table' else ' '
max_len = config['model_params']['sequence_length']
word_len = config['model_params']['word_length']
data_dict = init_data(
path=config['data_params']['path_test'], feature_names=feature_names, sep=sep,
vocs=vocs, max_len=max_len, model='test', use_char_feature=use_char_feature,
word_len=word_len)
# 加载模型
model = SequenceLabelingModel(
sequence_length=config['model_params']['sequence_length'],
nb_classes=config['model_params']['nb_classes'],
nb_hidden=config['model_params']['bilstm_params']['num_units'],
num_layers=config['model_params']['bilstm_params']['num_layers'],
feature_weight_shape_dict=feature_weight_shape_dict,
feature_init_weight_dict=feature_init_weight_dict,
feature_weight_dropout_dict=feature_weight_dropout_dict,
dropout_rate=config['model_params']['dropout_rate'],
nb_epoch=config['model_params']['nb_epoch'], feature_names=feature_names,
batch_size=config['model_params']['batch_size'],
train_max_patience=config['model_params']['max_patience'],
use_crf=config['model_params']['use_crf'],
l2_rate=config['model_params']['l2_rate'],
rnn_unit=config['model_params']['rnn_unit'],
learning_rate=config['model_params']['learning_rate'],
use_char_feature=use_char_feature,
conv_filter_size_list=conv_filter_size_list,
conv_filter_len_list=conv_filter_len_list,
word_length=word_len,
path_model=config['model_params']['path_model'])
saver = tf.train.Saver()
saver.restore(model.sess, config['model_params']['path_model'])
# 标记
result_sequences = model.predict(data_dict)
# 写入文件
label_voc = dict()
for key in vocs[-1]:
label_voc[vocs[-1][key]] = key
with codecs.open(config['data_params']['path_test'], 'r', encoding='utf-8') as file_r:
sentences = file_r.read().strip().split('\n\n')
file_result = codecs.open(
config['data_params']['path_result'], 'w', encoding='utf-8')
for i, sentence in enumerate(sentences):
for j, item in enumerate(sentence.split('\n')):
if j < len(result_sequences[i]):
file_result.write('%s\t%s\n' % (item, label_voc[result_sequences[i][j]]))
else:
file_result.write('%s\tO\n' % item)
file_result.write('\n')
file_result.close()
if __name__ == '__main__':
main()