-
Notifications
You must be signed in to change notification settings - Fork 6
/
cls_test.py
100 lines (83 loc) · 3.75 KB
/
cls_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import fitlog, argparse, json
from tqdm import tqdm
from Channel_LM_Prompting.util import get_label_from_template
def get_labels_cls(data):
post_data = {}
for d in tqdm(data):
idx = d['question'] if 'question' in d else d['sentence']
if idx in post_data:
if d['loss'] < post_data[idx]['loss']:
post_data[idx]['loss'] = d['loss']
assert d['true_label'] == post_data[idx]['true_label']
if 'sentence' in d:
assert d['sentence'] == post_data[idx]['sentence']
else:
assert d['question'] == post_data[idx]['sentence']
post_data[idx]['test_label'] = get_label_from_template(args.dataset, d['test_label'])
else:
post_data[idx] = {'loss': d['loss'],
'true_label': d['true_label'],
'test_label': get_label_from_template(args.dataset, d['test_label']),
'sentence': d['sentence'] if 'sentence' in d else d['question']}
true_labels, test_labels = [], []
for k, v in post_data.items():
true_labels.append(v['true_label'])
test_labels.append(v['test_label'])
return true_labels, test_labels
def get_labels_multi_choice(data):
post_data = {}
for d in tqdm(data):
idx = d['question']
if idx in post_data:
if d['loss'] < post_data[idx]['loss']:
post_data[idx]['loss'] = d['loss']
assert d['true_label'] == post_data[idx]['true_label']
assert d['question'] == post_data[idx]['question']
post_data[idx]['test_label'] = d['test_label']
else:
post_data[idx] = {'loss': d['loss'],
'true_label': d['true_label'],
'test_label': d['test_label'],
'question': d['question']}
true_labels, test_labels = [], []
for k, v in post_data.items():
true_labels.append(v['true_label'])
test_labels.append(v['test_label'])
return true_labels, test_labels
def cal_acc(true_labels, test_labels):
assert len(true_labels) == len(test_labels)
correct = 0
for i in range(len(true_labels)):
if true_labels[i] == test_labels[i]:
correct += 1
return round(correct / len(true_labels), 4)
def test():
with open(args.fp) as f:
data = json.load(f)
if args.dataset in ['commonsense_qa', 'cs_explan', 'cosmos_qa', 'social_i_qa', 'piqa', 'race', 'cs_valid',
'hellaswag', 'openbookqa', 'arc_easy', 'copa', 'balanced_copa']:
true_labels, test_labels = get_labels_multi_choice(data)
else:
true_labels, test_labels = get_labels_cls(data)
acc_result = cal_acc(true_labels, test_labels)
print(acc_result)
fitlog.add_best_metric({args.split: {'acc': acc_result}})
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str)
parser.add_argument('--split', type=str, default="test")
parser.add_argument('--fp', )
parser.add_argument('--exp_name', type=str)
parser.add_argument('--method', type=str)
parser.add_argument('--plm', type=str)
parser.add_argument('--iter_scored_num', type=str)
parser.add_argument('--iter_num', type=str)
parser.add_argument('--epoch_num', type=str)
parser.add_argument('--prompt_num', type=str)
parser.add_argument('--alpha', type=str)
parser.add_argument('--beilv', type=str)
args = parser.parse_args()
fitlog.set_log_dir("upr_fitlog/metric_logs/") # 设定日志存储的目录
fitlog.add_hyper(args) # 通过这种方式记录ArgumentParser的参数
test()
fitlog.finish() # finish the logging