Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

混淆矩阵的绘制 #14

Open
jayboxyz opened this issue Nov 11, 2019 · 2 comments
Open

混淆矩阵的绘制 #14

jayboxyz opened this issue Nov 11, 2019 · 2 comments

Comments

@jayboxyz
Copy link
Owner

jayboxyz commented Nov 11, 2019

注1:

在下文计算混淆矩阵的代码中,可能会出现一个报错:

missing from current font.

加入下面代码可以解决该报错:

plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']

注2:

当使用如下代码保存使用 plt.savefig 保存生成的图片时,结果打开生成的图片却是一片空白。

import matplotlib.pyplot as plt

""" 一些画图代码 """

plt.show()
plt.savefig("filename.png")

原因:其实产生这个现象的原因很简单:在 plt.show() 后调用了 plt.savefig() ,在 plt.show() 后实际上已经创建了一个新的空白的图片(坐标轴),这时候你再 plt.savefig() 就会保存这个新生成的空白图片。

解决:在 plt.show() 之前调用 plt.savefig();

@jayboxyz
Copy link
Owner Author

jayboxyz commented Nov 11, 2019

1、混淆矩阵的绘制(Plot a confusion matrix)

# 绘制混淆矩阵
def confusion_matrix(preds, labels, conf_matrix):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1
    return conf_matrix
conf_matrix = torch.zeros(10, 10)
for data, target in test_loader:
    output = fullModel(data.to(device))
    conf_matrix = confusion_matrix(output, target, conf_matrix)

最后得到的conf_matrix就是混淆矩阵的值。
image

有了上面的混淆矩阵中具体的值,下面就是进行可视化的步骤。可视化我们使用seaborn来进行完成。因为我这里conf_matrix的值是tensor, 所以需要先转换为Numpy.

import seaborn as sn
df_cm = pd.DataFrame(conf_matrix.numpy(),
                     index = [i for i in list(Attack2Index.keys())],
                     columns = [i for i in list(Attack2Index.keys())])
plt.figure(figsize = (10,7))
sn.heatmap(df_cm, annot=True, cmap="BuPu")

image

混淆矩阵的可视化(进行美化):

import itertools
# 绘制混淆矩阵
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Input
    - cm : 计算出的混淆矩阵的值
    - classes : 混淆矩阵中每一行每一列对应的列
    - normalize : True:显示百分比, False:显示个数
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

测试:

plot_confusion_matrix(cnf_matrix, classes=attack_types, normalize=True, title='Normalized confusion matrix')

image

2、除了上文,可以看下该文 混淆矩阵及绘图

image

from sklearn.metrics import confusion_matrix
from sklearn.metrics import recall_score
import matplotlib.pyplot as plt
 
 
# 预测数据,predict之后的预测结果集
guess = [1, 0, 1, 2, 1, 0, 1, 0, 1, 0]
# 真实结果集
fact = [0, 1, 0, 1, 2, 1, 0, 1, 0, 1]
# 类别
classes = list(set(fact))
# 排序,准确对上分类结果
classes.sort()
# 对比,得到混淆矩阵
confusion = confusion_matrix(guess, fact)
# 热度图,后面是指定的颜色块,gray也可以,gray_x反色也可以
plt.imshow(confusion, cmap=plt.cm.Blues)
# 这个东西就要注意了
# ticks 这个是坐标轴上的坐标点
# label 这个是坐标轴的注释说明
indices = range(len(confusion))
# 坐标位置放入
# 第一个是迭代对象,表示坐标的顺序
# 第二个是坐标显示的数值的数组,第一个表示的其实就是坐标显示数字数组的index,但是记住必须是迭代对象
plt.xticks(indices, classes)
plt.yticks(indices, classes)
# 热度显示仪?就是旁边的那个验孕棒啦
plt.colorbar()
# 就是坐标轴含义说明了
plt.xlabel('guess')
plt.ylabel('fact')
# 显示数据,直观些
for first_index in range(len(confusion)):
    for second_index in range(len(confusion[first_index])):
        plt.text(first_index, second_index, confusion[first_index][second_index])
 
# 显示
plt.show()
 
# PS:注意坐标轴上的显示,就是classes
# 如果数据正确的,对应关系显示错了就功亏一篑了
# 一个错误发生,想要说服别人就更难了

3、如何用python画好confusion matrix

image

'''compute confusion matrix
labels.txt: contain label name.
predict.txt: predict_label true_label
'''
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
#load labels.
labels = []
file = open('labels.txt', 'r')
lines = file.readlines()
for line in lines:
	labels.append(line.strip())
file.close()
 
y_true = []
y_pred = []
#load true and predict labels.
file = open('predict.txt', 'r')
lines = file.readlines()
for line in lines:
	y_true.append(int(line.split(" ")[1].strip()))
	y_pred.append(int(line.split(" ")[0].strip()))
file.close()
tick_marks = np.array(range(len(labels))) + 0.5
def plot_confusion_matrix(cm, title='Confusion Matrix', cmap = plt.cm.binary):
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(labels)))
    plt.xticks(xlocations, labels, rotation=90)
    plt.yticks(xlocations, labels)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
cm = confusion_matrix(y_true, y_pred)
print cm
np.set_printoptions(precision=2)
cm_normalized = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis]
print cm_normalized
plt.figure(figsize=(12,8), dpi=120)
#set the fontsize of label.
#for label in plt.gca().xaxis.get_ticklabels():
#    label.set_fontsize(8)
#text portion
ind_array = np.arange(len(labels))
x, y = np.meshgrid(ind_array, ind_array)
 
for x_val, y_val in zip(x.flatten(), y.flatten()):
    c = cm_normalized[y_val][x_val]
    if (c > 0.01):
	plt.text(x_val, y_val, "%0.2f" %(c,), color='red', fontsize=7, va='center', ha='center')
#offset the tick
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
 
plot_confusion_matrix(cm_normalized, title='Normalized confusion matrix')
#show confusion matrix
plt.show()

下面的当做参考:

1、https://github.com/Tony607/Focal_Loss_Keras/blob/master/src/keras_focal_loss.ipynb
image

@upupbo
Copy link

upupbo commented Mar 5, 2021

NameError: name 'Attack2Index' is not defined 请问这个问题如何解决

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants