Pairwise-ranking loss代码实现对比
Pairwise-ranking loss代码在Pairwise-ranking loss中我们希望正标记的得分都比负标记的得分高,所以采用以下的形式作为损失函数。其中c+c_+c+是正标记,c−c_{-}c−是负标记。J=∑i=1n∑j=1c+∑k=1c−max(0,1−fj(xi)+fk(xi))J=\sum_{i=1}^{n} \sum_{j=1}^{c_{+}} \sum_{k...
Multi-label classification中Pairwise-ranking loss代码
定义
在多标签分类任务中,Pairwise-ranking loss
中我们希望正标记的得分都比负标记的得分高,所以采用以下的形式作为损失函数。其中
c
+
c_+
c+是正标记,
c
−
c_{-}
c−是负标记。
引用了Mining multi-label data1中Ranking loss的介绍,令正标记的得分都高于负标记的得分。
根据上述的定义,我们对Pairwise-ranking loss修改为以下的形式:
J = ∑ i = 1 n ∑ j = 1 c + ∑ k = 1 c − max ( 0 , 1 − f j ( x i ) + f k ( x i ) ) J=\sum_{i=1}^{n} \sum_{j=1}^{c_{+}} \sum_{k=1}^{c_{-}} \max \left(0,1-f_{j}\left(\boldsymbol{x}_{i}\right)+f_{k}\left(\boldsymbol{x}_{i}\right)\right) J=i=1∑nj=1∑c+k=1∑c−max(0,1−fj(xi)+fk(xi))
代码优化
我写了两版代码,使用三层for循环的版本,以及使用一层for循环+矩阵运算的版本。当样本大小为1w时,代码二只需要1.3s,代码一需要13s。
注意:代码基于pytorch,其中的函数会进行解释
环境:i7-6700HQ, python3.6.5, pytorch 1.4.0+cpu
先准备测试数据
import torch
torch.manual_seed(1)
from time import time
batch = 10000
clsnum = 10
y_true = torch.randint(0,2,(batch,clsnum)) # 生成(10000,10)的0-1构成的向量作为真实标记
y_pred = torch.rand((batch,clsnum)) # 生成(10000,10)的0~1构成的向量作为预测标记
代码一(for循环版本)
st = time()
sum_one = 0
for i in range(y_pred.size(0)):
true_index = (y_true[i] == 1.0).nonzero().flatten()
false_index = (y_true[i] == 0.0).nonzero().flatten()
for j in true_index:
for k in false_index:
val = 1 - y_pred[i, j] + y_pred[i, k]
if val > 0:
sum_one += val
else:
sum_one += val * 0
end = time()
print(end-st) # 14.33322787284851
print(sum_one)# tensor(223758.5938)
代码二(矩阵运算版本)
st = time()
sum_one = 0
for i in range(y_pred.size(0)):
true_index = (y_true[i] == 1.0).nonzero().flatten()
false_index = (y_true[i] == 0.0).nonzero().flatten()
# 若正标记有n个,负标记为m个。生成n*m矩阵,用来进行正标记与负标记的运算。
ot = 1 - y_pred[i, true_index].view(-1,1).repeat((1,false_index.size(0))) + y_pred[i, false_index].view(1,-1).repeat((true_index.size(0),1))
sum_one += torch.clamp(ot,0.0).sum() # 将小于0的元素变为0。
end = time()
print(end-st) # 2.101431131362915
print(sum_one)# tensor(223755.8438)
疑问
sum_one
的值两次计算的结果是不同的,很奇怪。
补充
检索中的Pairwise-ranking loss
在评论区指出了我所关注的是多标签分类,而面向检索中的Pairwise-ranking loss的公式是不一致的。
引用一下刘铁岩老师的Learning to rank for information retrieval2中关于检索任务的Pairwise-ranking loss的定义,对于文档库中的两个文档
x
u
x_u
xu和
x
v
x_v
xv。 pairwise方法不注重预测每个文档相关性的准确程度,而是注重两个文档之间的相对顺序。
度量学习中的Pairwise-ranking loss
度量学习:学习衡量样本之间的距离/相似性。我认为检索中的ranking loss是度量学习的特例,而度量学习中的Pairwise-ranking loss的目标是,给定一个样本
r
a
r_a
ra和另外两个正负样本
r
p
,
r
n
r_p, r_n
rp,rn,使得
r
a
r_a
ra和
r
p
r_p
rp的距离更近,
r
a
r_a
ra和
r
n
r_n
rn的距离更远。引用《Understanding Ranking Loss, Contrastive Loss, Margin Loss, Triplet Loss, Hinge Loss and all those confusing names》3 的图片。
引用
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)