本文主要是作者对小样本类增量学习的一知半解,如果各位读者发现什么不对或者疑惑的地方,欢迎大家评论区展开舌战!!!😍😍😍💕💕💕

1.什么是few-shot?

        在深度学习中,各个任务的实现必不可少的三个条件分别是:网络模型,数据集,训练策略。(请允许笔者稍微拓展一下,考虑到有很多求知若渴的萌新刚入门深度学习,希望能让读者理解的更加透彻✌️✌️✌️)

1.1网络模型

        在学术领域,用的最多的当属何大神提出的残差网络结构了,可谓是开辟了深度学习的新时代!!残差网络比如:Resnet-18, Resnet-50就是使用不同的残差块构成的,残差块的结构如图1所示:

图1  残差块

为什么要叫残差块呢? 

        输入为x,经过网络的卷积,激活等操作后得到F(x),最后和恒等映射的x相加得到残差快的输出,就相当于网络学的参数主要决定F(x),这个就称为残差,表示和输入有多大的偏差!!!

那么,残差网络究竟解决了什么问题呢?

        首先,随着网络层数的增加,网络的特征提取能力越强,从而越具有区分性,同时模型的复杂度会增加,当某个样本有很小的梯度时在梯度的反传过程中就会越来越小,造成梯度消失,而加入了额外的直连层(如图1中的identity),那么梯度在反传的过程中不管求导后梯度多小,梯度会保有直流梯度,从而防止梯度消失。其次,网络越复杂,有些数据不需要这么复杂的网络即可很好的学习你想获得的知识,那么网络深层就会退化,此时残差网络的作用就来了!!!他会让网络深层计算得到的F(x)为0,那么这一层的输出就和输入完全相等,从而解决网络退化问题!!!是不是很神奇呢!!!🤦‍♂️🤦‍♂️🤦‍♂️

        最近大火的GPT,transformer大模型等结构也越来越多的应用于深度学习了,对transformer感兴趣的小伙伴可以阅读这篇博客。史上最小白之Transformer详解

笔者觉得讲的非常透彻,非常适合小白去了解transformer结构。

1.2数据集

        数据集就是你具体想实现的任务的各种数据样本的集合。比如图像,语音等数据集。扯了半天终于要讲我们的主题Few-shot了,在深度学习中,想要学到一个非常好的网络,数据集是非常重要的,直接用原始数据去训练网络,训练效果往往会比较差,需要对数据做一些规范的处理,如归一化,裁剪,旋转等操作来让网络对数据有更好的拟合效果;同时数据集的样本数量需要很多,现在的数据集的样本数量基本上都是上万的,如果是大模型需要训练的数据又会更多!!所以Few-shot就是假设我数据集中的训练样本就是比较少的,比如图像分类中每一类就只有五个样本!!!用小样本数据集训练网络往往会造成网络的过拟合即模型在训练集的拟合效果很好,但在测试集上的拟合效果差,这也是小样本类增量学习要解决的难题之一!!!

1.3训练策略

        最后就是你的训练策略了,包括网络模型中的各个参数的初始化,学习率的大小,每个训练批次训练的样本数量(batchsize),损失函数的选择以及各个超参数的确定等等。这些对于网络的学习还是有很大的影响的。

2.什么是类增量学习?

        类增量学习和终身学习,学会学习有点类似,比如分类任务中,当网络已经对已知类有良好的分类能力后,如果出现新的类别,那么网络也能利用新类的样本去更新模型,从而在不遗忘已学会的类别的前提下,对新的类别也有良好的识别能力。增量学习通俗来说就是网路模型需要不断的去更新,去进化!!!由于新类别的加入,模型只能利用新类的样本去更新参数,从而导致网络在更新的过程中遗忘已学会的类别,造成灾难性遗忘问题,这就是小样本类增量学习要解决的另外一个难题!!!

3.小样本类增量学习

3.1那么小样本类增量学习又是什么呢?        

        首先,网络的学习主要分为两个阶段,基类阶段和增量阶段。

        基类阶段:在这个阶段训练集的样本数量(往往每一类有几百个样本)是比较大的,同时数据集的类别数也比较多(几十类)。因此可以学习到对于这些类别有优异识别能力的网络。

        增量阶段:在这个阶段训练集的样本数量非常有限(每一类5个或者更少),同时包含的类别数也比较少(一般为5类),即5-way5-shot,每个增量阶段包含5类,每类5个样本。小样本也就主要体现在增量阶段。

3.2主要解决的难题

        前文在介绍深度学习任务中需要具备的几个条件时也简单的提了一嘴,主要是灾难性遗忘和网络的过拟合问题。

3.2.1灾难性遗忘

 

图2 灾难性遗忘与过拟合问题

        灾难性遗忘是指模型在更新的阶段,由于只能通过少量的新类样本来更新网络,导致在学习的过程中网络会更加倾向于分类新类,从而对旧类的分类效果越来越差,造成灾难性遗忘如图2(a)所示 ,随着epoch的增加,测试集上对旧类的准确率越来越低。

3.2.2网络的过拟合 

        由于增量阶段是小样本的训练策略,每一类的样本数量很少,网络在训练的过程中能对这些少量的新类样本很好的分类,但在测试集上对这类的分类效果会大打折扣 ,造成模型的过拟合问题,如图2(b)所示,随着训练的进行,模型在训练集上和测试集上的分类准确率相差很大,在训练集上能达到100%的准确率,但在测试集上只有50%左右。

3.3解决方法

        对于小样本类增量学习,各大期刊会议上也发表了很多的论文,各种方法层出不穷,我认为解决方法主要分为两个大的方向:对特征提取器的设计或者是对分类器的设计。也可以说是侧重于基类阶段或者是侧重于增量阶段

        因为一个常见的分类系统通常分为一个特征提取器分类器,特征提取器将你要处理的信号通过深度神经网络提取成特征(embedding),而分类器将特征映射成概率分布(logits),表示该样本属于各个类别的概率,取概率最大的作为最终的判决结果。

        而在小样本类增量学习中,比较流行的一个pipeline(意思就是对于某个任务的整个操作流程,类似于流水线)就是在基类阶段,根据含有大量样本的基类训练一个对这些基类有优秀特征提取能力的特征提取器和优异分类性能的分类器,而在增量阶段,将特征提取器和分类器解耦,冻结特征提取器的参数,利用含有小样本的新类去更新分类器的参数,但是更新的过程中旧类的分类器参数也是冻结的。由于特征提取器是需要提取有特征表达能力的embedding,既然已经在基类阶段训练了一个优秀的特征提取器,那么在增量阶段就没有必要去更新特征提取器,利用小样本去更新特征提取器反而会损害它的特征提取能力。

        接下来介绍一个在学术界得到认可的小样本类增量学习策略,大名鼎鼎的原型网络登场了!!!🙌🙌🙌

        首先要知道什么是原型,前文提到,特征提取器能够将待处理的样本提取成具有特征表达能力的embedding,那么数据集中每一类有多少个样本就能提取出多少个embedding,而对每一类的所有embedding取均值,就得到了每一类的原型prototype,可以理解成每一类的聚类中心,然后利用这个原型去进行分类,让网络去学习这个原型,将测试集的每个样本提取到的embedding和每一类的原型计算一个欧氏距离,最小的值表示这个样本和这一类最近,那么就判别其为之一类别。

        最后介绍CVPR2021发表的一篇论文,也是利用到了原型的概念,在小样本类增量学习中引用次数达到了4万+,感兴趣的可以去读一读。广为人知的CEC!!!

        我也简单介绍以下CEC,首先看一下他的整个网络框架图。

 图3 CEC的网络框架图

 主要分为3个阶段:预训练阶段伪增量学习阶段分类器学习和自适应阶段

 首先是预训练阶段:对含有大量类别基类采样,得到用于预训练的基类训练集和用于伪增量学习阶段的伪增量数据集,这两个数据集中的类别和样本都是不重叠的。利用基类训练集和交叉熵损失函数训练得到一个特征提取器

然后就是伪增量学习极端:基类采样得到的伪增量数据集进行翻转得到伪增量训练集,因为并不是真实的增类样本,所以称为伪增量。在图3中stage2有一个图模型Gθ,称为GAT(Graph Attention Network)。利用这个模型去学习每一类的原型,用于最终的分类。这个模型采用的Transfromer的结构,利用注意力机制来进行参数的更新。

最后就是分类器学习和自适应阶段:在这个阶段就是利用真实的增量数据集去更新这个图模型,整个图模型的工作机制如图4所示:

图中每个节点就表示一个原型,增量阶段会对旧类的原型冻结,只学习当前阶段的原型

以上就是本篇文章的全部内容啦!!!非常感谢您能阅读至此!!!欢迎大家各抒己见!!!

Logo

瓜分20万奖金 获得内推名额 丰厚实物奖励 易参与易上手

更多推荐