PyTorch:Dataset()与Dataloader()的使用详解
目录1、Dataset类的使用2、Dataloader类的使用3、总结Dataset类与Dataloader类是PyTorch官方封装的用于在数据集中提取一个batch的训练用数据的接口,其实我们也可以自定义获取每个batch的方法,但是对于大数据量的数据集,直接用封装好的接口会很大程度上提升效率。一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Da
目录
Dataset类与Dataloader类是PyTorch官方封装的用于在数据集中提取一个batch的训练用数据的接口,其实我们也可以自定义获取每个batch的方法,但是对于大数据量的数据集,直接用封装好的接口会很大程度上提升效率。
一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责在整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。
1、Dataset类的使用
Dataset用以整理数据集。我们整理数据的目的是为了Dataloader可以方便的从整理后的和数据中获取一个batch的数据来供网络进行训练。
先看一下官方的Dataset的源码:
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
很明显,这个类内部什么方法的实现都没有,就是用来让我们继承重写的。当我们继承该类时,必须重写里面的__getitem__(self, index)方法。该方法定义了使用索引值来查找元素的方法,即假如我们定义一个自己的训练数据集实例traindata,如果想使用traindata[index]的方式来获取索引为index的数据,我们就得实现__getitem__方法。这样当我们调用traindata[index]索引数据时,其实就是自动调用__getitem__(self, index)方法来实现的。另外,我们还可以重写__len__(self)方法,用以使用len(traindata)方法来获取我们整个数据集的数量。如果还不清楚,可以细细品一下下面的例子:
class TrainData(Dataset): # 继承Dataset类并重写相关的方法
...
def __getitem__(self, index):
'''编写自己的数据获取方式'''
return [x_data, y_lable]
def __len__(self):
'''编写获取数据集大小的实现方式'''
return length
traindata = TrainData(mydataset) # 定义一个实例
first = traindata[0] # 获取数据集中的第一组数据,会自动调用__getitem__
length = len(traindata) # 获取数据集的数据量的方法,会自动调用__len__
2、Dataloader类的使用
整体上来说,Dataloader类就是从上面封装好的数据中按照给定的方式来一次一次地抽取一个batch的数据来供网络进行训练,其内部使用的是yield生成器机制。Dataloader不用继承重写,我们直接实例化就行。下面我们接着上面的例子来继续了解下Dataloader从数据集中取出一个batch数据的过程:
首先,定义一个Dataloader实例gen_train:
gen_train = Dataloader(traindata, batch_size=4, num_workers=4, pin_memory=True, drop_last=True, collate_fn=my_collate_fn)
关于有关参数的说明(没用到的参数就不解释了):
1、traindata(Dataset): 传入的数据集,按自己定义的Dataset实例名来传入,我这里是traindata
2、batch_size(int, optional): 每个batch有多少个样本
3、num_workers (int, optional): 这个参数决定了有几个进程来处理data loading。0意味着所有的数据都会被load进主进程。(默认为0)
4、pin_memory (bool, optional): 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中.
5、drop_last (bool, optional): 如果设置为True:这个是对最后的未完成的batch来说的,比如你的batch_size设置为4,而一个epoch只有100个样本,那么训练的时候后面的2个因为不满足组成一个batch就被扔掉了。如果为False(默认),那么会继续正常执行,只是最后的batch_size会小一点。
6、collate_fn (callable, optional): 将一个list的sample组成一个mini-batch的函数
可以看到,gen_train从traindata中返回的是一个含有batch_size(4)个数据([x_data, y_label])的mini_batch。
下面我们分析分析这个过程是咋实现的。首先,DataLoader(object)源码中有下面这么一段代码:
。。。。。。
if sampler is None: # give default samplers
if self.dataset_kind == _DatasetKind.Iterable:
# See NOTE [ Custom Samplers and IterableDataset ]
sampler = _InfiniteConstantSampler()
else: # map-style
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
。。。。。。
按照上面的设置,sampler默认是None,我们没有定义要打乱数据(即shuffle为False),则接下来会调用
sampler = SequentialSampler(dataset)
再来看看这个方法是怎么实现的:
class SequentialSampler(Sampler):
r"""Samples elements sequentially, always in the same order.
Arguments:
data_source (Dataset): dataset to sample from
"""
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
主要看__iter__部分,明显的,假设数据集共有n个数据,这是一个返回的sampler就是数据集长度[0,1,2,......,n-1]序号的迭代器。关于怎么迭代,我们回到DataLoader(object)源码中继续往先看,会发现这么几条代码:
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
首先说一下,这个代码就是从上一步的迭代器sampler中取出batch_size个序号,batch_size之前我们设置的是4,所以就是取出4个序号(索引),用以后面从traindata中取出batch_size个数据,来看一下BatchSampler方法的迭代方式的实现,注意这里的yield机制:
class BatchSampler(Sampler):
。。。。。。
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
。。。。。。
所以,到这里我们一个batch_size的数据的索引就已经有了,后面就是调用多线程或单线程机制来取出对应的数据traindata[i]了。回到DataLoader(object)源码中,在往下看,就是下面这段代码了:
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
这段代码就是DataLoader的迭代器的实现方式了,具体的单多线程实现就不详细展开了。此时我们已经完成了获取本次迭代所需要的数据的索引值,接下来即使按照索引在traindata中找到相应的数据并一起返回这个mini_batch了。比如我们可以这样获取数据并用于训练:
for iteration, batch in enumerate(gen_train):
if iteration >= epoch_size: # 判断是否到达一个epoch的迭代次数(len(traindata)/batchsize)
break
x_datas, y_labels= batch[0], batch[1] # 获取batch中的数据和标签,用于训练
......
我们就可以使用这批数据进行一次网络的训练了,这么周而复始,直至达到我们设置的epoch。
3、总结
一般情况下,Dataset类与Dataloader类是配合着使用的,Dataset负责整理数据,Dataloader负责从Dataset整理好的数据中按照一定的规则取出batch_size个数据来供网络训练使用。
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)