DataLoader详解
torchvision中datasets中所有封装的数据集都是torch.utils.data.Dataset的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。
在深度学习加载模型的时候,会对数据进行处理,今天主要介绍pytorch中Dateset和DataLoader的使用方法。
目录
1.torch.utils.data里面的dataset使用方法
一、基础概念
1.torch.utils.data.datasets-抽象类可以创建数据集,但是抽象类不能实例化,所以需要构建这个抽象类的子类来创建数据集,并且我们还可以定义自己的继承和重写方法。其中最重要的是len和getitem这两个函数,len能够给出数据集的大小,getitem用于查找数据和标签。
2.torch.utils.data.DataLoader是一个迭代器,主要是用于多线程的读取数据,并且可以实现batch和shuffle的读取。
二、Dataset使用方法
1.torch.utils.data里面的dataset使用方法
当我们继承了一个Dataset类之后,我们需要重写里面的len方法,该方法提供了dataset的大小,getitem(),该方法支持从0-len(self)的索引。
from torch.utils.data import Dataset, DataLoader
import torch
class MyDataset(Dataset):
def __init__(self):
self.x = torch.linspace(11, 20, 10)
self.y = torch.linspace(1, 10, 10)
self.len = len(self.x)
def __getitem__(self, index):
return self.x[index], self.y[index]
def __len__(self):
return self.len
mydataset = MyDataset()
train_loader2 = DataLoader(dataset=mydataset,batch_size=5,shuffle=False)
2.torchvision.datasets的使用方法
torchvision
中datasets
中所有封装的数据集都是torch.utils.data.Dataset
的子类,它们都实现了__getitem__和__len__方法。因此,它们都可以用torch.utils.data.DataLoader进行数据加载。
import torchvision
import torch
# 导入FashionMNIST数据集
mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=transforms.ToTensor())
train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
三、DateLoader详解
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
各个参数的介绍:
1.dataset(Dataset): 传入的数据集
2.batch_size(int, optional): 每个batch有多少个样本
3.shuffle(bool, optional): 在每个epoch开始的时候,对数据进行重新排序
4.sampler(Sampler, optional): 自定义从数据集中取样本的策略,如果指定这个参数,那么 shuffle必须为False5.batch_sampler(Sampler, optional): 与sampler类似,但是一次只返回一个batch的indices(索引),需要注意的是,一旦指定了这个参数,那么batch_size,shuffle,sampler,drop_last 就不能再制定了(互斥——Mutually exclusive)
6.num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所 有的数据都会被load进主进程。(默认为0)
7.collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数8.pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前, 将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
9.drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如 你的batch_size设置为64,而一个epoch只有100个样本,那么训练的时候后面的36个就被 扔掉了…如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。10.timeout(numeric, optional): 如果是正数,表明等待从worker进程中收集一个batch等待的 时间,若超出设定的时间还没有收集到,那就不收集这个内容了。这个numeric应总是 大于等于0。默认为0
11.worker_init_fn (callable, optional): 每个worker初始化函数 If not None, this will be called on each
# 处理数据集,把数据转换成张量,使数据可以输入下面我们搭建的网络
def load_data_fashion_mnist(mnist_train, mnist_test, batch_size):
if sys.platform.startswith('win'):
num_workers = 0
else:
num_workers = 4
train_data = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_data = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
return train_data, test_data
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)