forked from qipeng/gcn-over-pruned-trees
-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
124 lines (90 loc) · 3.96 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""
Run evaluation with saved models
Original Authors: Wenxuan Zhou, Yuhao Zhang
Enhanced By: Jonathan Yellin
Status: prototype
"""
import random
import argparse
import csv
import json
from tqdm import tqdm
import torch
from data.loader import DataLoader
from model.trainer import GCNTrainer
from utils import torch_utils, scorer, constant, helper
from utils.vocab import Vocab
from utils.ucca_embedding import UccaEmbedding
parser = argparse.ArgumentParser()
parser.add_argument('model_dir', type=str, help='Directory of the model.')
parser.add_argument('--model', type=str, default='best_model.pt', help='Name of the model file.')
parser.add_argument('--data_dir', type=str, default='dataset/tacred')
parser.add_argument('--dataset', type=str, default='test', help="Evaluate on dev or test.")
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--cuda', type=bool, default=torch.cuda.is_available())
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--trace_file_for_misses', type=str, help='When provided misses will be outputed to file')
args = parser.parse_args()
if args.trace_file_for_misses != None:
if not helper.is_path_exists_or_creatable(args.trace_file_for_misses):
print(f'"{args.trace_file_for_misses}" is an invalid path. Please supply correct "trace_file_for_misses". Exiting.')
exit(1)
torch.manual_seed(args.seed)
random.seed(args.seed)
if args.cpu:
args.cuda = False
elif args.cuda:
torch.cuda.manual_seed(args.seed)
# load opt
model_file = args.model_dir + '/' + args.model
print("Loading model from {}".format(model_file))
opt = torch_utils.load_config(model_file)
if not opt['binary_classification'] is None:
for label in constant.LABEL_TO_ID.keys():
if label != opt['binary_classification']:
constant.LABEL_TO_ID[label] = 0
trainer = GCNTrainer(opt)
trainer.load(model_file)
# load vocab
vocab_file = args.model_dir + '/vocab.pkl'
vocab = Vocab(vocab_file, load=True)
assert opt['vocab_size'] == vocab.size, "Vocab size must match that in the saved model."
# UCCA Embedding?
ucca_embedding = None
if opt['ucca_embedding_dim'] > 0:
embedding_file = opt['ucca_embedding_dir'] + '/' + opt['ucca_embedding_file']
index_file = opt['ucca_embedding_dir'] + '/' + opt['ucca_embedding_index_file']
ucca_embedding = UccaEmbedding(opt['ucca_embedding_dim'], index_file, embedding_file)
# load data
data_file = opt['data_dir'] + '/{}.json'.format(args.dataset)
print("Loading data from {} with batch size {}...".format(data_file, opt['batch_size']))
with open(data_file) as infile:
data_input = json.load(infile)
batch = DataLoader(data_input, opt['batch_size'], opt, vocab, evaluation=True, ucca_embedding=ucca_embedding)
print("{} batches created for test".format(len(batch.data)))
helper.print_config(opt)
label2id = constant.LABEL_TO_ID
# The id2label[0] = 'no_relation' assignment is necessary for when --binary_classification is active
id2label = dict([(v,k) for k,v in label2id.items()])
id2label[0] = 'no_relation'
predictions = []
all_probs = []
all_ids = []
batch_iter = tqdm(batch)
for i, b in enumerate(batch_iter):
preds, probs, _, ids = trainer.predict(b)
predictions += preds
all_probs += probs
all_ids += ids
predictions = [id2label[p] for p in predictions]
p, r, f1 = scorer.score(batch.gold(), predictions, verbose=True)
print("{} set evaluate result: {:.2f}\t{:.2f}\t{:.2f}".format(args.dataset,p,r,f1))
if args.trace_file_for_misses != None:
print(f'Preparing miss information and writing it to "{args.trace_file_for_misses}"')
with open(args.trace_file_for_misses, 'w', encoding='utf-8', newline='') as trace_file_for_misses:
csv_writer = csv.writer(trace_file_for_misses)
csv_writer.writerow( ['id', 'gold', 'predicted'])
for gold, prediction, id in zip(batch.gold(), predictions, all_ids):
if gold != prediction:
csv_writer.writerow( [id, gold, prediction])
print("Evaluation ended.")