文章

attention is all you need to learn

[toc]

学习和使用Transformer很久了, 很多地方的理解都比较混乱,因为学习的资料和使用的模块在很多部分已经不同步了。一直想按照时间顺序串一下Transformer的演化过程,就趁着这个春节的时间,一直推到ViT为止吧。

最传统的,最vanilla的注意力

以下公式来自Attention Is All You Need (2017)

1.输入一个序列,包含n个元素,将这n个word通过线性变换, $W_Q, W_K, W_V$得到我们的主角$Q,K,V$.也大概理解了为什么很多的教程以NLP任务举例, 因为最早的注意力提出来就是为了解决NLP任务的, 像是ViT等是后来才出现的。

对于Q中每一个向量$q_i$,计算与K中所有$k_j$的相似性,对应到注意力的基础公式, query与key计算特征相似度, 这一点在feature match中都是非常常见的操作了。

\[\textit{Attention Score}(Q, K) = Q K^T\]

但我们能发现,论文中的注意力公式多一个缩放因子$\sqrt{d}$作为分母

\[\textit{Scaled Attention Score}(Q, K) = \frac{Q K^T}{\sqrt{d}}\]

alt text

MatMul: $QK^T$

Scale: $\frac{1}{\sqrt{d}}$

可以发现, 公式中并没有Mask这一部分,(opt.)也表明它是可选的。因为在标准的注意力计算中, 所有的QKV都是有效、需要使用的,但是根据任务的不同,需要忽略部分输入。

一般来说, 一个网络中的特征维度d是一个固定值, 因此相当于对Q和K的相似度除以了一个固定值, 这个值是随着网络的特征维度而改变的, 即$\sqrt{d}$是维度相关的缩放因子。它的作用是缩小点积的极端值,使得注意力得分在经过softmax之后输出的结果更加平缓。


举例:

假设输入经过线性变换后: \(q_1 = [1, 0, 1, 0] \\ q_2 = [0, 1, 0, 1] \\ k_1 = [1, 0, 1, 0] \\ k_2 = [0, 1, 0, 1]\)

计算$Q \cdot K^T$

\[Q \cdot K^T = \begin{pmatrix} q_1 \cdot k_1 & q_1 \cdot k_2 \\ q_2 \cdot k_1 & q_2 \cdot k_2 \end{pmatrix} = \begin{pmatrix} 2 & 0 \\ 0 & 2 \end{pmatrix}\] \[softmax(Q \cdot K^T)=\begin{pmatrix} 0.88 & 0.12 \\ 0.12 & 0.88 \end{pmatrix}\]

如果添加了缩放因子$\sqrt{d}$

\[softmax(\frac{Q \cdot K^T}{\sqrt{d}}) = softmax\begin{pmatrix} 1 & 0 \\ 0 & 1 \end{pmatrix} =\begin{pmatrix} 0.73 & 0.27 \\ 0.27 & 0.73 \end{pmatrix}\]

可以发现, scaled attention score相比于不带缩放因子的版本, 相似度分数的差异更小了。那么$\sqrt{d}$是否可以替换为常量呢?当然可以, 但最好还是跟随维度来进行变化, 因为不同的任务中向量的维度可能差异非常大, 如果缩放因子过大或过小, 都不能够达到缩小极端值的效果。


使用$Q,K$计算相似度并输入softmax,借助其非线性特征进行归一化和平滑后, 作为注意力权重, 计算与$V$的加权,至此就是我们所熟悉的注意力公式:

\[\textit{Attention}(Q, K, V)=softmax(\frac{Q \cdot K^T}{\sqrt{d}})V\]

其实缩放因子这个东西本身是为点乘注意力(dot-product attention)而服务的, 是为了克服特征维度$d$的影响,有的公式中也会强调$d$是K的维度$d_k$, 但是就目前看到的大部分实现来说, $Q$和$K$的维度$d_q$和$d_k$通常是相同的,因为便于计算。因此缩放这部分是有点偏经验向的东西。具体可以参考原文中的描述:

Transformer为什么没有如RNN中的bias?

问出这个问题,其实是没有理解Transformer的QKV是为了做什么。

在RNN中, 我们使用如下的公式:

\[h_t = \text{Activation}(W_x x + W_h h_{t-1}+b)\]

Activation是一个非线性激活函数, 比如tanh或者ReLU. 之所以管$h_t$叫隐藏状态(Hidden State),是因为它是一个中间结果,既不是网络的输入也不是网络的输出,是网络整个黑盒中的一部分。

之所以RNN会有偏置项$b$, 是为了防止零输入导致的输出相同。

MHA(Multi-Head Attention)

Multi-head是如何出现的?

根据上述的注意力公式, 输入

假设输入序列的长度为$n$, 矩阵$Q, K, V$的shape分别为$\mathbb{R}^{n \times d_q}$, $\mathbb{R}^{n \times d_k}$, $\mathbb{R}^{n \times d_v}$

最终得到的加权和维度也是$\mathbb{R}^{n \times d_v}$

举具体例子,

输入:

长度为$n$的序列, 得到shape分别为$\R^{n \times d_q}$, $\R^{n \times d_k}$, $\R^{n \times d_v}$的矩阵$Q, K, V$

(因为Q和K要进行点积操作, 所以$d_q=d_k$) 最终得到的加权和计算得到的加权和维度也是$\R^{n \times d_v}$ 多头注意力是如何实现的? 分别使用$h$个独立的线性变换$W_i(i=1, 2, …, h)$ 将$Q, K, V$进行线性变换。

若有$h$个头, 每个头的维度是$d_h$ \(d_h = \frac{d}{h}\)

对进行线性变换 \(Q' = QW_h\)

\[K' = KW_k\] \[V' = VW_v\] \[(n, d_q)×(d_q, d) \rarr (n, d_q) × (d_q, hd_h) \rarr (n, hd_h)\]

之后有一个阶段, 划分头

$d$称为多头注意力的嵌入维度(representation dimensionality)(多头注意力总维度), $h$是头的数量,

\[Q_h = \text{reshape}(Q', (n, h, d_h))\] \[K_h = \text{reshape}(K', (n, h, d_h))\] \[V_h = \text{reshape}(V', (n, h, d_v))\]

划分后每个头有一个独立的key, query, value集合, 在单独的子空间中计算。

\[\text{Attention}_i = softmax(\frac{Q^i_h \cdot {K^i_h}^T}{\sqrt{d_h}})V^i_h, i=1,...,h\]

$h$个头, 每个头输出的维度是$(n, d_h)$

将$h$个头的输出concat, 得到$(n, hd_h)$维度的输出。拼接后结果经过一个线性变换得到最终输出。

\[O = \text{concat}(O_1, O_2, ..., O_h)\]

为什么要使用多头机制?

因为能够映射特征到更多的子空间,使模型从不同的视角学习特征来增强表达能力,这个所谓的不同视角,是”h个头使用独立的线性变换”实现的, 即 h个头在映射QKV时使用的$W_q, W_k, W_v$是不同的, 以保证其计算的梯度不同$\rarr$优化的方向不同$\rarr$各个头映射到的子空间不同


根据论文中的描述:

首先: 将qkv的值通过$h$次不同的学习到的线性变换得到$d_k, d_k, d_v$维的输入, 比直接使用相同的$d_model$维的输入效果要好。在每一次线性投影后的结果上计算注意力


什么是multi-head?

想想到目前为止Transformer做了什么不同的事情, 它将输入通过线性变换$W_Q, W_K, W_V$映射为$Q, K, V$, 本质是使用了一线性变换对特征进行归纳。但既然是线性变换进行归纳, 那么更改映射方式就能得到新的表示, 举个不太恰当的例子, 比如”粉色-汽车”和”汽车-粉色”都能表示”一辆粉色的汽车”, 但是在匹配时可能发生不同的事情。

创建$d_v$组不同的线性变换$W^Q_i, W^K_i, W^V_i$

最后将不同head的输出concat之后接一个$W^O$变换

multi-head attn与multi-head self-attn有什么不同?

手撕Transformer

Croco中的attention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Attention(nn.Module):
    def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.rope = rope 

    def forward(self, x, xpos):
        B, N, C = x.shape

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
        q, k, v = [qkv[:,:,i] for i in range(3)]
        # q,k,v = qkv.unbind(2)  # make torchscript happy (cannot use tensor as tuple)
               
        if self.rope is not None:
            q = self.rope(q, xpos)
            k = self.rope(k, xpos)
               
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

Residual Connection

残差连接是Transformer架构中非常常见的技术。将输入结果通过一个跳跃连接(skip connection, 最早由ResNet提出)连接到子层的输出上。

\[Output = Layer(x) + x\]

我们之前看过的不论是self-attn还是cross-attn的Transformer都能看到它的身影。

1
2
3
4
5
6
7
8
9
10
11
12
13
    # self-attn + MLP
    def forward(self, x, xpos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

    # self-attn + cross-attn + MLP
    def forward(self, x, y, xpos, ypos):
        x = x + self.drop_path(self.attn(self.norm1(x), xpos))
        y_ = self.norm_y(y)
        x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
        x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x, y

在最早的Transformer架构中我们可以发现, Add(也就是Residual Connection)和Norm是一起使用的。主要出现两部分输出之后:多头注意力(Multi-Head Attention, MHA)和前馈神经网络(Feed-Forward Neural Network), 这两个残差连接会和归一化一起使用(Layer Normalization), 使得Transformer能够更深地堆叠多个encoder和decoder, 而不会出现梯度消失或信息丢失。

本文由作者按照 CC BY 4.0 进行授权