在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:

target.scatter(dim, index, src)

其中各变量及参数的说明如下:

  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐列进行行填充
  • index: 按照轴方向,在target张量中需要填充的位置

为了保证scatter填充的有效性,需要注意:
(1)target张量在dim方向上的长度不小于source张量,且在其它轴方向的长度与source张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
(2)index张量的shape一般与source ,从而定义了每个source元素的填充位置。这里的一般是指broadcast机制下的例外情况。

下面以一个实际的案例来观察scatter函数:

import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
b_= b.scatter(dim=0, index=torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
print(b_)

# tensor([[0, 6, 0, 0, 9],
#        [0, 0, 2, 8, 0],
#        [5, 1, 7, 0, 4]])

整个函数的操作过程见下面的示意图。因为设定了dim=0,所以会逐列将source中的元素按照index中的位置信息,放入target张量中。
在这里插入图片描述
scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:

labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
#        [0, 0, 0, 1, 0]])
Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐