关系网络 Relation Network
文章目录比较孪生网络、原型网络和关系网络关系网络 Relation Network实现过程网络结构损失函数训练策略算法推广 —— 推广到 zero-shot创新点算法评价比较孪生网络、原型网络和关系网络孪生网络需要计算任意两两样本的匹配程度,而原型网络则进一步改进,提出对样本进行适当的 embedding,然后计算样本每一类的样本中心,称为原型 prototype,通过模型学习出 prototyp
比较孪生网络、原型网络和关系网络
孪生网络需要计算任意两两样本的匹配程度,而原型网络则进一步改进,提出对样本进行适当的 embedding,然后计算样本每一类的样本中心,称为原型 prototype,通过模型学习出 prototype 的位置,对测试样本计算到每个原型的距离,从而进行分类。
不论是孪生网络还是原型网络,在分析两个样本的时候都是通过 embedding 后的特征向量距离(比如欧式距离)来反应,而关系网络则是通过构建神经网络来计算两个样本之间的距离从而分析匹配程度,和孪生网络、原型网络相比,关系网络可以看成提供了一个可学习的非线性分类器用于判断关系,而孪生网络、原型网络的距离只是一种线性的关系分类器。
关系网络 Relation Network
这篇论文进一步学习一种可迁移的深度度量方式能够比较图像之间的关系。整个网络分为两个阶段,第一阶段是一个 embedding 模块(用于提取特征信息),第二阶段是一个相关性模块(用于输出两幅图之间的相似程度得分,从而判断两幅图像是否来自同一类别)
和Matching Network一样,训练集分为支持集(Support Set)和查询集(Queury Set),将支持集的图像和查询集的图像分别输入嵌入式模块
f
φ
f_φ
fφ,提取得到特征信息。然后将查询集图像对应的特征信息分别与支持集中各个图像对应的特征信息级联起来(也可以是其他的连接方式,两两连接),然后进入相关模块
g
ϕ
g_ϕ
gϕ 计算得到相关性得分,最后输出一个 one-hot 向量,表示查询集中图像个支持集图像相似程度最高的那一类。
在计算相关性得分的时候,如果是 One-Shot,那么如果是 C-way,则会计算出C个得分,如果是K-shot,那么支持集所有同类样本的 embedding 先求和,在计算相关性得分,最后还是计算出 C 个得分
r
i
,
j
=
g
ϕ
(
C
(
f
φ
1
(
v
c
)
,
f
φ
2
(
x
j
)
)
)
,
i
=
1
,
2
,
…
,
C
r_{i, j}=g_{\phi}\left(\mathcal{C}\left(f_{\varphi_{1}}\left(v_{c}\right), f_{\varphi_{2}}\left(x_{j}\right)\right)\right), \quad i=1,2, \ldots, C
ri,j=gϕ(C(fφ1(vc),fφ2(xj))),i=1,2,…,C
其中:
- f φ f_φ fφ:嵌入函数
- C ( ) C() C():连接函数
- g ϕ g_ϕ gϕ:相关性计算函数
实现过程
网络结构
嵌入式模块仍然是采用四个卷积块构成,相关性模块现有两个卷积块,再经过两个全连接层,最后利用Sigmoid函数得到相似程度得分。特别的是,在处理同一类别包含多幅图像的数据集的时候(如:5-shot),本文采用将支持集中同一类别的图像得到的特征向量采用逐像素相加的方式得到对应类别的特征向量,再与查询集图像进行级联和计算得分的操作。
损失函数
与常见分类任务采用交叉熵损失函数不同,本文采用均方差对相似程度得分进行监督,优化目标函数如下
φ
,
ϕ
←
argmin
φ
,
ϕ
∑
i
=
1
m
∑
j
=
1
n
(
r
i
,
j
−
1
(
y
i
=
=
y
j
)
)
2
\varphi, \phi \leftarrow \underset{\varphi, \phi}{\operatorname{argmin}} \sum_{i=1}^{m} \sum_{j=1}^{n}\left(r_{i, j}-\mathbf{1}\left(y_{i}==y_{j}\right)\right)^{2}
φ,ϕ←φ,ϕargmini=1∑mj=1∑n(ri,j−1(yi==yj))2
训练策略
与Matching Network等基本相同,分成多个Episode,包含支持集和查询集。
算法推广 —— 推广到 zero-shot
本文提出的模型稍加改造也可以用于zero-shot学习任务,所谓zero-shot学习就是训练集中不包含图像,只有图像对应的一个语义特征向量或者描述向量。
如图所示,本文对网络进行相应的改进,对于描述向量经过两个全连接层(带有L2正则化,实现权重衰减)得到对应的特征向量,对于查询集中图像则是经过一个深层卷积神经网络(Inception or ResNet)得到对应的特征向量。然后将二者级联起来计算相似程度得分。
创新点
- 提出一种可学习的非线性相似性度量方式,用于实现小样本甚至one-shot学习任务
算法评价
在学习了Matching Network等一系列文章后,本文的思路是非常好理解的,就是改进了相似性度量的方式,由预先定义的固定的相似性度量函数(Matching Network——余弦距离,Prototypical Network——平方欧氏距离)或者Siamese Network中线性度量方式,升级为利用神经网络训练得到一个可学习的非线性相似性度量函数。实验结果表示在多个数据集上都取得了不错的成绩,但本文对于5-shot问题采用将特征图逐元素相加的方式来获取每个类别对应的特征信息的方式,我是存在异议的,这种做法是否过于简单粗暴?特征图直接相加是否会导致特征信息遭到破坏?这可能也是本文在5-shot任务中表现普遍较差的原因吧。
参考资料:
Sung F, Yang Y, Zhang L, et al. Learning to compare: Relation network for few-shot learning[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. CVPR 2018: 1199-1208.
论文阅读笔记《Learning to Compare: Relation Network for Few-Shot Learning》
元学习系列(三):Relation Network(关系网络)
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)