-
Notifications
You must be signed in to change notification settings - Fork 1
/
getEnsembleAccuracy.py
87 lines (73 loc) · 2.8 KB
/
getEnsembleAccuracy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import numpy as np
import argparse
import math
parser = argparse.ArgumentParser()
parser.add_argument('data_path', metavar='DATA',
help='path to data')
parser.add_argument('-s', '--soft', action='store_true', default=False,
help='set for soft bagging, otherwise hard bagging')
parser.add_argument('-n', '--n_models', default=16, type=int,
help='# of models (default=16)')
parser.add_argument('-p', '--pkl_file', default='filelist', type=str,
help='path to pickle file')
parser.add_argument('-nd', '--n_data_points', default=10000, type=int,
help='# of problems (default=10,000)')
parser.add_argument('--nway', default=51, type=int,
help='# of classes per problem (default=51)')
parser.add_argument('--kquery', default=1, type=int,
help='# of queries per class (default=1)')
args = parser.parse_args()
n_models = args.n_models
n_data_points = args.n_data_points
pkl_file = args.pkl_file
data_path = args.data_path
nway = args.nway
kquery = args.kquery
def myEntropy(count_labels, base=None):
""" Computes entropy of label distribution. """
count_labels = count_labels[count_labels!=0]
tot_labels = sum(count_labels)
if tot_labels <= 1:
return 0
probs = count_labels / tot_labels
n_classes = np.count_nonzero(probs)
if n_classes <= 1:
return 0
ent = 0.
# Compute entropy
base = math.e if base is None else base
for i in probs:
ent -= i * math.log(i, base)
return ent
def test():
threshold = math.ceil(n_models / 2)
labels_query = [0] * ((nway - 1) * kquery)
labels_query.extend([1] * (nway - 1) * kquery)
query_preds_list = []
for i in range(n_models):
preds_file = os.path.join(data_path, 'WHOLE', 'queryPreds_' + pkl_file + '_model' + str(i) + '.npy')
query_preds_list.append(np.load(preds_file))
if args.soft:
query_preds= np.mean(np.array(query_preds_list), axis=0)
query_predictions = np.argmax(query_preds, axis=-1)
# entropy_probs = 0
# for i in range(n_data_points):
# for j in range(2*kquery):
# entropy_probs += entropy(query_probs[i][j])
# print('Average ensemble entropy:', entropy_probs/ (n_data_points * 2 * (nway - 1) * kquery))
else:
query_preds_list = np.argmax(np.array(query_preds_list), axis=-1)
query_votes = np.count_nonzero(query_preds_list, axis=0)
query_predictions = (query_votes >= threshold)
correct_predictions = np.sum(query_predictions == labels_query, axis=1)
test_accs = correct_predictions / (2 * (nway - 1) * kquery)
print('Mean test accuracy:', np.mean(test_accs))
stds = np.std(test_accs)
ci95 = 1.96 * stds * 100 / np.sqrt(n_data_points)
print('stds:', stds)
print('ci95:', ci95)
def main():
test()
if __name__=="__main__":
main()