-
Notifications
You must be signed in to change notification settings - Fork 8
/
CMC.py
82 lines (72 loc) · 3.14 KB
/
CMC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
default_color = ['r','g','b','c','m','y','orange','brown']
default_marker = ['*','o','s','v','X','*','.','P']
class CMC:
def __init__(self,cmc_dict, color=default_color, marker = default_marker):
self.color = color
self.marker = marker
self.cmc_dict = cmc_dict
def plot(self,title,rank=20, xlabel='Rank',ylabel='Matching Rates (%)',show_grid=True):
fig, ax = plt.subplots()
fig.suptitle(title)
x = list(range(0, rank+1, 5))
plt.ylim(0, 1.0)
plt.xlim(1, rank)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.xticks(x)
plt.grid(show_grid)
method_name = []
i = 0
for name in self.cmc_dict.keys():
if rank < len(self.cmc_dict[name]):
temp_cmc = self.cmc_dict[name][:rank]
r = list(range(1, rank+1))
else:
temp_cmc = self.cmc_dict[name]
r = list(range(1, len(temp_cmc)+1))
if name == list(self.cmc_dict.keys())[-1]:
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[0], marker=self.marker[0], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
else:
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i+1], marker=self.marker[i+1], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
i = i+1
ax.add_line(globals()[name])
method_name.append(globals()[name])
plt.legend(handles=method_name)
plt.show()
def save(self, title, filename,
rank=20, xlabel='Rank',
ylabel='Matching Rates (%)', show_grid=True,
save_path=os.getcwd(), format='png', **kwargs):
fig, ax = plt.subplots()
fig.suptitle(title)
x = list(range(0, rank+1, 5))
plt.ylim(0, 1.0)
plt.xlim(1, rank)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.xticks(x)
plt.grid(show_grid)
method_name = []
i = 0
for name in self.cmc_dict.keys():
if rank < len(self.cmc_dict[name]):
temp_cmc = self.cmc_dict[name][:rank]
r = list(range(1, rank+1))
else:
temp_cmc = self.cmc_dict[name]
r = list(range(1, len(temp_cmc)+1))
if name == list(self.cmc_dict.keys())[-1]:
globals()[name] = mlines.Line2D(r, temp_cmc, color='r', marker='*', label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
else:
globals()[name] = mlines.Line2D(r, temp_cmc, color=self.color[i], marker=self.marker[i], label='{:.1f}% {}'.format(self.cmc_dict[name][0]*100, name))
i = i+1
ax.add_line(globals()[name])
method_name.append(globals()[name])
plt.legend(handles=method_name)
fig.savefig(os.path.join(save_path,filename+'.'+format),
format=format,
bbox_inches='tight',
pad_inches = 0, **kwargs)