深度学习--数据处理dataloader介绍及代码分析
参考博客DataLoader是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。DataLoader是一个用于批量加载数据的工具,它可以将数据集分成多个小批量,并逐个加载,以适应模型训练的需要。DataLoader主要用于两个关键任务:数据加载和批次处理DataLoader可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动
dataloader
概述
参考博客
DataLoader
是深度学习中重要的数据处理工具之一,旨在有效加载、处理和管理大规模数据集,用于训练和测试机器学习和深度学习模型。
DataLoader
是一个用于批量加载数据的工具,它可以将数据集分成多个小批量(mini-batch)
,并逐个加载,以适应模型训练的需要。
DataLoader
主要用于两个关键任务:数据加载和批次处理
- 数据加载:
DataLoader
可以从不同来源加载数据,如硬盘上的文件、数据库、网络等。它能够自动将数据集划分为小批次,从而减小内存需求,确保数据的高效加载。 - 数据批次处理:每个批次由多个样本组成,可以并行地进行数据预处理和数据增强。这有助于提高模型训练的效率,同时确保每个批次的数据都经过适当的处理。
collate_fn
collate_fn 是一个自定义函数,用于在 PyTorch 的 DataLoader 中定义如何将单个样本组合成一个批次(batch)。具体来说,collate_fn 函数会在每次从 DataLoader 中取出一个批次的数据时被调用,用于对数据进行整理和转换。
主要作用
collate_fn
:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成期望的数据格式。
将一个批次的数据样本整理成适合模型输入的格式,特别是将数据转换为 PyTorch 张量(Tensor),以便于后续的模型训练和推理。
- 自定义数据堆叠:将单个样本组合成一个批次,处理数据的不同形状或类型。
- 数据转换:在批次数据组成之前进行必要的转换操作,例如数据类型转换、数据增强等。
在代码中的使用
在本代码中,unet_dataset_collate 函数就是一个 collate_fn 函数。它的作用是将一个批次的数据样本(图像、PNG 数据和分割标签)整理成适合模型输入的格式。具体步骤包括将数据从列表转换为 NumPy 数组,再转换为 PyTorch 张量。
代码详解
# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
images = []
pngs = []
seg_labels = []
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
return images, pngs, seg_labels
这段代码定义了一个名为 unet_dataset_collate
的函数,用于在 PyTorch 的 DataLoader 中自定义批处理方式。函数将一个批次的数据样本(batch)转换为适合模型输入的格式。
代码解释
__init__函数
在 DataLoader 中,init 函数的主要作用是初始化数据集对象,并为后续的数据加载和处理做好准备。
UnetDataset 类的 init 函数在 DataLoader 中的作用包括:
- 数据集初始化:通过传入的参数(如 annotation_lines、input_shape 等)初始化数据集对象,使其包含所有必要的信息。
- 数据预处理:在初始化过程中,可以对数据进行预处理,如归一化、裁剪等,以便后续的模型训练。
- 数据分割:将数据集分割成训练集和验证集(通过 train 参数),以便在训练过程中进行模型评估。
- 路径管理:通过 dataset_path 参数指定数据集的存储路径,方便数据的加载和管理。
# UnetDataset 类的初始化方法,接受五个参数:annotation_lines、input_shape、num_classes、train 和 dataset_path。
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
# super() 函数用于调用父类的初始化方法。在这里,它调用了 UnetDataset 类的父类的 __init__ 方法,确保父类的初始化逻辑也被执行。这对于继承自其他类的类非常重要。
super(UnetDataset, self).__init__()
# self 代表类的实例。self.annotation_lines 将传入的 annotation_lines 参数赋值给实例属性 annotation_lines
self.annotation_lines = annotation_lines
self.length = len(annotation_lines)
self.input_shape = input_shape
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
解释 super 和 self
- super
super()
函数用于调用父类的方法。在多重继承的情况下,它确保正确调用父类的方法,避免重复调用。这里,它调用了 UnetDataset 类的父类的 init 方法。 - self
self
是类的实例的引用。它用于访问类的属性和方法。在类的方法中,self 必须作为第一个参数传递,以便方法能够访问实例的属性和其他方法。
collate_fn
# DataLoader中collate_fn使用
# 函数定义:net_dataset_collate(batch):定义了一个函数,接收一个批次的数据样本batch。
def unet_dataset_collate(batch):
# 初始化列表:
# images = []:用于存储所有图像数据。
# pngs = []:用于存储所有 PNG 格式的数据。
# seg_labels = []:用于存储所有分割标签数据
images = []
pngs = []
seg_labels = []
# 遍历批次数据:
# 遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。
# images.append(img):将图像数据添加到 images 列表中。
# pngs.append(png):将 PNG 数据添加到 pngs 列表中。
# seg_labels.append(labels):将分割标签数据添加到 seg_labels 列表中。
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
#转换数据类型:
# 将 images 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
# 将 pngs 列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。
# 将 seg_labels 列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
# 返回结果:
# 返回处理后的图像数据、PNG 数据和分割标签数据。
return images, pngs, seg_labels
详细说明
-
函数定义:
unet_dataset_collate(batch)
:定义了一个函数,接收一个批次的数据样本batch
。
-
初始化列表:
images = []
:用于存储所有图像数据。pngs = []
:用于存储所有 PNG 格式的数据。seg_labels = []
:用于存储所有分割标签数据。
-
遍历批次数据:
for img, png, labels in batch:
:遍历批次中的每个样本,假设每个样本包含图像、PNG 数据和分割标签。images.append(img)
:将图像数据添加到images
列表中。pngs.append(png)
:将 PNG 数据添加到pngs
列表中。seg_labels.append(labels)
:将分割标签数据添加到seg_labels
列表中。
-
转换数据类型:
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
:将images
列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。pngs = torch.from_numpy(np.array(pngs)).long()
:将pngs
列表转换为 NumPy 数组,再转换为 PyTorch 的 LongTensor 类型。seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
:将seg_labels
列表转换为 NumPy 数组,再转换为 PyTorch 的 FloatTensor 类型。
-
返回结果:
return images, pngs, seg_labels
:返回处理后的图像数据、PNG 数据和分割标签数据。
完整代码
import os
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class UnetDataset(Dataset):
def __init__(self, annotation_lines, input_shape, num_classes, train, dataset_path):
super(UnetDataset, self).__init__()
self.annotation_lines = annotation_lines
self.length = len(annotation_lines)
self.input_shape = input_shape
self.num_classes = num_classes
self.train = train
self.dataset_path = dataset_path
def __len__(self):
return self.length
def __getitem__(self, index):
annotation_line = self.annotation_lines[index]
name = annotation_line.split()[0]
#-------------------------------#
# 从文件中读取图像
#-------------------------------#
jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "JPEGImages"), name + ".jpg"))
png = Image.open(os.path.join(os.path.join(self.dataset_path, "SegmentationClass"), name + ".png"))
#-------------------------------#
# 数据增强
#-------------------------------#
jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
png = np.array(png)
png[png >= self.num_classes] = self.num_classes
#-------------------------------------------------------#
# 转化成one_hot的形式
# 在这里需要+1是因为voc数据集有些标签具有白边部分
# 我们需要将白边部分进行忽略,+1的目的是方便忽略。
#-------------------------------------------------------#
seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
return jpg, png, seg_labels
def rand(self, a=0, b=1):
return np.random.rand() * (b - a) + a
def get_random_data(self, image, label, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.3, random=True):
image = cvtColor(image)
label = Image.fromarray(np.array(label))
#------------------------------#
# 获得图像的高宽与目标高宽
#------------------------------#
iw, ih = image.size
h, w = input_shape
if not random:
iw, ih = image.size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', [w, h], (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
label = label.resize((nw,nh), Image.NEAREST)
new_label = Image.new('L', [w, h], (0))
new_label.paste(label, ((w-nw)//2, (h-nh)//2))
return new_image, new_label
#------------------------------------------#
# 对图像进行缩放并且进行长和宽的扭曲
#------------------------------------------#
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
scale = self.rand(0.25, 2)
if new_ar < 1:
nh = int(scale*h)
nw = int(nh*new_ar)
else:
nw = int(scale*w)
nh = int(nw/new_ar)
image = image.resize((nw,nh), Image.BICUBIC)
label = label.resize((nw,nh), Image.NEAREST)
#------------------------------------------#
# 翻转图像
#------------------------------------------#
flip = self.rand()<.5
if flip:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
label = label.transpose(Image.FLIP_LEFT_RIGHT)
#------------------------------------------#
# 将图像多余的部分加上灰条
#------------------------------------------#
dx = int(self.rand(0, w-nw))
dy = int(self.rand(0, h-nh))
new_image = Image.new('RGB', (w,h), (128,128,128))
new_label = Image.new('L', (w,h), (0))
new_image.paste(image, (dx, dy))
new_label.paste(label, (dx, dy))
image = new_image
label = new_label
image_data = np.array(image, np.uint8)
#---------------------------------#
# 对图像进行色域变换
# 计算色域变换的参数
#---------------------------------#
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
#---------------------------------#
# 将图像转到HSV上
#---------------------------------#
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
dtype = image_data.dtype
#---------------------------------#
# 应用变换
#---------------------------------#
x = np.arange(0, 256, dtype=r.dtype)
lut_hue = ((x * r[0]) % 180).astype(dtype)
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
return image_data, label
# DataLoader中collate_fn使用
def unet_dataset_collate(batch):
images = []
pngs = []
seg_labels = []
for img, png, labels in batch:
images.append(img)
pngs.append(png)
seg_labels.append(labels)
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
pngs = torch.from_numpy(np.array(pngs)).long()
seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
return images, pngs, seg_labels
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)