attention is all you need to learn
[toc]
学习和使用Transformer很久了, 很多地方的理解都比较混乱,因为学习的资料和使用的模块在很多部分已经不同步了。一直想按照时间顺序串一下Transformer的演化过程,就趁着这个春节的时间,一直推到ViT为止吧。
最传统的,最vanilla的注意力
以下公式来自Attention Is All You Need (2017)
1.输入一个序列,包含n个元素,将这n个word通过线性变换, $W_Q, W_K, W_V$得到我们的主角$Q,K,V$.也大概理解了为什么很多的教程以NLP任务举例, 因为最早的注意力提出来就是为了解决NLP任务的, 像是ViT等是后来才出现的。
对于Q中每一个向量$q_i$,计算与K中所有$k_j$的相似性,对应到注意力的基础公式, query与key计算特征相似度, 这一点在feature match中都是非常常见的操作了。
\[\textit{Attention Score}(Q, K) = Q K^T\]但我们能发现,论文中的注意力公式多一个缩放因子$\sqrt{d}$作为分母
\[\textit{Scaled Attention Score}(Q, K) = \frac{Q K^T}{\sqrt{d}}\]MatMul: $QK^T$
Scale: $\frac{1}{\sqrt{d}}$
可以发现, 公式中并没有Mask这一部分,(opt.)也表明它是可选的。因为在标准的注意力计算中, 所有的QKV都是有效、需要使用的,但是根据任务的不同,需要忽略部分输入。
一般来说, 一个网络中的特征维度d是一个固定值, 因此相当于对Q和K的相似度除以了一个固定值, 这个值是随着网络的特征维度而改变的, 即$\sqrt{d}$是维度相关的缩放因子。它的作用是缩小点积的极端值,使得注意力得分在经过softmax之后输出的结果更加平缓。
举例:
假设输入经过线性变换后: \(q_1 = [1, 0, 1, 0] \\ q_2 = [0, 1, 0, 1] \\ k_1 = [1, 0, 1, 0] \\ k_2 = [0, 1, 0, 1]\)
计算$Q \cdot K^T$
\[Q \cdot K^T = \begin{pmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 \\ q_2 \cdot k_1 & q_2 \cdot k_2 \end{pmatrix} = \begin{pmatrix} 2 & 0 \\ 0 & 2 \end{pmatrix}\] \[softmax(Q \cdot K^T)=\begin{pmatrix} 0.88 & 0.12 \\ 0.12 & 0.88 \end{pmatrix}\]如果添加了缩放因子$\sqrt{d}$
\[softmax(\frac{Q \cdot K^T}{\sqrt{d}}) = softmax\begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} =\begin{pmatrix} 0.73 & 0.27 \\ 0.27 & 0.73 \end{pmatrix}\]可以发现, scaled attention score相比于不带缩放因子的版本, 相似度分数的差异更小了。那么$\sqrt{d}$是否可以替换为常量呢?当然可以, 但最好还是跟随维度来进行变化, 因为不同的任务中向量的维度可能差异非常大, 如果缩放因子过大或过小, 都不能够达到缩小极端值的效果。
使用$Q,K$计算相似度并输入softmax,借助其非线性特征进行归一化和平滑后, 作为注意力权重, 计算与$V$的加权,至此就是我们所熟悉的注意力公式:
\[\textit{Attention}(Q, K, V)=softmax(\frac{Q \cdot K^T}{\sqrt{d}})V\]其实缩放因子这个东西本身是为点乘注意力(dot-product attention)而服务的, 是为了克服特征维度$d$的影响,有的公式中也会强调$d$是K的维度$d_k$, 但是就目前看到的大部分实现来说, $Q$和$K$的维度$d_q$和$d_k$通常是相同的,因为便于计算。因此缩放这部分是有点偏经验向的东西。具体可以参考原文中的描述:
Transformer为什么没有如RNN中的bias?
问出这个问题,其实是没有理解Transformer的QKV是为了做什么。
在RNN中, 我们使用如下的公式:
\[h_t = \text{Activation}(W_x x + W_h h_{t-1}+b)\]Activation是一个非线性激活函数, 比如tanh或者ReLU. 之所以管$h_t$叫隐藏状态(Hidden State),是因为它是一个中间结果,既不是网络的输入也不是网络的输出,是网络整个黑盒中的一部分。
之所以RNN会有偏置项$b$, 是为了防止零输入导致的输出相同。
MHA(Multi-Head Attention)
Multi-head是如何出现的?
根据上述的注意力公式, 输入
假设输入序列的长度为$n$, 矩阵$Q, K, V$的shape分别为$\mathbb{R}^{n \times d_q}$, $\mathbb{R}^{n \times d_k}$, $\mathbb{R}^{n \times d_v}$
最终得到的加权和维度也是$\mathbb{R}^{n \times d_v}$
举具体例子,
输入:
长度为$n$的序列, 得到shape分别为$\R^{n \times d_q}$, $\R^{n \times d_k}$, $\R^{n \times d_v}$的矩阵$Q, K, V$
(因为Q和K要进行点积操作, 所以$d_q=d_k$) 最终得到的加权和计算得到的加权和维度也是$\R^{n \times d_v}$ 多头注意力是如何实现的? 分别使用$h$个独立的线性变换$W_i(i=1, 2, …, h)$ 将$Q, K, V$进行线性变换。
若有$h$个头, 每个头的维度是$d_h$ \(d_h = \frac{d}{h}\)
对进行线性变换 \(Q' = QW_h\)
\[K' = KW_k\] \[V' = VW_v\] \[(n, d_q)×(d_q, d) \rarr (n, d_q) × (d_q, hd_h) \rarr (n, hd_h)\]之后有一个阶段, 划分头
$d$称为多头注意力的嵌入维度(representation dimensionality)(多头注意力总维度), $h$是头的数量,
\[Q_h = \text{reshape}(Q', (n, h, d_h))\] \[K_h = \text{reshape}(K', (n, h, d_h))\] \[V_h = \text{reshape}(V', (n, h, d_v))\]划分后每个头有一个独立的key, query, value集合, 在单独的子空间中计算。
\[\text{Attention}_i = softmax(\frac{Q^i_h \cdot {K^i_h}^T}{\sqrt{d_h}})V^i_h, i=1,...,h\]$h$个头, 每个头输出的维度是$(n, d_h)$
将$h$个头的输出concat, 得到$(n, hd_h)$维度的输出。拼接后结果经过一个线性变换得到最终输出。
\[O = \text{concat}(O_1, O_2, ..., O_h)\]为什么要使用多头机制?
因为能够映射特征到更多的子空间,使模型从不同的视角学习特征来增强表达能力,这个所谓的不同视角,是”h个头使用独立的线性变换”实现的, 即 h个头在映射QKV时使用的$W_q, W_k, W_v$是不同的, 以保证其计算的梯度不同$\rarr$优化的方向不同$\rarr$各个头映射到的子空间不同
根据论文中的描述:
首先: 将qkv的值通过$h$次不同的学习到的线性变换得到$d_k, d_k, d_v$维的输入, 比直接使用相同的$d_model$维的输入效果要好。在每一次线性投影后的结果上计算注意力
什么是multi-head?
想想到目前为止Transformer做了什么不同的事情, 它将输入通过线性变换$W_Q, W_K, W_V$映射为$Q, K, V$, 本质是使用了一线性变换对特征进行归纳。但既然是线性变换进行归纳, 那么更改映射方式就能得到新的表示, 举个不太恰当的例子, 比如”粉色-汽车”和”汽车-粉色”都能表示”一辆粉色的汽车”, 但是在匹配时可能发生不同的事情。
创建$d_v$组不同的线性变换$W^Q_i, W^K_i, W^V_i$
最后将不同head的输出concat之后接一个$W^O$变换
multi-head attn与multi-head self-attn有什么不同?
手撕Transformer
Croco中的attention
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Attention(nn.Module):
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.rope = rope
def forward(self, x, xpos):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
q, k, v = [qkv[:,:,i] for i in range(3)]
# q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
if self.rope is not None:
q = self.rope(q, xpos)
k = self.rope(k, xpos)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
Residual Connection
残差连接是Transformer架构中非常常见的技术。将输入结果通过一个跳跃连接(skip connection, 最早由ResNet提出)连接到子层的输出上。
\[Output = Layer(x) + x\]我们之前看过的不论是self-attn还是cross-attn的Transformer都能看到它的身影。
1
2
3
4
5
6
7
8
9
10
11
12
13
# self-attn + MLP
def forward(self, x, xpos):
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
# self-attn + cross-attn + MLP
def forward(self, x, y, xpos, ypos):
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
y_ = self.norm_y(y)
x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
x = x + self.drop_path(self.mlp(self.norm3(x)))
return x, y
在最早的Transformer架构中我们可以发现, Add(也就是Residual Connection)和Norm是一起使用的。主要出现两部分输出之后:多头注意力(Multi-Head Attention, MHA)和前馈神经网络(Feed-Forward Neural Network), 这两个残差连接会和归一化一起使用(Layer Normalization), 使得Transformer能够更深地堆叠多个encoder和decoder, 而不会出现梯度消失或信息丢失。