(第三篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
前言:系列文章的前面两篇文章已经很明确的说明了如何使用DataSet类和DataLoader类,而且第二篇文章中详细介绍了DataLoader类中的几个重要的常用的参数,如sampler参数、collate_fn参数,但是在数据与处理的过程中,还会遇到数据增强、数据裁剪等各种操作,当然这些操作我们可以预先自己来实现,但是pytorch提供了强大的处理工具来对图像进行预处理,这也是本文的重点...
前言:系列文章的前面两篇文章已经很明确的说明了如何使用DataSet类和DataLoader类,而且第二篇文章中详细介绍了DataLoader类中的几个重要的常用的参数,如sampler参数、collate_fn参数,但是在数据与处理的过程中,还会遇到数据增强、数据裁剪等各种操作,当然这些操作我们可以预先自己来实现,但是pytorch提供了强大的处理工具来对图像进行预处理,这也是本文的重点,详细介绍 torchvision中的transform操作。系列文章前面两篇为:
(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
一、transform模块简介
专门负责图像预处理、实现图像增强的模块,在使用的过程中是与DataSet结合起来使用的,第一篇文章中已经说明了,我们在重写__getitem__的时候,会将数据增强的的代码也放在里面。
看一个简单的例子,本次的例子依然是以第一篇文章中的例子而言的:
class LaneDataSet(Dataset):
def __init__(self, dataset, transform):
pass # 省略了
def __len__(self):
return len(self._gt_img_list)
def __getitem__(self, idx):
# 前面的部分省略了
# 可选参数 transformations,裁剪成(256,512)
if self.transform:
img = self.transform(img)
label_binary_img = self.transform(label_binary_img)
label_instance_img = self.transform(label_instance_img)
img = img.reshape(img.shape[2], img.shape[0], img.shape[1]) #(3,720,1280) 这里都没有问题
return (img, label_binary_img, label_instance_img)
然后再定义DataSet的时候会传递进去transform参数,如下:
from torchvision import transforms
from torchvision.transforms import Resize
dataset = LaneDataSet(train_dataset_file,transform=transforms.Compose([Resize((512, 256))]))
dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
上面的预处理只是实现了一种变换Resize,我也可以同时实现多种变换:
# 将多个transform组织在一起
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
normalize
])
dataset=MyDataSet('一系列参数',transform=transform)
于是我们总结出“transform+DataSet”的一般使用步骤:
(1)第一步:定义自己的DataSet,并重写__getitem__,在里面实现关键的transform操作:
import torch
from torch.utils.data import Dataset, DataLoader
class LaneDataSet(Dataset):
def __init__(self, dataset, transform):
pass
def __getitem__(self, idx):
# 这里是作为选择性参数使用,如果有这个参数,则进行数据增强,当然我也可以不添加这个参数,所有的都默认使用transform操作
if self.transform:
img = self.transform(img) # 这其实是一个函数调用的形式
return img,label
(2)第二步:将多个数据增强方式组合起来合成一个transforms,通过Compose类来实现,注意这个类的返回值哦!
从上面的第一部中可以看出,self.transform(img)实际上是一个函数调用的形式,但是这个transform实际上是Compose类的一个实例,所以在Compose中应该实现了__call__方法,我们查看一下定义:
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def __repr__(self):
从这里可知道,Compose类的确实现了__call__方法,而且可以看出,传递给构造函数的transforms参数应该是一个集合(一般就用列表就可以了),因为后面在__call__里面会遍历这个集合中的每一个元素,这个集合中的每一个元素实际上就是数据增强的类,这一切就都合理了。所以一般Compose类的使用方法为:
transform=transforms.Compose([ # 将需要的操作全部放在一个列表里面
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
])
(3)构造DataSet的对象,将组合起来的transform传递进去。这样就会对每一个batch_size的图像都进行相关的数据增强操作了。
到这一步其实就简单了,只需要在构造DataSet对象的时候传递进去一个参数即可,一般格式如下所示:
dataset = LaneDataSet("参数列表", transform=transform)
二、transform类的简单实用
为了演示整个过程的使用流程,本文依然以tusimple数据集作为例子,也就是前面第一篇中的例子,此处添加几个数据增强的操作,代码如下:
第一步:构建DataSet类,并实现__len__和__getitem__
class LaneDataSet(Dataset):
def __init__(self, dataset, transform):
'''
param:
detaset: 实际上就是tusimple数据集的三个文本文件train.txt、val.txt、test.txt三者的文件路径
transform: 决定是否进行变换,它其实是一个函数或者是几个函数的组合
构造三个列表,存储每一张图片的文件路径
'''
self._gt_img_list = []
self._gt_label_binary_list = []
self._gt_label_instance_list = []
self.transform = transform
with open(dataset, 'r') as file: # 打开其实是那个 training下面的那个train.txt 文件
for _info in file:
info_tmp = _info.strip(' ').split()
self._gt_img_list.append(info_tmp[0])
self._gt_label_binary_list.append(info_tmp[1])
self._gt_label_instance_list.append(info_tmp[2])
assert len(self._gt_img_list) == len(self._gt_label_binary_list) == len(self._gt_label_instance_list)
def __len__(self):
return len(self._gt_img_list)
def __getitem__(self, idx):
assert len(self._gt_label_binary_list) == len(self._gt_label_instance_list) \
== len(self._gt_img_list)
# 读取所有图片
img = cv2.imread(self._gt_img_list[idx], cv2.IMREAD_COLOR) #真实图片 (720,1280,3)
label_instance_img = cv2.imread(self._gt_label_instance_list[idx], cv2.IMREAD_UNCHANGED) # instance图片 (720,1280)
label_binary_img = cv2.imread(self._gt_label_binary_list[idx], cv2.IMREAD_GRAYSCALE) #binary图片 (720,1280)
# 需要注意的是,这里经过变换之后,将numpy数组转化成了 pillow的Image对象
if self.transform:
img = self.transform(img)
label_binary_img = self.transform(label_binary_img)
label_instance_img = self.transform(label_instance_img)
return (img, label_binary_img, label_instance_img) # 这三个都是 Image 对象
第二步:实现“三步走”:
train_dataset_file = 'H:/tusimple_dataset/training/train.txt'
transform=transforms.Compose([transforms.ToPILImage(), # 将ndarray转化成 pillow的Image格式
transforms.Resize((256,512)), # 裁减至(256,512)
transforms.RandomRotation((30,150)), # 随机旋转30至150度
transforms.RandomHorizontalFlip(0.6),# 水平翻转
transforms.RandomVerticalFlip(0.4), # 垂直翻转
transforms.ToTensor()]) #将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],而且会将[w,h,c]转化成pytorch需要的[c,w,h]格式
# 第一步:构造dataset对象
dataset = LaneDataSet(train_dataset_file, transform=transform)
# 第二步:构造dataloader对象
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
# 第三步:迭代 dataloader
t=enumerate(iter(dataloader))
for batch_idx, batch in t:
# 注意 ,这三个数据都是 FloatTensor
image_data = batch[0] # (8,3,256,512) ,之所以通道在前,是因为应用了transforms.ToTensor()
binary_label = batch[1] # [8,1,256,512]
instance_label = batch[2] # (8,1,256,512)
print(np.shape(image_data),np.shape(binary_label),np.shape(instance_label),sep=" ")
第三步:查看数据增强之后的照片
image_data=image_data.reshape((8,256,512,3)) # (8,256,512,3)
binary_label=binary_label.squeeze(1) # (8,256,512)
instance_label=instance_label.squeeze(1) # (8,256,512)
for i in range(8):
plt.figure()
plt.subplot(1,3,1)
plt.imshow(image_data[i])
plt.subplot(1,3,2)
plt.imshow(binary_label[i],cmap="gray")
plt.subplot(1,3,3)
plt.imshow(instance_label[i],cmap="gray")
plt.show()
运行效果如下:
上面的第一幅图,也就是原始图像好像有点问题,具体原因还不太清楚,后面的binary_label和instance_label是没有问题的。
注意事项:
(1)transfroms中的数据增强操作针对的是pillow的Image图像格式,而我们很多时候在使用opencv读进去的又是ndarray格式,所以需要第一步先将ndarray转化成Image格式,即:transforms.ToPILImage().
(2)但是我们后需要的数据又是需要ndarray格式或者是tensor格式,故而有需要将Image转换回来,即:
transforms.ToTensor()。
三、transforms中的图像变换操作大全
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
关于每一种变换方法的具体实用技巧,可以查阅torch的官方文档,也可以参考下面几篇博文。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)