-
Notifications
You must be signed in to change notification settings - Fork 7
/
main.py
56 lines (47 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tqdm
import os
from train_kenlm import arpa_to_lmdb, data_produce
if not os.path.exists('./result_files'):
os.makedirs('./result_files')
lmplz = 'train_kenlm/kenlm/build/bin/lmplz'
data = './result_files/data_cuted.txt'
arpa = './result_files/log.arpa'
def main():
# 1 从 articles 目录中生成预处理好的语料
# data_produce.gen_data_txt(process_num=6, mem_limit_gb=10)
# 2 使用命令行调用 kenlm 训练 arpa 模型
# os.system('{} -o 3 --verbose_header --text {} --arpa {} --prune 0 30 50'.format(lmplz, data, arpa))
# 3 生成最终可用模型,
# 一个 LMDB 用来查词汇转移概率(以 10 为底的对数)
# arpa_to_lmdb.gen_emission_and_database()
def test():
import utility
from dag import dag_v2 as dag
from datetime import datetime
# dag.Database_Type = dag.kLMDB
dag.load_data()
pys = utility.get_pinyin_str("he'li'ji'qun'zhong'man'yi'de'fang'an")
start = datetime.now()
l = dag.get_candidates_from(pys, path_num=10)
end = datetime.now()
print('Running time:{}ms'.format((end - start).microseconds / 1000))
for item in l:
print('/'.join(item.path), item.score)
import res.test
test_data = res.test.smallData
pbar = tqdm.tqdm(total=len(test_data))
hit = 0
for py, value in test_data.items():
pbar.update()
r = dag.get_candidates_from(py, path_num=10)
rstr = 'None'
if len(r) > 0:
rstr = ''.join(r[0].path)
if rstr == value:
hit += 1
if pbar.n % 100 == 0 and rstr != value and len(r) > 0:
print("test:{}, result:{}, should:{}".format(py, '/'.join(r[0].path), value))
print('命中率:{}%'.format(hit / len(test_data) * 100), )
if __name__ == '__main__':
main()
# test()