【Pytorch】进阶学习:深入解析 sklearn.metrics 中的 confusion_matrix(混淆矩阵)
在这里插入图片描述

🌈 个人主页:高斯小哥
🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化Python基础【高质量合集】PyTorch零基础入门教程👈 希望得到您的订阅和支持~
💡 创作高质量博文(平均质量分92+),分享更多关于深度学习、PyTorch、Python领域的优质内容!(希望得到您的关注~)


📚 一、混淆矩阵简介

  在机器学习和数据科学领域,混淆矩阵(Confusion Matrix)是一种常用的性能度量工具,尤其在分类问题中。它提供了分类模型性能的可视化表示,帮助我们深入理解模型的分类效果。混淆矩阵以矩阵的形式展示了真实类别与模型预测类别之间的关系

  混淆矩阵的每一行代表实际类别,每一列代表预测类别。通过混淆矩阵,我们可以清晰地看到每个类别的真正例(True Positive, TP)、假正例(False Positive, FP)、真反例(True Negative, TN)和假反例(False Negative, FN)的数量。

💻 二、使用sklearn.metrics中的confusion_matrix

  在Python的机器学习库sklearn中,confusion_matrix函数是一个非常方便的工具,用于计算混淆矩阵。下面我们将通过一个简单的例子来展示如何使用它。

  • 首先,我们需要导入必要的库,并创建一个简单的分类问题:

    # 导入所需的库  
    from sklearn import datasets           # 导入sklearn库中的datasets模块,用于加载数据集  
    from sklearn.model_selection import train_test_split  # 导入train_test_split函数,用于划分数据集为训练集和测试集  
    from sklearn.linear_model import LogisticRegression  # 导入逻辑回归模型类  
    from sklearn.metrics import confusion_matrix  # 导入混淆矩阵计算函数  
    import seaborn as sns  # 导入seaborn库,用于绘制热图  
    import matplotlib.pyplot as plt  # 导入matplotlib库,用于绘图  
      
    # 加载鸢尾花数据集  
    iris = datasets.load_iris()  # 加载鸢尾花数据集  
    X = iris.data  # 获取特征数据  
    y = iris.target  # 获取目标标签  
      
    # 划分训练集和测试集  
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)  
    # 使用train_test_split函数将数据集划分为训练集和测试集,测试集占比20%,random_state设置随机种子以保证结果可复现  
      
    # 初始化逻辑回归模型  
    model = LogisticRegression()  # 创建逻辑回归模型对象  
      
    # 训练模型  
    model.fit(X_train, y_train)  # 使用训练数据拟合模型  
      
    # 预测测试集  
    y_pred = model.predict(X_test)  # 使用训练好的模型对测试集进行预测  
    
  • 接下来,我们使用confusion_matrix函数来计算混淆矩阵:

    # 计算混淆矩阵  
    cm = confusion_matrix(y_test, y_pred)  # 计算真实标签和预测标签之间的混淆矩阵  
      
    # 输出混淆矩阵  
    print("Confusion Matrix:")  
    print(cm)  # 打印混淆矩阵  
    

输出将是一个二维数组,表示每个类别的真正例、假正例、真反例和假反例的数量。

📊 三、可视化混淆矩阵

  为了更好地理解混淆矩阵,我们通常会将其可视化。使用Seaborn库可以很方便地绘制混淆矩阵的热图:

  • 代码示例:

    # 设置类别标签  
    class_names = iris.target_names  # 获取数据集中的类别名称  
      
    # 绘制混淆矩阵热图  
    plt.figure(figsize=(10, 7))  # 创建一个指定大小的画布  
    sns.heatmap(cm, annot=True, xticklabels=class_names, yticklabels=class_names, cmap='Blues', fmt="d")  
    # 使用seaborn库中的heatmap函数绘制混淆矩阵的热图  
    # annot=True表示在热图中显示数值,xticklabels和yticklabels分别设置x轴和y轴的标签,cmap设置颜色映射,fmt设置数值格式  
      
    plt.xlabel('Predicted')  # 设置x轴标签为"Predicted"  
    plt.ylabel('True')  # 设置y轴标签为"True"  
    plt.show()  # 显示图形
    

这将生成一个热图,其中颜色深浅表示对应类别的实例数量。通过热图,我们可以直观地看到模型对每个类别的分类效果。

  • 输出如下:

在这里插入图片描述

问题:我想让上图的字体变大一点,应该怎么办啊?这图看得好难受呀😣😣😣,最好变成下图这样👇👇👇

在这里插入图片描述
  如果您也有类似的需求,博主强烈推荐您访问博客文章【matplotlib】一文解决 plt.xticks 调整字体大小,您会找到相应的解决方案。

🎯 四、混淆矩阵的性能指标

  混淆矩阵不仅展示了分类的详细情况,还可以从中提取出多个性能指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall)和F1分数(F1 Score)。这些指标可以帮助我们更全面地评估模型的性能。

  • 例如,我们可以使用sklearn.metrics中的函数来计算这些指标:

    # 导入所需的sklearn库中的评估指标函数
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    
    # 使用accuracy_score函数计算模型在测试集上的准确率
    # 准确率是所有预测正确的样本数占总样本数的比例
    accuracy = accuracy_score(y_test, y_pred)
    print(f"Accuracy: {accuracy:.2f}")  # 打印准确率的值,保留两位小数
    
    # 计算每个类别的精确率
    # 精确率是指预测为正例的样本中真正为正例的比例
    # average=None表示返回每个类别的精确率,而不是平均值
    precisions = precision_score(y_test, y_pred, average=None)
    print(f"Precision: {precisions}")  # 打印每个类别的精确率
    
    # 计算每个类别的召回率
    # 召回率是指真正为正例的样本中被预测为正例的比例
    # average=None同样表示返回每个类别的召回率,而不是平均值
    recalls = recall_score(y_test, y_pred, average=None)
    print(f"Recall: {recalls}")  # 打印每个类别的召回率
    
    # 计算每个类别的F1分数
    # F1分数是精确率和召回率的调和平均值,用于综合评估模型的性能
    # average=None表示返回每个类别的F1分数,而不是平均值
    f1s = f1_score(y_test, y_pred, average=None)
    print(f"F1 Score: {f1s}")  # 打印每个类别的F1分数
    

    代码解释:

    1. accuracy_score:用于计算模型预测的准确率,即正确预测的样本数与总样本数的比例。

    2. precision_score:用于计算模型的精确率,精确率关注模型预测为正例的样本中,有多少是真正的正例。在多分类问题中,如果average=None,则对每个类别分别计算精确率,返回一个数组。

    3. recall_score:用于计算模型的召回率,召回率关注所有真正的正例样本中,有多少被模型正确预测为正例。在多分类问题中,average=None意味着对每个类别分别计算召回率。

    4. f1_score:用于计算模型的F1分数,F1分数是精确率和召回率的调和平均值,能够综合评估模型在精确率和召回率上的性能。当average=None时,对每个类别分别计算F1分数。

通过打印每个类别的精确率、召回率和F1分数,我们可以更详细地了解模型在不同类别上的性能表现,有助于分析模型在哪些类别上表现较好,哪些类别上表现较差,进而进行针对性的优化。

🔍 五、混淆矩阵的局限性及改进方法

  虽然混淆矩阵是一个强大的工具,但它也有一些局限性。例如,当数据集中的类别分布不均衡时,准确率可能不是一个很好的指标。此外,混淆矩阵本身只提供了分类结果的统计信息,而没有告诉我们为什么模型会做出错误的预测

为了改进模型的性能,我们可以采取一些措施:

  1. 重采样技术:当数据集中类别分布不均衡时,可以使用过采样(oversampling)或欠采样(undersampling)技术来平衡类别分布,从而提高模型的性能评估准确性。

  2. 阈值调整:在某些情况下,模型输出的概率值可能需要调整,以改变分类的阈值。这可以通过绘制精确率-召回率曲线(PR曲线)或接收者操作特征曲线(ROC曲线)来确定最佳阈值。

  3. 使用集成方法:集成学习(如随机森林、梯度提升机等)通常能够提供更稳定和准确的预测结果,因为它们结合了多个模型的输出。

  4. 特征工程和选择:通过选择更具代表性的特征或进行特征转换,可以提高模型的分类性能。

📚 六、混淆矩阵的扩展应用

  混淆矩阵不仅限于基本的分类问题。在更复杂的任务中,如多标签分类、多类分类以及序列标注任务中,混淆矩阵的变种或扩展形式也可以用来评估模型的性能。

🚀 七、总结与展望

  混淆矩阵是机器学习领域中的一个重要工具,它提供了对分类模型性能的深入洞察。通过计算和可视化混淆矩阵,我们可以更好地理解模型的分类效果,并据此进行模型优化和改进。

  随着机器学习技术的不断发展,未来我们可能会看到更多关于混淆矩阵的扩展和变种,以适应更复杂的分类任务和性能评估需求。因此,掌握混淆矩阵的基本原理和应用方法,对于机器学习从业者来说是非常有益的。

  通过本文的学习,相信你已经对sklearn.metrics中的confusion_matrix有了更深入的了解。希望你在实际的项目中能够灵活运用这一工具,提升模型的分类性能,并不断探索新的应用场景和扩展方法。

  记住,混淆矩阵只是评估模型性能的一个方面,还需要结合其他指标和工具来全面评估模型的优劣。在实际应用中,不断探索和尝试,才能找到最适合你任务的模型和评估方法。

  最后,希望你在机器学习的道路上越走越远,不断取得新的进步和成就!🚀

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐