目录

1. 概要

2. 二分类情况下的混淆矩阵

3. 多分类情况下的混淆矩阵

4. 混淆矩阵的图视化

4.1 sklearn. confusion_matrix() and plot_confusion_matrix()

4.2 seaborn heatmap()

5. Next


1. 概要

        在机器学习领域,混淆矩阵(confusion matrix)是一种评判模型结果指标的可视化工具,属于模型评估的一部分,多用于判断分类器(Classifier)的优劣。特别用于监督学习,在无监督学习一般叫做匹配矩阵。

        简而言之,混淆矩阵就是分别统计分类模型归错类,归对类的观测值个数,然后把结果放在一个表里展示出来。这个表就是混淆矩阵。

        混淆矩阵以行代表实际类别结果(或 Ground Truth),以列表示实际分类预测结果,其中每一个元素(i,j)所存储的值表示实际类别为type(i)而被分类器识别为type(j)的个数。

        混淆矩阵是机器学习领域中评判模型结果的指标,属于模型评估的一部分,多用于判断分类器(Classifier)的优劣。

2. 二分类情况下的混淆矩阵

        二分类的混淆矩阵例如下图所示:

图 1 二分类混淆矩阵 

  1. TP(True Positive):将正(Positive)类预测为正类,真实为0,预测也为0(此处假定0表示正类,1表示负类)
  2. FN(False Negative):将正类预测为负类,真实为0,预测为1
  3. FP(False Positive):将负(Negative)类预测为正类, 真实为1,预测为0
  4. TN(True Negative):将负类预测为负类,真实为1,预测也为1

        其中,用#()表示对应该种情况的发生个数。

        当然,在二分类问题中所谓的“Positive(正)”和“Negative(负)”并不像其字面含义那样具有倾向性,即并不是“正”有积极向上或者更好的含义。将类别1标为正类、类别2标为负类,或者将类别2标为正类、类别1标为负类,并没有实质性的差异。

        比如说,在“猫-vs-狗”的2分类问题中,无论把猫还是狗标识为“正”类都是可以的。

        但是在有些情况下,人们的确会习惯于按照日常习惯的倾向来进行正、负的标识,但是也仍然不是“正”有积极向上或者更好的含义。比如说,说在癌症检测问题中,两个类别分别为‘阳性’和‘阴性’,将‘阳性’标识为‘正’类而‘阴性’标识为‘负’类就是一个自然的选择,你当然不能说‘阳性’代表更好更积极向上。另外一个例子,在‘cat-or-not’分类中两个类别分别是‘是猫’和‘非猫’,那么将‘是猫’标识为‘正’类而‘非猫’标识为‘负’类就是一个自然的选择,虽然从逻辑上来说并没有任何必然性。

        当然,‘正’、‘负’类别的标识方法以及可能由字面意义带来的歧义其实只是二分类问题特有的,在多分类问题中就不存在这一问题了。

3. 多分类情况下的混淆矩阵

        多分类的混淆矩阵例如下图所示:

图 2 多分类情况下的混淆矩阵

        每一行之和表示该类别的真实样本数量#Ck_gt,每一列之和表示被预测为该类别的样本数量#Ck_pred.

        对角线上的元素标识正确分类的结果,非对角线的元素都标识错误分类的结果。

4. 混淆矩阵的可视化

4.1 sklearn. confusion_matrix() and plot_confusion_matrix()

        sklearn.metrics包中提供了confusion_matrix() 方法用于根据预测结果以及标签真值)(Ground Truth)生成混淆矩阵。而另一个方法plot_confusion_matrix()则用于直接绘制图示化的混淆矩阵。

        

        以下代码中先利用make_classification()创建了一个二分类的玩具数据集,然后实例化了一个支持向量机分类器,对齐进行训练和预测。然后调用plot_confusion_matrix()绘制该分类应用于该数据集的的测试集时的混淆矩阵。     

# Example 1: Using sklearn plot_confusion_matrix 
# Ref: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.plot_confusion_matrix.html
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(
        X, y, random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)

plot_confusion_matrix(clf, X_test, y_test)  
plt.show()    

        运行后的效果如下所示:

 图3 plot_confusion_matrix()绘制混淆矩阵示例

4.2 seaborn heatmap()

        Ref: https://www.stackvidhya.com/plot-confusion-matrix-in-python-and-why/

        利用seaborn库中的heatmap绘制功能画出来的图会更漂亮一些。

        Seaborn heatmap()方法的调用参数如下所示(data为必须的参数,其余为可选参数用于控制图示效果选项。更多的参数请参考Seaborn heatmap()文档):

  • data – A rectangular dataset that can be coerced into a 2d array. Here, you can pass the confusion matrix you already have
  • annot=True – To write the data value in the cell of the printed matrix. By default, this is False.
  • cmap=Blues – This is to denote the matplotlib color map names. 

        heatmap()方法返回matplotlib axes,可以存储于一个变量,以便于后面进一步修改图示效果选项,比如说,设置titlex-axis and y-axis labels and tick labels for x-axis and y-axis. 注意,也可以在heatmap()的参数列表中用ax参数来指定用于存储matplotlib axes的变量,如以下例所示。

  • Title – Used to label the complete image. Use the set_title() method to set the title.
  • Axes-labels – Used to name the x axis or y axis. Use the set_xlabel() to set the x-axis label and set_ylabel() to set the y-axis label.
  • Tick labels – Used to denote the datapoints on the axes. You can pass the tick labels in an array, and it must be in ascending order. Because the confusion matrix contains the values in the ascending order format. Use the xaxis.set_ticklabels() to set the tick labels for x-axis and yaxis.set_ticklabels() to set the tick labels for y-axis.

        最后需要调用plot.show() 方法以显示该图.

# Example2: Using seaborn heatmap
# Ref: https://www.stackvidhya.com/plot-confusion-matrix-in-python-and-why/
import seaborn as sns
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

print('Example2: Using seaborn heatmap for confusion matrix visualization')
sns.set()
f,ax = plt.subplots()
# y_true = [0,0,1,2,2,0,2,0,1]
y_pred = clf.predict(X_test)
C2 = confusion_matrix(y_test,y_pred,labels=[0,1])
# print C2
print(C2)
sns.heatmap(C2,annot=True,ax=ax) #plot heatmap
# ax.plot(C2)

ax.set_title('Seaborn Confusion Matrix with labels\n\n');
ax.set_xlabel('predict') 
ax.set_ylabel('true') #
plt.show()

         运行后的效果如下图所示:

 图4 基于seaborn heatmap()绘制混淆矩阵示例

        seaborn的热图绘制用起来稍微麻烦一点,但是效果是不是要更好一丢丢^-^

5. Next

        基于混淆矩阵可以计算很多分类器性能评估指标,如accuracy,precision,等等。欲知后事如何,且听下回分解。

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐