【SwinTransformer源码阅读二】Window Attention和Shifted Window Attention部分
swintransformer源码阅读
先放一下SwinTransformer的整体结构,图片源于原论文,可以发现,在Transformer的Block中 W-MSA(Window based multi-head self attention) 和 SW-MSA是关键组成部分。W-MSA出现在某阶段的奇数层,SW-MSA出现在某阶段的偶数层,W-MSA考虑的是单个窗口的信息,SW-MSA考虑的是不同窗口间的信息。
虽然从网络架构图里看,W-MSA和SW-MSA为两个不同的模块,但是在代码层面,两者是同一个代码片段,只是在计算SW-MSA时候,在计算完W-MSA后,然后通过代码进行滑动窗口,即cyclic shift操作,多计算了一个mask的操作。下面将针对代码进行分析。
W-MSA的代码
【注意】注释第一句话:Window based multi-head self attention (W-MSA) module with relative position bias.It supports both of shifted and non-shifted window.
代码注释中的中文,是以配置文件中 swin-tiny 相关的量 来进行注释的。
#窗口注意力
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim#96*(2^layer_index 0,1,2,3...)
self.window_size = window_size # Wh, Ww (7,7)
self.num_heads = num_heads#[3, 6, 12, 24]
head_dim = dim // num_heads#(96//3=32,96*2^1 // 6=32,...)
self.scale = qk_scale or head_dim ** -0.5#default:head_dim ** -0.5
# define a parameter table of relative position bias
#定义相对位置偏置表格
#[(2*7-1)*(2*7-1),3]
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
#得到一对在窗口中的相对位置索引
coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
#让相对坐标从0开始
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
#relative_coords[:, :, 0] * (2*7-1)
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
#索引值为 (49,49) 值在0-168对应位置偏移表的索引
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
#dim*(dim*3)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#attn_drop=0.0
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#初始化相对位置偏置值表(截断正态分布)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
#模块的前向传播
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape#输入特征的尺寸
#(3, B_, num_heads, N, C // num_heads)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# q/k/v: [B_, num_heads, N, C // num_heads]
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# q*head_dim ** -0.5
q = q * self.scale
# attn:B_, num_heads,N,N
attn = (q @ k.transpose(-2, -1))
# 在 随机在relative_position_bias_table中的第一维(169)选择position_index对应的值,共49*49个
#由于relative_position_bias_table第二维为 nHeads所以最终表变为了 49*49*nHead 的随机表
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
#attn每一个批次,加上随机的相对位置偏移 说民attn.shape=B_,num_heads,Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
#mask 在某阶段的奇数层为None 偶数层才存在
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
#进行 dropout
attn = self.attn_drop(attn)
#attn @ v:B_, num_heads, N, C/num_heads
#x: B_, N, C 其中
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
#经过一层全连接
x = self.proj(x)
#进行drop out
x = self.proj_drop(x)
return x
关于W-MSA中的注意力机制的运算,其实就是按照下面这个公式来进行的,在这个公式里,其实 QKV 三者均是又输入经过一个全连接层(nn.Linear())得到的,这个在代码里很好看明白。关键是在W-MSA中,增加了一个位置偏移量 B,这里的B相关计算也是W-MSA中的关键一步,下面进行记录下。
位置偏移量 B 的代码详解
这里关键是理解 relative_position_bias_table 和 relative_position_index ,这两个矩阵的对应关系设计的比较巧妙,即relative_position_index 刚好设计为 relative_position_bias_table 所对应的网格数量
# define a parameter table of relative position bias
#定义相对位置偏置表格
#[(2*7-1)*(2*7-1),3]
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
#得到一对在窗口中的相对位置索引
coords_h = torch.arange(self.window_size[0])#[0,1,2,3,4,5,6]
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
#让相对坐标从0开始
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
#relative_coords[:, :, 0] * (2*7-1)
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
#为位置偏置表中索引值,位置偏移表(13*13,nHeads)索引0-168
#索引值为 (49,49) 值在0-168对应位置偏移表的索引
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
#注册为不可学习变量
self.register_buffer("relative_position_index", relative_position_index)
#dim*(dim*3)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
#attn_drop=0.0
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
#初始化相对位置偏置值表(截断正态分布)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
relative_position_bias_table :设置的一个可学习的 (2 x window_size[0]-1)x(2 x window_size[1]-1) x nHeads 的随机变量(利用截断正态分布赋值),如果以代码中第一个阶段的参数量为例,则 window_size[0]=window_size[1]=7, 在第一个阶段 nHeads=3 。即该表中存储的时候一系列的随机数,用于位置编码,提升模型的性能。
下面展示一个relative_position_bias_table的例子
relative_position_index :相对位置编码表的索引表,即存储的值,用来取得相对位置偏移量表relative_position_bias_table 中某个位置的值,relative_position_index中存的值所取范围为 [0,168],即relative_position_bias_table 的大小为 169(13 x 13)个单元格。通过下面的图片,可以看到 relative_position_index中0 和 168 位置的编码只取一次,其实符合传统transformer中对于位置编码的运用,即开头和结尾的位置编码只用一次。
关注到最后计算出的 attn 需要加上位置偏移量,则这里需要看一下 relative_position_bias的计算策略,即下面的图示
relative_position_bias的计算策略:
最终的 relative_position_bias, 即经过转置后和 attn 的后三维一致,进而就可以进行直接位置相加了。
SW-MSA (Shifted Window based multi-head self attention (SW-MSA) module )
SW-MSA的代码中关键步骤为 attn_mask 和 shift windows的操作,即通过对特征图移位,并给Attention设置mask来间接实现的,在保持原有的window个数下,节省计算。
首先来看 attn_mask
代码如下:
#奇数层没有shift_size 偶数层有 shift_size
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution#(56/(2^layer_index),56/(2^layer_index))
#zero_init:img_mask (1,H,W,1)
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
#h_slices :(slice(0, -7, None), slice(-7, -3, None), slice(-3, None, None))
#>>> c=range(0, 10)
#>>> c[h_slices[0]]
# range(0, 3)
#>>> c[h_slices[1]]
# range(3, 7)
#>>> c[h_slices[2]]
# range(7, 10)
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
#将一个H*W的输入按照切片分为9块
#按照H维进行切片
for h in h_slices:
#按照W维进行切片
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
#将img_mask shape 1,H,W,1-> nW, window_size, window_size, 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
#nW, window_size, window_size
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
#attn_mask:[nW, window_size * window_size, window_size * window_size]
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
#矩阵中为0的置0 不为0的置 -100
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
关于程序中的数组切片,slice 部分,代码注释中有说明,注意 这里 window_size=7, shift_size=3,就不详细说明了,这里先针对 img_mask 来说明下,即下图的步骤,具体完成了哪些内容?在例子中,我让img_mask的 H=W=14
首先 img_mask 为和输入大小一致的张量,经过上面的slice代码的切片后,则形成了下面形状(1,14,14,1)的张量
注意到
如果直接将img_mask转为(14,14)的张量,我们可以看到其形状,相当于将张量根据slice切片,分为了9个部分,其中红色部分,不论img|_mask的 H W 为多少,始终为矩形,且大小为 (H-window_size)*(W-window_size),其余黄色的框基本大小是确定的,和 window_size 和 shift_size有关系。
其中attn_mask代码中的 mask_windows - > attn_mask的变换是关键的一步,这一步主要是让以 7*7 为单元的窗口中,块索引值相同的位置,置0,不同的位置 置为 -100 即直接屏蔽掉。
我们可以用以下代码模拟下,比如,a和b为shape为[2,3]的张量,则可以发现,a.unsqueeze(1) shape 为 [2,1,3], b.unsequeeze(2).shape 为 [2,3,1] ,最后经过 c = a.unsqueeze(1) - b.unsequeeze(2),c 变为了 shape [2,3,3],可以根据图中的计算过程,发现其实是 a [2,1,3] 中 b [2,3,1],就是 a 第一个 [1,3] 和 b 中 第一个[3,1] 中的每个元素进行减法操作,形成一个[3,3]的矩阵,然后再让 a 第二个 [1,3] 和 b 中 第二个[3,1] 中的每个元素进行减法操作,形成二个[3,3]的矩阵,最终形成了 [2,3,3] 的矩阵。
即经过上面的 mask_windows - > attn_mask 的运算,可以将不同窗口中,对应位置的索引相同值置0,不同值为两者的差值。
然后根据下面的代码进行同则赋 0,异则赋 -100。
#矩阵中为0的置0 不为0的置 -100
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
最后将得到的 attn_mask 与 得到的特征图 attn 进行相加
#mask 在某阶段的奇数层为None 偶数层才存在
if mask is not None:
#nW=B*H/7*W/7
#mask.shape:[B*H/7*W/7 , 49, 49]
nW = mask.shape[0]
#mask:torch.Size([1, 4, 1, 49, 49])
#attn.view:[B_ // nW, nW(4), self.num_heads, N, N]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
我们用程序模拟下面这个步骤,即假设 attn.view:[B_ // nW, nW(4), self.num_heads, N, N]=[1,2,2,2,2],而mask.unsqueeze(1).unsqueeze(0).shape=[1,2,1,2,2] , 如下面的图
所以根据代码 展开 attn_mask的计算过程,可以用图示表示:
通过图示可以发现,相当于强行将某些模块的样本用来计算对应mask的注意力值,这个属于对网络的一种约束了。且是强行分了 B_/nw 个模块,每个模块中交替进行计算对应那几个(nw)个mask的注意力。
说完了 attn_mask,再来看看 shift windows的操作,具体来讲,应该是一个特征图循环移位的操作,不过只移动了一次,所以直接用 shift 也可以理解。相关代码如下:
进行移动窗口
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
由于代码中,默认 fused_window_process 为 False,所以进行移动窗口主要代码是:
#这里的 x = x.view(B, H, W, C)
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
为与上面的例子对应,这里我们假设 shift_size = 3,由于 X的 shape为 [B,H,W,C] 所以可以看出,这个移位是在 H 和 W的维度分别移动 3
最后的shift恢复过程,就是上面 roll 和 partition 的反过程,代码中:
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
最后再来看一下 SwinTransformerBlock 的前向传播代码,即如果 shift_size>0 ,整体过程是对输入的整个特征图进行 循环移位 - > 然后进行带mask的注意力机制计算(SW-MSA)->再进行一系列后操作,这里并看不到针对某个窗口进行特征图移位和针对某个窗口进行 mask 均是针对整张特征图进行的相关操作。
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
else:
x_windows = WindowProcess.apply(x, B, H, W, C, -self.shift_size, self.window_size)
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
# reverse cyclic shift
if self.shift_size > 0:
if not self.fused_window_process:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = WindowProcessReverse.apply(attn_windows, B, H, W, C, self.shift_size, self.window_size)
else:
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
# FFN
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
所以看了代码,不禁怀疑网上有些关于swintransformer的教程,其实有些问题的,具体还是要看代码,如果有问题,欢迎留言指正!但为什么swintransformer仅仅进行了特征图循环移位和限制性的mask注意力机制,就有效果,其实还需要深究,个人感觉是多阶段连续后,其实特征图的循环移位,让越靠后的层,越考虑了全局的特征,这个还需要再看看代码了
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)