【Pytorch】scatter函数详解
在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:target.scatter(dim, index, source)其中各变量及参数的说明如下:target:即目标张量source:即源张量dim:指定轴方向,即填充方式。对于二维...
在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(source)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:
target.scatter(dim, index, src)
其中各变量及参数的说明如下:
target
:即目标张量,将在该张量上进行映射src
:即源张量,将把该张量上的元素逐个映射到目标张量上dim
:指定轴方向,定义了填充方式。对于二维张量,dim=0
表示逐列进行行填充,而dim=1
表示逐列进行行填充index
: 按照轴方向,在target
张量中需要填充的位置
为了保证scatter填充的有效性,需要注意:
(1)target
张量在dim
方向上的长度不小于source
张量,且在其它轴方向的长度与source
张量一般相同。这里的一般是指:scatter操作本身有broadcast机制。
(2)index
张量的shape一般与source
,从而定义了每个source
元素的填充位置。这里的一般是指broadcast机制下的例外情况。
下面以一个实际的案例来观察scatter函数:
import torch
a = torch.arange(10).reshape(2,5).float()
b = torch.zeros(3, 5))
b_= b.scatter(dim=0, index=torch.LongTensor([[1, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
print(b_)
# tensor([[0, 6, 0, 0, 9],
# [0, 0, 2, 8, 0],
# [5, 1, 7, 0, 4]])
整个函数的操作过程见下面的示意图。因为设定了dim=0
,所以会逐列将source
中的元素按照index
中的位置信息,放入target
张量中。
scatter函数的一个典型应用就是在分类问题中,将目标标签转换为one-hot编码形式,如:
labels = torch.LongTensor([1,3])
targets = torch.zeros(2, 5)
targets.scatter(dim=1, index=labels.unsqueeze(-1), src=torch.tensor(1))
# 注意dim=1,即逐样本的进行列填充
# 返回值为 tensor([[0, 1, 0, 0, 0],
# [0, 0, 0, 1, 0]])
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)