【MobileNetV2 Mydataset】MobileNetV2训练自定义分类数据集
文章目录1 分类数据集准备2 获取训练与验证图片路径及标签3 Dataset类与DataLoader类的理解3.1 Dataset类3.2 DataLoader类3.3 Dataset与DataLoader综合使用简单示例4 MobileNetV2介绍5 训练总体流程6 推理一张图片7 感谢链接1 分类数据集准备期待的分类数据集样式如下,注意,验证集需要知道图片类别。data├── train│├
文章目录
1 分类数据集准备
期待的分类数据集样式如下,注意,验证集需要知道图片类别。
data
├── train
│ ├── class_name_1
│ │ ├── 1.jpg
│ │ └── 2.jpg
│ └── class_name_2
│ ├── 1.jpg
│ └── 2.jpg
| ....
└── val
│ ├── class_name_1
│ │ ├── 1.jpg
│ └── class_name_2
│ ├── 1.jpg
| ....
以花分类数据集(下方给出数据集下载链接)为例,在拿到所有图片后,需要进行数据集格式的调整。
链接:https://pan.baidu.com/s/16T3ycHeID0Y06JTLcBAXRQ
提取码:ynrj
数据集解压后,里面每个文件夹名就是其内部图片的类别。
在文件夹里面新建split_train_val_data.py
,
运行split_train_val_data.py
后,可以得到期待的数据集存放格式,同时得到一个类别索引class_indices.json
文件,便于查看数字与类别对应关系,记得把class_indices.json
文件复制到mobilenetv2
文件夹下,便于使用。
split_train_val_data.py
内容如下:
import os
import random
import json
import shutil # 用于复制图片
def read_split_data(root: str, val_rate: float = 0.2):
random.seed(0) # 保证随机结果可复现
assert os.path.exists(root), "dataset root: {} does not exist.".format(root)
# 遍历文件夹,一个文件夹对应一个类别
# ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflowers':3, 'tulips':4}
class_indices = dict((k, v) for v, k in enumerate(flower_class))
# json.dumps:将一个Python数据结构转换为JSON,生成的是字符串
# indent:参数根据数据格式缩进显示,读起来更加清晰。
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
# 会自动新建class_indices.json文件,往里面写入内容
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_path = os.path.join(root, cla)
# 遍历获取supported支持的所有文件路径
# os.path.splitext:分离文件名与扩展名;默认返回(fname,fextension)元组
images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
# 按比例随机采样验证样本
# random.sample:返回 k 长度从序列imagess中选择的新元素列表
val_path = random.sample(images, k=int(len(images) * val_rate))
new_val_path = 'D:/DeepLearning/classification/mobilenetv2/data/val/{0}'.format(cla)
if os.path.exists(new_val_path) is False:
os.makedirs(new_val_path)
new_train_path = 'D:/DeepLearning/classification/mobilenetv2/data/train/{0}'.format(cla)
if os.path.exists(new_train_path) is False:
os.makedirs(new_train_path)
for img_path in images:
if img_path in val_path: # 如果该路径在采样的验证集样本中则存入验证集
shutil.copy(img_path, new_val_path)
else: # 否则存入训练集
shutil.copy(img_path, new_train_path)
if __name__ =="__main__":
read_split_data(root="D:/DeepLearning/classification/mobilenetv2/data/flower_photos", val_rate=0.2)
2 获取训练与验证图片路径及标签
有了图片,我们得会获取图片的路径和标签,这样才能读取图片,训练模型。
在utils.py
中写了一个read_data(root: str)
函数,可以获取训练与验证图片路径及类别标签,内容及解析如下:
import os
def read_data(root: str):
root_train = root + '/train'
root_val = root + '/val'
assert os.path.exists(root_train), "dataset root_train: {} does not exist.".format(root_train)
assert os.path.exists(root_val), "dataset root_val: {} does not exist.".format(root_val)
# 遍历训练文件夹,其下:一个文件夹对应一个类别
# ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
flower_class = [cla for cla in os.listdir(root_train) if os.path.isdir(os.path.join(root_train, cla))]
# 排序,保证顺序一致
flower_class.sort()
# 生成类别名称以及对应的数字索引
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflowers':3, 'tulips':4}
class_indices = dict((k, v) for v, k in enumerate(flower_class))
# json.dumps:将一个Python数据结构转换为JSON,生成的是字符串
# indent:参数根据数据格式缩进显示,读起来更加清晰。
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
# 会自动新建class_indices.json文件,往里面写入内容
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
print("class_indices.json has been written!")
train_images_path = [] # 存储训练集的所有图片路径
train_images_label = [] # 存储训练集图片对应索引信息
val_images_path = [] # 存储验证集的所有图片路径
val_images_label = [] # 存储验证集图片对应索引信息
every_class_train_num = [] # 存储每个类别的训练样本总数
every_class_val_num = [] # 存储每个类别的验证样本总数
supported = [".jpg", ".JPG", ".png", ".PNG"] # 支持的文件后缀类型
# 遍历每个文件夹下的文件
for cla in flower_class:
cla_train_path = os.path.join(root_train, cla)
# 遍历获取supported支持的所有文件路径
# os.path.splitext:分离文件名与扩展名;默认返回(fname,fextension)元组
images_train = [os.path.join(root_train, cla, i) for i in os.listdir(cla_train_path)
if os.path.splitext(i)[-1] in supported]
# 获取该类别对应的索引,image_class:0~4
image_class = class_indices[cla]
# 记录该类别的训练样本数量
every_class_train_num.append(len(images_train))
train_images_path += images_train
train_images_label += [image_class] * len(images_train)
cla_val_path = os.path.join(root_val, cla)
images_val = [os.path.join(root_val, cla, i) for i in os.listdir(cla_val_path)
if os.path.splitext(i)[-1] in supported]
every_class_val_num.append(len(images_val))
val_images_path += images_val
val_images_label += [image_class] * len(images_val)
print("{} images were found in the train dataset.".format(sum(every_class_train_num)))
print("{} images were found in the val dataset.".format(sum(every_class_val_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
return train_images_path, train_images_label, val_images_path, val_images_label
3 Dataset类与DataLoader类的理解
神经网络需要数据传入才能进行训练等操作,那怎样才能把图片以及标签信息整合成神经网络正规输入的格式呢?
回答: pytorch 数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练
因此,我们需要先了解一些Dataset 和 DataLoader 的基础知识。
代码中经常看到这两行,那Dataset和DataLoader到底是什么玩意?
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
3.1 Dataset类
Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
当用户想要加载自定义的数据时,只需要继承这个类,并且覆写其中的两个方法即可:
__len__
:实现len(dataset),返回整个数据集的大小。__getitem__
:用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。
不覆写这两个方法会直接返回错误。
简单示例:
class MyDataset(Dataset):
def __init__(self, images_path: list, images_class: list, transform=None):
...
def __len__(self):
...
def __getitem__(self, index):
...
本项目中,新建一个my_dataset.py
,里面存放自己写的 MyDataset 类,如下:
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_class: list, transform=None):
self.images_path = images_path
self.images_class = images_class
self.transform = transform # 用于对图片进行一些处理
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
label = self.images_class[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
# 用于DataLoader中
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0) # torch.stack把图片打包成batch
labels = torch.as_tensor(labels)
return images, labels
3.2 DataLoader类
DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。
- dataloader本质上是一个可迭代对象,使用iter()访问,不能使用next()访问;
- 使用 iter(dataloader) 返回的是一个迭代器,然后可以使用next访问;
- 一般使用
for inputs, labels in dataloaders
进行可迭代对象的访问;
DataLoader参数介绍:
class torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None, # <function default_collate>
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None)
部分关键参数含义:
- batch_size:每个batch的大小
- shuffle:在每个epoch开始的时候,是否对数据进行重新排序
- num_workers:加载数据的时候使用几个子进程,0意味着所有的数据都会被load进主进程。(默认为0)
- collate_fn:如何取样本,可以自己定义函数来准确地实现想要的功能
- drop_last:告诉如何处理数据集长度除以batch_size 余下的数据。True就抛弃,否则保留
3.3 Dataset与DataLoader综合使用简单示例
最朴实的情况:
dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
for img, label in dataloader:
....
4 MobileNetV2介绍
MobileNetV2网络结构及代码介绍,欢迎查看我的另一篇文章:MobileNetV2网络结构详解并获取网络计算量与参数量
5 训练总体流程
使用在Imagenet上的预训练权重进行迁移训练,总体流程可参考之前写的文章:Imagenet上的模型预训练权重用到CIFAR10上,这里也给出train.py
的代码:
import os
import math
import argparse
import torch
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import mobilenet_v2 as create_model
from my_dataset import MyDataSet
from utils import read_data, train_one_epoch, evaluate
def main(args):
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(args)
if os.path.exists("./output") is False:
os.makedirs("./output")
# 获取训练与验证图片路径及标签,均是列表形式
train_images_path, train_images_label, val_images_path, val_images_label = read_data(args.data_path)
img_size = {"v2": 224}
num_model = "v2"
# train和val预处理函数
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
transforms.RandomHorizontalFlip(), # 水平方向随机翻转
transforms.ToTensor(), # 转化成tensor,数值从0~255,变成0~1
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), # mean std
"val": transforms.Compose([transforms.Resize(img_size[num_model]),
transforms.CenterCrop(img_size[num_model]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
# 实例化训练数据集
train_dataset = MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transform=data_transform["train"])
# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transform=data_transform["val"])
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
# 实例化模型
model = create_model(num_classes=args.num_classes).to(device)
# 如果存在预训练权重则载入
if args.weights != "":
if os.path.exists(args.weights):
# -----------------------------------------------------------------------------------------#
# 由于预训练权重是在Imagenet上的,类别数1000,而花分类数据集只有5类,故此处分类器权重不进行加载
# 训练过程中会出现如下提示,正常,不用管
# _IncompatibleKeys(missing_keys=['classifier.1.weight', 'classifier.1.bias'], unexpected_keys=[])
# -----------------------------------------------------------------------------------------#
weights_dict = torch.load(args.weights, map_location=device)
load_weights_dict = {k: v for k, v in weights_dict.items()
if model.state_dict()[k].numel() == v.numel()}
print(model.load_state_dict(load_weights_dict, strict=False))
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))
# 是否冻结权重
if args.freeze_layers:
for name, para in model.named_parameters():
# 除最后一个卷积层和全连接层外,其他权重全部冻结
if ("features.top" not in name) and ("classifier" not in name):
para.requires_grad_(False)
else:
print("training {}".format(name))
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
for epoch in range(args.epochs):
# train
mean_loss = train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
# validate
acc = evaluate(model=model,
data_loader=val_loader,
device=device)
print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
torch.save(model.state_dict(), "./output/model-{}.pth".format(epoch))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--num_classes', type=int, default=5)
parser.add_argument('--epochs', type=int, default=30)
parser.add_argument('--batch-size', type=int, default=16)
parser.add_argument('--lr', type=float, default=0.01)
parser.add_argument('--lrf', type=float, default=0.01)
# 数据集所在目录
# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
parser.add_argument('--data-path', type=str,
default="./data")
# download model weights
# https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
parser.add_argument('--weights', type=str, default='./pretrained/mobilenet_v2-b0353104.pth',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=False)
opt = parser.parse_args()
main(opt)
输出:
(python36) cylab@amax:/data/wyx/mobilenetv2_train_onnx$ CUDA_VISIBLE_DEVICES=1 python3 train.py
Namespace(batch_size=16, data_path='/data/wyx/datasets/flower_data/', epochs=30, freeze_layers=False, lr=0.01,
lrf=0.01, num_classes=5, weights='./pretrained/mobilenet_v2-b0353104.pth')
class_indices.json has been written!
2939 images were found in the train dataset.
731 images were found in the val dataset.
2939 images for training.
731 images for validation.
Using 8 dataloader workers every process
_IncompatibleKeys(missing_keys=['classifier.1.weight', 'classifier.1.bias'], unexpected_keys=[])
[epoch 0] mean loss 1.565: 100%|████████████████████████████████████████████████| 184/184 [00:25<00:00, 7.35it/s]
100%|█████████████████████████████████████████████████████████████████| 46/46 [00:03<00:00, 15.04it/s]
[epoch 0] accuracy: 0.607
[epoch 1] mean loss 0.355: 100%|████████████████████████████████████████████████| 184/184 [00:35<00:00, 5.21it/s]
...
6 推理一张图片
训练完成后,从网上找了一张tulip图片进行测试,代码在predict.py
中,其内容如下:
注意: 读取class_indices.json文件时,一定要使用新生成的这个class_indices.json文件,不同系统上的排序可能不同。
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import efficientnet_b0 as create_model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# load image
img_path = "./data/tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=5).to(device)
# load model weights
model_weight_path = "./output/model-29.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu() # 模型输出,score
predict = torch.softmax(output, dim=0) # 经过softmax转化为概率
predict_cla = torch.argmax(predict).numpy() # 得到最大概率索引
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
结果展示如下:
7 感谢链接
https://www.bilibili.com/video/BV1W7411T7qc/?spm_id_from=333.788
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)