-
Notifications
You must be signed in to change notification settings - Fork 5
/
eval_cvec.py
executable file
·82 lines (71 loc) · 3.22 KB
/
eval_cvec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#!/usr/bin/env python
import argparse
import random
import os
import sys
from read_write import read_word_vectors
from read_write import gzopen
from qvec_scripts.qvec_cca import ComputeCCA, GetVocab, ReadOracleMatrix, ReadVectorMatrix
import subprocess
def get_qvec_gold_filename():
return 'semantic_classes'
def get_relevant_word_types(eval_data_filename):
relevant_word_types = set()
with open(eval_data_filename) as eval_data_file:
for line in eval_data_file:
splits = line.strip().split('\t')
assert len(splits) > 0
word = splits[0]
relevant_word_types.add(word)
return relevant_word_types
def get_relevant_embeddings_filename(eval_data_filename, embeddings_filename):
# We only need embeddings for a subset of word types. Copy the relevant embeddings in a new plain file.
if not os.path.isdir('temp'): os.mkdir('temp')
relevant_embeddings_filename = os.path.join(os.path.dirname(__file__), 'temp', str(random.randint(100000, 999999)))
relevant_word_types = set(get_relevant_word_types(eval_data_filename))
with gzopen(embeddings_filename) as all_embeddings_file:
with open(relevant_embeddings_filename, mode='w') as relevant_embeddings_file:
for line in all_embeddings_file:
if line.split(' ')[0] not in relevant_word_types: continue
relevant_embeddings_file.write(line)
return relevant_embeddings_filename
def compute_coverage(semantic_classes_file, word_vecs):
not_found, total_size = (0, 0)
for line in open(semantic_classes_file):
splits = line.strip().lower().split('\t')
assert len(splits) > 0
word = splits[0]
total_size += 1
if word not in word_vecs:
not_found += 1
assert total_size > 0
return 1.0 - (not_found * 1.0 / total_size)
def qvec_cca_wrapper(in_oracle, in_vectors):
oracle_files = [in_oracle]
vocab_oracle = GetVocab(oracle_files, vocab_union=True)
vocab_vectors = GetVocab([in_vectors])
vocab = set(vocab_vectors) & set(vocab_oracle)
oracle_matrix = ReadOracleMatrix(oracle_files, vocab)
vsm_matrix = ReadVectorMatrix(in_vectors, vocab)
score = ComputeCCA(vsm_matrix, oracle_matrix)
return score
def evaluate(eval_data_dir, embeddings_filename):
eval_data_filename = '{}/{}'.format(eval_data_dir, get_qvec_gold_filename())
relevant_embeddings_filename = get_relevant_embeddings_filename(eval_data_filename, embeddings_filename)
word_vecs = read_word_vectors(relevant_embeddings_filename)
coverage = compute_coverage(eval_data_filename, word_vecs)
score = qvec_cca_wrapper(eval_data_filename, relevant_embeddings_filename)
os.remove(relevant_embeddings_filename)
return (score, coverage,)
def main(argv):
# parse/validate arguments
argparser = argparse.ArgumentParser()
argparser.add_argument("-eval-data", help="Path to a directory which contains all data files needed to setup the evaluation script.")
argparser.add_argument("-embeddings-file", help="Path to the embeddings file (lowercased, UTF8-encoded, space-delimited, optional: suffix .gz indicate the file is gzip compressed.)")
args = argparser.parse_args()
# evaluate
score, coverage = evaluate(args.eval_data, args.embeddings_file)
# report
print 'score={}, coverage={}'.format(score, coverage)
if __name__ == '__main__':
main(sys.argv)