代码实现—多头自注意力&多头交叉注意力
多头注意力(Multi-Head Attention)是一种基于自注意力机制(self-attention)的改进方法。自注意力是一种能够计算出输入序列中每个位置的权重,因此可以很好地处理序列中长距离依赖关系的问题。但在应用中,可能存在多个不同的关注点,因此就需要多个自注意力机制来处理不同的关注点。多头注意力就是在一个输入序列上使用多个自注意力机制,得到多组注意力结果,然后将这些结果进行拼接和线性
多头注意力和交叉注意力
多头注意力
多头注意力(Multi-Head Attention)是一种基于自注意力机制(self-attention)的改进方法。
自注意力
是一种能够计算出输入序列中每个位置的权重,因此可以很好地处理序列中长距离依赖关系的问题。但在应用中,可能存在多个不同的关注点,因此就需要多个自注意力机制来处理不同的关注点。多头注意力
就是在一个输入序列上使用多个自注意力机制,得到多组注意力结果,然后将这些结果进行拼接和线性投影得到最终输出。
自注意力和交叉注意力
自注意力和交叉注意力的区别就在于
输入
,在计算注意力时我们需要三个矩阵Q、K、V,这三个矩阵是由输入X经过线性变换得到的。对于自注意力
,输入X只有一个,三个矩阵都是由同一个X得到;对于交叉注意力
,输入X一般有两个 x 1 x_1 x1和 x 2 x_2 x2,Q由 x 1 x_1 x1线性变换得到,K、V由 x 2 x_2 x2线性变换得到。
代码实现
在多头注意力机制中计算注意力矩阵时,将输入张量X拆分成h个子张量,每一个子张量,都计算一次子注意力,得到一个输出张量 O i O_i Oi。最后将h个输出张量拼接在一起,得到最终的输出张量O
具体来说,设X(nxd)为输入张量,Q、K、V(dxd)分别为学习到的d维查询、键和值向量。h为头数,则计算如下:
M
u
l
t
i
H
e
a
d
(
X
)
=
C
o
n
c
a
t
(
h
e
a
d
1
,
.
.
.
h
e
a
d
h
)
W
O
MultiHead(X)=Concat(head_1,...head_h)W^O
MultiHead(X)=Concat(head1,...headh)WO
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
S
o
f
t
m
a
x
(
(
Q
K
T
)
d
k
V
Attention(Q,K,V)=Softmax((QK^T) \sqrt{d_k} V
Attention(Q,K,V)=Softmax((QKT)dkV
h
e
a
d
i
=
A
t
t
e
n
t
i
o
n
(
X
Q
i
,
X
K
i
,
X
V
i
)
,
i
=
1
,
.
.
.
,
h
head_i=Attention(XQ_i,XK_i,XV_i),i=1,...,h
headi=Attention(XQi,XKi,XVi),i=1,...,h
Q
i
=
X
W
i
Q
K
i
=
X
W
i
K
V
i
=
X
W
i
V
.
i
=
1
,
.
.
.
h
Q_i=XW_i^Q \quad K_i=XW_i^K \quad V_i=XW_i^V. \quad i=1,...h
Qi=XWiQKi=XWiKVi=XWiV.i=1,...h
多头注意力
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, in_dim, k_dim, v_dim, num_heads):
super(NultiHeadAttention, self).__init__()
self.num_heads = numheads
self.k_dim = k_dim
self.v_dim = v_dim
self.proj_q = nn.Linear(in_dim, k_dim * num_heads, bias=False)
self.proj_v = nn.Linear(in_dim, v_dim * num_heads, bias=False)
self.proj_k = nn.Linear(in_dim, k_dim * num_heads, bias=False)
self.proj_o = nn.Linear(v_dim * num_heads, in_dim)
def forward(self, x, mask=None):
# 输入x的维度是(batch_size, seq_len, in_dim)
batch_size. seq_len, in_dim = x.size()
# x经过线性变换得到的向量维度是(batch_size, seq_len, k_dim*num_heads)
# q的维度是(bath_size, self.num_heads, seq_len, k_dim)
q = self.proj_q(x).view(batch_size, seq_len, self.num_heads, self.k_dim).permute(0, 2, 1, 3)
# k的维度是(bath_size, self.num_heads, k_dim, seq_len) 这里就相当于对k进行转置
k = self.proj_k(x).view(batch_size, seq_len, self.num_heads, self.k_dim).premute(0, 2, 3, 1)
# v的维度是(bath_size, self.num_heads, seq_len, v_dim)
v = self.proj_v(x).view(batch_size, seq_len, self.num_heads, slef.v_dim).premute(0, 2, 1, 3)
# 计算attention
# attention的维度是(batch_size, self.num_heads, seq_len, seq_len) 每个字对每个字的注意力分数,得到方阵
attention = torch.matmul(q, k) / self.k_dim ** 0.54
if mask is not None:
attention = attention.masked_fill(mask == 0, -1e9)
attention = F.softmax(attention, dim=-1)
# attention和v相乘后的维度是(batch_size,self.num_heads, seq_len, v_dim)
# 这两个矩阵相乘的意义就是 用每个字的这个句子的每个字的注意力分数乘上每个字第一维特征,再求和
# 因为V矩阵就是输入X线性变换得到的 每一行代表一个字的嵌入向量 每个字向量维度是v_dim
# 而注意力矩阵是个方阵 行列都和句子长度保持一致 每一行都是一个字对整个句子其他字的注意力分数
# 这里最后输出的维度是(batch_size, seq_len, num_heads*v_dim)
# 多头注意力得到多组注意力结果,然后将多组结果拼接
output = torch.matmul(attention, v).premute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, -1)
# 再对结果进行线性变化,和使用一个头得到的注意力矩阵维度是一样的
# 这里的维度是(batch_size, seq_len, in_dim) 最后和输入X维度一致
output = self.project_o(output)
return output
交叉注意力
class CrossAttention(nn.Module):
def __init__(self, in_dim1, in_dim2, k_dim, v_dim, num_heads):
super(CrossAttention, self).__init__()
self.num_heads = num_heads
self.k_dim = k_dim
self.v_dim = v_dim
self.proj_q1 = nn.Linear(in_dim1, k_dim * num_heads, bias=False)
self.proj_k2 = nn.Linear(in_dim2, k_dim * num_heads, bias=False)
slef.proj_v2 = nn.Linear(in_dim2, v_dim * num_heads, bias=False)
self.proj_o = nn.Linear(v_dim * num_heads, in_dim1)
def forward(self, x1, x2, mask=None):
batch_size, seq_len1, in_dim1 = x1.size()
seq_len2 = x2.size()
# q1(batch_size, num_heads, seq_len1, k_dim)
q1 = self.proj_q1(x1).view(batch_size, seq_len1, self.num_heads, self.k_dim).premute(0, 2, 1, 3)
# k2(batch_size, num_heads, k_dim, seq_len2)
k2 = self.proj_k2(x2).view(batch_size, seq_len2, self.num_heads, slef.k_dim).premute(0, 2, 3, 1)
# v2(batch_size, num_heads, seq_len2, v_dim)
v2 = self.proj_v2(x2).view(batch_size, seq_len2, self.num_heads, self.v_dim).premute(0, 2, 1, 3)
# attention(batch_size, num_heads, seq_len1, seq_len2)
attention = torch.matmul(q1, k1) / self.k_dim ** 0.5
if mask is not None:
attention = attention.masked_fill(mask == 0, -1e9)
attention = F.softmax(attention, dim=1)
# output(batch_size, num_heads, seq_len1, v_dim)=>(batch_size, seq_len1, num_heads*v_dim)
output = torch.matmul(attention, v2).premute(0, 2, 1, 3).contiguous().view(batch_size, seq_len1, -1)
# output(batch_size, seq_len1, in_dim1)
output = self.proj_o(output)
return output
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)