前言:系列文章的前面两篇文章已经很明确的说明了如何使用DataSet类和DataLoader类,而且第二篇文章中详细介绍了DataLoader类中的几个重要的常用的参数,如sampler参数、collate_fn参数,但是在数据与处理的过程中,还会遇到数据增强、数据裁剪等各种操作,当然这些操作我们可以预先自己来实现,但是pytorch提供了强大的处理工具来对图像进行预处理,这也是本文的重点,详细介绍 torchvision中的transform操作。系列文章前面两篇为:

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

一、transform模块简介

专门负责图像预处理、实现图像增强的模块,在使用的过程中是与DataSet结合起来使用的,第一篇文章中已经说明了,我们在重写__getitem__的时候,会将数据增强的的代码也放在里面。

看一个简单的例子,本次的例子依然是以第一篇文章中的例子而言的:

class LaneDataSet(Dataset):
    def __init__(self, dataset, transform):
        pass # 省略了
    def __len__(self):
        return len(self._gt_img_list)

    def __getitem__(self, idx):       
        # 前面的部分省略了

        # 可选参数 transformations,裁剪成(256,512)
        if self.transform:
            img = self.transform(img)
            label_binary_img = self.transform(label_binary_img)
            label_instance_img = self.transform(label_instance_img)

        img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) #(3,720,1280) 这里都没有问题
        return (img, label_binary_img, label_instance_img)

然后再定义DataSet的时候会传递进去transform参数,如下:

from torchvision import transforms
from torchvision.transforms import Resize

dataset = LaneDataSet(train_dataset_file,transform=transforms.Compose([Resize((512, 256))]))
   
dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

上面的预处理只是实现了一种变换Resize,我也可以同时实现多种变换:

# 将多个transform组织在一起
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
    transforms.RandomSizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
])

dataset=MyDataSet('一系列参数',transform=transform)

于是我们总结出“transform+DataSet”的一般使用步骤:

(1)第一步:定义自己的DataSet,并重写__getitem__,在里面实现关键的transform操作

import torch
from torch.utils.data import Dataset, DataLoader

class LaneDataSet(Dataset):
    def __init__(self, dataset, transform):
        pass
    def __getitem__(self, idx):
        # 这里是作为选择性参数使用,如果有这个参数,则进行数据增强,当然我也可以不添加这个参数,所有的都默认使用transform操作
        if self.transform:
            img = self.transform(img)  # 这其实是一个函数调用的形式
          
        return img,label

(2)第二步:将多个数据增强方式组合起来合成一个transforms,通过Compose类来实现,注意这个类的返回值哦!

从上面的第一部中可以看出,self.transform(img)实际上是一个函数调用的形式,但是这个transform实际上是Compose类的一个实例,所以在Compose中应该实现了__call__方法,我们查看一下定义:

class Compose(object):

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

    def __repr__(self):
      

从这里可知道,Compose类的确实现了__call__方法,而且可以看出,传递给构造函数的transforms参数应该是一个集合(一般就用列表就可以了),因为后面在__call__里面会遍历这个集合中的每一个元素,这个集合中的每一个元素实际上就是数据增强的类,这一切就都合理了。所以一般Compose类的使用方法为:

transform=transforms.Compose([   # 将需要的操作全部放在一个列表里面
    transforms.RandomSizedCrop(224),  
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
])

(3)构造DataSet的对象,将组合起来的transform传递进去。这样就会对每一个batch_size的图像都进行相关的数据增强操作了。

到这一步其实就简单了,只需要在构造DataSet对象的时候传递进去一个参数即可,一般格式如下所示:

dataset = LaneDataSet("参数列表", transform=transform)

二、transform类的简单实用

为了演示整个过程的使用流程,本文依然以tusimple数据集作为例子,也就是前面第一篇中的例子,此处添加几个数据增强的操作,代码如下:

第一步:构建DataSet类,并实现__len__和__getitem__

class LaneDataSet(Dataset):
    def __init__(self, dataset, transform):
        '''
        param:
            detaset: 实际上就是tusimple数据集的三个文本文件train.txt、val.txt、test.txt三者的文件路径
            transform: 决定是否进行变换,它其实是一个函数或者是几个函数的组合
        构造三个列表,存储每一张图片的文件路径          
        '''
        self._gt_img_list = []
        self._gt_label_binary_list = []
        self._gt_label_instance_list = []
        self.transform = transform

        with open(dataset, 'r') as file:  # 打开其实是那个 training下面的那个train.txt 文件
            for _info in file:
                info_tmp = _info.strip(' ').split()

                self._gt_img_list.append(info_tmp[0])
                self._gt_label_binary_list.append(info_tmp[1])
                self._gt_label_instance_list.append(info_tmp[2])

        assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)
    
    def __len__(self):
        return len(self._gt_img_list)

    def __getitem__(self, idx):
        assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \
               == len(self._gt_img_list)

        # 读取所有图片
        img = cv2.imread(self._gt_img_list[idx], cv2.IMREAD_COLOR) #真实图片 (720,1280,3)
        label_instance_img = cv2.imread(self._gt_label_instance_list[idx], cv2.IMREAD_UNCHANGED) # instance图片 (720,1280)
        label_binary_img = cv2.imread(self._gt_label_binary_list[idx], cv2.IMREAD_GRAYSCALE) #binary图片 (720,1280)

        # 需要注意的是,这里经过变换之后,将numpy数组转化成了 pillow的Image对象
        if self.transform:
            img = self.transform(img)
            label_binary_img = self.transform(label_binary_img)
            label_instance_img = self.transform(label_instance_img)

        return (img, label_binary_img, label_instance_img)  # 这三个都是 Image 对象

第二步:实现“三步走”:

train_dataset_file = 'H:/tusimple_dataset/training/train.txt'

transform=transforms.Compose([transforms.ToPILImage(),             # 将ndarray转化成 pillow的Image格式
                              transforms.Resize((256,512)),        # 裁减至(256,512)
                              transforms.RandomRotation((30,150)), # 随机旋转30至150度
                              transforms.RandomHorizontalFlip(0.6),# 水平翻转
                              transforms.RandomVerticalFlip(0.4),  # 垂直翻转
                              transforms.ToTensor()])              #将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],而且会将[w,h,c]转化成pytorch需要的[c,w,h]格式

# 第一步:构造dataset对象
dataset = LaneDataSet(train_dataset_file, transform=transform)

# 第二步:构造dataloader对象
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# 第三步:迭代 dataloader
t=enumerate(iter(dataloader))
for batch_idx, batch in t:
        
    # 注意 ,这三个数据都是 FloatTensor
    image_data = batch[0]     # (8,3,256,512) ,之所以通道在前,是因为应用了transforms.ToTensor()
    binary_label = batch[1]   # [8,1,256,512]  
    instance_label = batch[2] # (8,1,256,512)  
    
    print(np.shape(image_data),np.shape(binary_label),np.shape(instance_label),sep="    ")

第三步:查看数据增强之后的照片

image_data=image_data.reshape((8,256,512,3)) # (8,256,512,3) 
binary_label=binary_label.squeeze(1)         # (8,256,512) 
instance_label=instance_label.squeeze(1)     # (8,256,512) 

for i in range(8):
    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(image_data[i])

    plt.subplot(1,3,2)
    plt.imshow(binary_label[i],cmap="gray")

    plt.subplot(1,3,3)
    plt.imshow(instance_label[i],cmap="gray")
    
    plt.show()

运行效果如下:

上面的第一幅图,也就是原始图像好像有点问题,具体原因还不太清楚,后面的binary_label和instance_label是没有问题的。

 注意事项:

(1)transfroms中的数据增强操作针对的是pillow的Image图像格式,而我们很多时候在使用opencv读进去的又是ndarray格式,所以需要第一步先将ndarray转化成Image格式,即:transforms.ToPILImage().

(2)但是我们后需要的数据又是需要ndarray格式或者是tensor格式,故而有需要将Image转换回来,即:

transforms.ToTensor()。

三、transforms中的图像变换操作大全

__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
           "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
           "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
           "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]

关于每一种变换方法的具体实用技巧,可以查阅torch的官方文档,也可以参考下面几篇博文。

PyTorch 学习笔记(三):transforms的二十二个方法

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐