1 分类数据集准备

期待的分类数据集样式如下,注意,验证集需要知道图片类别。

data
├── train
│   ├── class_name_1
│   │   ├── 1.jpg
│   │   └── 2.jpg
│   └── class_name_2
│       ├── 1.jpg
│       └── 2.jpg
|   ....
└── val
│   ├── class_name_1
│   │   ├── 1.jpg
│   └── class_name_2
│       ├── 1.jpg
|   ....

以花分类数据集(下方给出数据集下载链接)为例,在拿到所有图片后,需要进行数据集格式的调整。

链接:https://pan.baidu.com/s/16T3ycHeID0Y06JTLcBAXRQ 
提取码:ynrj

数据集解压后,里面每个文件夹名就是其内部图片的类别。
在文件夹里面新建split_train_val_data.py
在这里插入图片描述
运行split_train_val_data.py后,可以得到期待的数据集存放格式,同时得到一个类别索引class_indices.json文件,便于查看数字与类别对应关系,记得class_indices.json文件复制到mobilenetv2文件夹下,便于使用

在这里插入图片描述

split_train_val_data.py内容如下:

import os
import random
import json
import shutil       # 用于复制图片

def read_split_data(root: str, val_rate: float = 0.2):
    random.seed(0)  # 保证随机结果可复现
    assert os.path.exists(root), "dataset root: {} does not exist.".format(root)

    # 遍历文件夹,一个文件夹对应一个类别
    #   ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    
    # 生成类别名称以及对应的数字索引
    #   {'daisy':0, 'dandelion':1, 'roses':2, 'sunflowers':3, 'tulips':4}
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    # json.dumps:将一个Python数据结构转换为JSON,生成的是字符串
    #   indent:参数根据数据格式缩进显示,读起来更加清晰。
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    # 会自动新建class_indices.json文件,往里面写入内容
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_path = os.path.join(root, cla)
        # 遍历获取supported支持的所有文件路径
        #   os.path.splitext:分离文件名与扩展名;默认返回(fname,fextension)元组
        images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)
                  if os.path.splitext(i)[-1] in supported]

        # 按比例随机采样验证样本
        #   random.sample:返回 k 长度从序列imagess中选择的新元素列表
        val_path = random.sample(images, k=int(len(images) * val_rate))

        new_val_path = 'D:/DeepLearning/classification/mobilenetv2/data/val/{0}'.format(cla)
        if os.path.exists(new_val_path) is False:
            os.makedirs(new_val_path)
        new_train_path = 'D:/DeepLearning/classification/mobilenetv2/data/train/{0}'.format(cla)
        if os.path.exists(new_train_path) is False:
            os.makedirs(new_train_path)

        for img_path in images:
            if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集
                shutil.copy(img_path, new_val_path)
            else:  # 否则存入训练集
                shutil.copy(img_path, new_train_path)

if __name__ =="__main__":
    read_split_data(root="D:/DeepLearning/classification/mobilenetv2/data/flower_photos", val_rate=0.2)

2 获取训练与验证图片路径及标签

有了图片,我们得会获取图片的路径和标签,这样才能读取图片,训练模型。
utils.py中写了一个read_data(root: str)函数,可以获取训练与验证图片路径及类别标签,内容及解析如下:

import os

def read_data(root: str):
    root_train = root + '/train'
    root_val = root + '/val'
    assert os.path.exists(root_train), "dataset root_train: {} does not exist.".format(root_train)
    assert os.path.exists(root_val), "dataset root_val: {} does not exist.".format(root_val)

    # 遍历训练文件夹,其下:一个文件夹对应一个类别
    #   ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
    flower_class = [cla for cla in os.listdir(root_train) if os.path.isdir(os.path.join(root_train, cla))]
    # 排序,保证顺序一致
    flower_class.sort()
    
    # 生成类别名称以及对应的数字索引
    #   {'daisy':0, 'dandelion':1, 'roses':2, 'sunflowers':3, 'tulips':4}
    class_indices = dict((k, v) for v, k in enumerate(flower_class))
    
    # json.dumps:将一个Python数据结构转换为JSON,生成的是字符串
    #   indent:参数根据数据格式缩进显示,读起来更加清晰。
    json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
    # 会自动新建class_indices.json文件,往里面写入内容
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    print("class_indices.json has been written!")

    train_images_path = []      # 存储训练集的所有图片路径
    train_images_label = []     # 存储训练集图片对应索引信息
    val_images_path = []        # 存储验证集的所有图片路径
    val_images_label = []       # 存储验证集图片对应索引信息
    every_class_train_num = []  # 存储每个类别的训练样本总数
    every_class_val_num = []    # 存储每个类别的验证样本总数
    supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型
    # 遍历每个文件夹下的文件
    for cla in flower_class:
        cla_train_path = os.path.join(root_train, cla)
        # 遍历获取supported支持的所有文件路径
        #   os.path.splitext:分离文件名与扩展名;默认返回(fname,fextension)元组
        images_train = [os.path.join(root_train, cla, i) for i in os.listdir(cla_train_path)
                  if os.path.splitext(i)[-1] in supported]
        # 获取该类别对应的索引,image_class:0~4
        image_class = class_indices[cla]
        # 记录该类别的训练样本数量
        every_class_train_num.append(len(images_train))
        train_images_path += images_train
        train_images_label += [image_class] * len(images_train)

        cla_val_path = os.path.join(root_val, cla)
        images_val = [os.path.join(root_val, cla, i) for i in os.listdir(cla_val_path)
                  if os.path.splitext(i)[-1] in supported]
        every_class_val_num.append(len(images_val))
        val_images_path += images_val
        val_images_label += [image_class] * len(images_val)

    print("{} images were found in the train dataset.".format(sum(every_class_train_num)))
    print("{} images were found in the val dataset.".format(sum(every_class_val_num)))
    print("{} images for training.".format(len(train_images_path)))
    print("{} images for validation.".format(len(val_images_path)))

    return train_images_path, train_images_label, val_images_path, val_images_label

3 Dataset类与DataLoader类的理解

神经网络需要数据传入才能进行训练等操作,那怎样才能把图片以及标签信息整合成神经网络正规输入的格式呢?

回答: pytorch 数据加载到模型的操作顺序是这样的:
① 创建一个 Dataset 对象
② 创建一个 DataLoader 对象
③ 循环这个 DataLoader 对象,将img, label加载到模型中进行训练

因此,我们需要先了解一些Dataset 和 DataLoader 的基础知识。

代码中经常看到这两行,那Dataset和DataLoader到底是什么玩意?

from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

3.1 Dataset类

Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。

当用户想要加载自定义的数据时,只需要继承这个类,并且覆写其中的两个方法即可:

  1. __len__:实现len(dataset),返回整个数据集的大小。
  2. __getitem__:用来获取一些索引的数据,使dataset[i]返回数据集中第i个样本。

不覆写这两个方法会直接返回错误。

简单示例:

class MyDataset(Dataset):
    def __init__(self, images_path: list, images_class: list, transform=None):
        ...

    def __len__(self):
        ...

    def __getitem__(self, index):
        ...

本项目中,新建一个my_dataset.py,里面存放自己写的 MyDataset 类,如下:

from PIL import Image
import torch
from torch.utils.data import Dataset


class MyDataSet(Dataset):
    """自定义数据集"""

    def __init__(self, images_path: list, images_class: list, transform=None):
        self.images_path = images_path
        self.images_class = images_class
        self.transform = transform		# 用于对图片进行一些处理

    def __len__(self):
        return len(self.images_path)

    def __getitem__(self, item):
        img = Image.open(self.images_path[item])
        # RGB为彩色图片,L为灰度图片
        if img.mode != 'RGB':
            raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))
        label = self.images_class[item]

        if self.transform is not None:
            img = self.transform(img)

        return img, label

    @staticmethod
    def collate_fn(batch):
        # 用于DataLoader中
        images, labels = tuple(zip(*batch))

        images = torch.stack(images, dim=0) # torch.stack把图片打包成batch
        labels = torch.as_tensor(labels)
        return images, labels

3.2 DataLoader类

DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个Batch Size大小的Tensor,用于后面的训练。

  • dataloader本质上是一个可迭代对象,使用iter()访问,不能使用next()访问;
  • 使用 iter(dataloader) 返回的是一个迭代器,然后可以使用next访问;
  • 一般使用for inputs, labels in dataloaders进行可迭代对象的访问;

DataLoader参数介绍:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=None,    # <function default_collate>
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

部分关键参数含义:

  • batch_size:每个batch的大小
  • shuffle:在每个epoch开始的时候,是否对数据进行重新排序
  • num_workers:加载数据的时候使用几个子进程,0意味着所有的数据都会被load进主进程。(默认为0)
  • collate_fn:如何取样本,可以自己定义函数来准确地实现想要的功能
  • drop_last:告诉如何处理数据集长度除以batch_size 余下的数据。True就抛弃,否则保留

3.3 Dataset与DataLoader综合使用简单示例

最朴实的情况:

dataset = MyDataset()
dataloader = DataLoader(dataset)
num_epoches = 100
for epoch in range(num_epoches):
    for img, label in dataloader:
        ....

4 MobileNetV2介绍

MobileNetV2网络结构及代码介绍,欢迎查看我的另一篇文章:MobileNetV2网络结构详解并获取网络计算量与参数量

5 训练总体流程

使用在Imagenet上的预训练权重进行迁移训练,总体流程可参考之前写的文章:Imagenet上的模型预训练权重用到CIFAR10上,这里也给出train.py的代码:

import os
import math
import argparse

import torch
import torch.optim as optim
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import mobilenet_v2 as create_model
from my_dataset import MyDataSet
from utils import read_data, train_one_epoch, evaluate


def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

    print(args)

    if os.path.exists("./output") is False:
        os.makedirs("./output")

    # 获取训练与验证图片路径及标签,均是列表形式
    train_images_path, train_images_label, val_images_path, val_images_label = read_data(args.data_path)

    img_size = {"v2": 224}
    num_model = "v2"

    # train和val预处理函数
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model]),
                                     transforms.RandomHorizontalFlip(),     # 水平方向随机翻转
                                     transforms.ToTensor(),     # 转化成tensor,数值从0~255,变成0~1
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),  # mean std
        "val": transforms.Compose([transforms.Resize(img_size[num_model]),
                                   transforms.CenterCrop(img_size[num_model]),
                                   transforms.ToTensor(),       
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}    

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 实例化模型
    model = create_model(num_classes=args.num_classes).to(device)

    # 如果存在预训练权重则载入
    if args.weights != "":
        if os.path.exists(args.weights):
            # -----------------------------------------------------------------------------------------#
            # 由于预训练权重是在Imagenet上的,类别数1000,而花分类数据集只有5类,故此处分类器权重不进行加载
            #   训练过程中会出现如下提示,正常,不用管
            #   _IncompatibleKeys(missing_keys=['classifier.1.weight', 'classifier.1.bias'], unexpected_keys=[])
            # -----------------------------------------------------------------------------------------#
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}

            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除最后一个卷积层和全连接层外,其他权重全部冻结
            if ("features.top" not in name) and ("classifier" not in name):
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        mean_loss = train_one_epoch(model=model,
                                    optimizer=optimizer,
                                    data_loader=train_loader,
                                    device=device,
                                    epoch=epoch)

        scheduler.step()

        # validate
        acc = evaluate(model=model,
                       data_loader=val_loader,
                       device=device)

        print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))
        torch.save(model.state_dict(), "./output/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=30)
    parser.add_argument('--batch-size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)
    # 数据集所在目录
    # https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
    parser.add_argument('--data-path', type=str,
                        default="./data")
    # download model weights
    # https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
    parser.add_argument('--weights', type=str, default='./pretrained/mobilenet_v2-b0353104.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    
    opt = parser.parse_args()

    main(opt)

输出:

(python36) cylab@amax:/data/wyx/mobilenetv2_train_onnx$ CUDA_VISIBLE_DEVICES=1 python3 train.py
Namespace(batch_size=16, data_path='/data/wyx/datasets/flower_data/', epochs=30, freeze_layers=False, lr=0.01, 
lrf=0.01, num_classes=5, weights='./pretrained/mobilenet_v2-b0353104.pth')
class_indices.json has been written!
2939 images were found in the train dataset.
731 images were found in the val dataset.
2939 images for training.
731 images for validation.
Using 8 dataloader workers every process
_IncompatibleKeys(missing_keys=['classifier.1.weight', 'classifier.1.bias'], unexpected_keys=[])
[epoch 0] mean loss 1.565: 100%|████████████████████████████████████████████████| 184/184 [00:25<00:00,  7.35it/s]
100%|█████████████████████████████████████████████████████████████████| 46/46 [00:03<00:00, 15.04it/s]
[epoch 0] accuracy: 0.607

[epoch 1] mean loss 0.355: 100%|████████████████████████████████████████████████| 184/184 [00:35<00:00,  5.21it/s]
...

6 推理一张图片

训练完成后,从网上找了一张tulip图片进行测试,代码在predict.py中,其内容如下:
注意: 读取class_indices.json文件时,一定要使用新生成的这个class_indices.json文件,不同系统上的排序可能不同。

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import efficientnet_b0 as create_model


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    data_transform = transforms.Compose(
        [transforms.Resize(224),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    img_path = "./data/tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    plt.imshow(img)
    # [N, C, H, W]
    img = data_transform(img)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # read class_indict
    json_path = './class_indices.json'
    assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # create model
    model = create_model(num_classes=5).to(device)
    # load model weights
    model_weight_path = "./output/model-29.pth"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()     # 模型输出,score
        predict = torch.softmax(output, dim=0)                  # 经过softmax转化为概率
        predict_cla = torch.argmax(predict).numpy()             # 得到最大概率索引

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    plt.title(print_res)
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
    plt.show()


if __name__ == '__main__':
    main()

结果展示如下:
在这里插入图片描述

7 感谢链接

https://www.bilibili.com/video/BV1W7411T7qc/?spm_id_from=333.788
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐