使用pytorch训练你自己的图像分类模型(包括模型训练、推理预测、误差分析)
待完成
相关代码:https://github.com/xxcheng0708/Pytorch_Image_Classifier_Template
pytorch单精度、半精度、混合精度、单卡、多卡(DP / DDP)、FSDP、DeepSpeed模型训练
相关代码:pytorch-model-train-template
使用pytorch框架搭建一个图像分类模型通常包含以下步骤:
1、数据加载DataSet,DataLoader,数据转换transforms
2、构建模型、模型训练
3、模型误差分析
下面依次来看一下上述几个步骤的实现方法:
一、数据加载、数据增强
a)、有时候torchvision.transform中提供的数据转换方法不能满足项目需要,需要自定义数据转换方法进行数据增强,以下InvertTransform类实现了__init__和__call__方法对图像的像素值进行翻转:
b)、数据加载可以参考pytorch数据集加载之DataSet和DataLoader。
c)、使用torchvision.transforms.Compose组合多种数据转换方法进行数据转换:
其中,train_transform用于训练集,val_transform用于验证集,训练集和验证集要进行相同的数据转换操作。并且在pytorch中提供的transform数据转换方法有的处理目标是Image对象,有的处理目标是tensor,并且经过处理后的数据维度变为[N,C,H,W]。
d)、使用Dataset和DataLoader进行数据加载
二、模型构建
在pytorch中提供了常用的分类网络的创建接口以及预训练权重,一般情况下直接使用预训练权重来初始化backbone网络,修改网络的输出层来适配自己的数据集,仅训练网络输出层或者输出层及其之前几层就可以。模型构建方法实例如下:
在上述代码中,创建了MobileNet_V2网络模型,并且同时下载了预训练模型参数,同时修改了网络中的classifier模块,将输出层的维度修改为自己的数据集类别数量,同时将classifier模块的参数和模型的其他参数进行区分,模型除classifier 模块参数之外的参数使用预训练模型参数进行初始化,classifier模块参数使用随机初始化。然后在优化器中对参数进行分组训练,classifier模块的参数需要重点训练,使用较大的学习率,其他模块的参数是预训练模型参数,并且处于浅层网络中,参数仅需要稍微修改就可以,使用较小的学习率。
三、模型训练
模型训练的主要流程是,从DataLoader中分批加载数据送入模型,将模型预测结果与真实结果使用定义的loss计算损失,基于计算的损失进行梯度反向传播进行参数优化。
四、模型评估,模型准确性评估
五、学习曲线
根据迭代训练过程中保存训练集、验证集的准确率、损失值、学习率绘制学习曲线,来判断模型的学习情况,是否出现过拟合、欠拟合等。
六、误差分析
吴恩达大佬不止一次强调过误差分析的重要性,在分类模型训练过程中进行误差分析,可以清晰的看到误差来源,哪些样本容易被误识别,这些样例有什么规律,从而进行数据调整提升模型性能。在最近吴恩达大佬的一次讲座中,再次强调了误差分析的重要性,讲座中提出当你的模型性能遇到瓶颈时,是以模型为中心调整模型呢?还是以数据为中心调整数据呢?大佬的结论是调整模型几乎不会带来性能的提升,而调整数据能带来模型性能的大幅提升。下面就看一下图像分类模型的误差分析示例代码:
该方法根据模型预测结果以及样本的真实标签,计算样本的正确与否,对于预测错误的样本,保存样本的真实标签,预测标签,从而能够清晰的看到那些被分类错误的样本都被误分类成了什么。
就写到这吧!有疑问欢迎随时交流。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)