Transformer
译:《你只需要注意力》
论文地址:《Attention Is All You Need》
Motivation
RNN、LSTM及一些其它带门的循环神经网络在目前序列模型和转换问题如语言模型和机器翻译等取得了当前最好的结果。而这些循环模型通常是根据前一时间步的隐藏状态\(h_{t-1}\)和当前时间步(或位置)\(t\)推出当前时间步的隐藏状态\(h_t\),这种固有的顺序性就阻止了训练过程的并行化。虽然最近的工作已经通过因式分解技巧和条件计算在计算效率上取得了改善,且后者还会改善模型性能,但序列模型的基础限制依然存在。
为了解决上述问题,论文提出了Transformer,一种完全依赖注意力机制的模型结构,这种模型可以更好的并行运算,而且也可能在更短的时间内得到更好的结果。
方法简介
绝大多数序列转换模型都是编码器-解码器结构,编码器映射一个输入序列\((x_1,\cdots,x_n)\)到一个连续序列\(Z=(z_1,\cdots,z_n)\),给定\(Z\)、解码器,然后生成输出序列\((y_1,\cdots,y_m)\),在每一步,模型都是自回归,且在生成下一步时使用先前生成的符号作为附加输入
编码器和解码器堆叠
如上图:
编码器是由6个独立的层堆叠,每一层有两个子层,首先是多头自注意力机制,然后是一个简单的全连接前馈网络,在两个子层之间使用残差连接,然后进行层标准化(Layer Normalization),也就是说每个子层输出为\(LayerNorm(x+Sublayer(x))\),并且为了促进残差连接,所有子层和嵌入层输出维度为\(d_{model}=512\)
解码器也是由6个独立层堆叠,除了在每层有两个子层外,解码器还插入了第三个子层,用于对编码器的输出进行多头注意力。这里还修改了自注意力子层,防止位置关注后续的位置。这种掩码与输入嵌入偏移一个位置相结合,确保了对位置\(i\)的预测只能依赖于小于\(i\)的位置的已知输出。
注意力
一个注意力函数可以被描述为一个查询(query)和一个键值对(key-value)集合映射的输出。
缩放点积注意力
如图,论文将其采用的特殊注意力称为缩放点积注意力(Scaled Dot-Product Attention),输入由维度为\(d_k\)的\(queries\)和\(keys\)以及维度为\(d_v\)的\(values\)构成,具体计算如下公式: \[ Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V \] 最常用的两种注意力函数是加法注意力和点积注意力,而本文采用点积注意力,因为点积注意力是通过矩阵乘法实现的,在实际使用中更快,空间利用率也更高。
这里,作者怀疑对于较大的\(d_k\)值,点积的大小会增大,从而将Softmax函数推入其梯度极小的区域。为了抵消这种影响,将点积的结果缩放为\(\frac{1}{\sqrt{d_k}}\)
多头注意力
像图中右边一样,使用多种不同的学习线性投影将\(Q,K,V\)分别线性投影到\(d_k,d_k,d_v\)维度\(h\)次是有好处的,这种情况也就是多头注意力机制,其中\(h\)表示“头”的数量,计算表示如下: \[ MultiHead(Q,K,V)=Concat(head_1,\cdots,head_h)W^O \\ where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) \] 其中,\(W_i^Q=\mathbb R^{d_{model}\times d_k},W_i^K=\mathbb R^{d_{model}\times d_k},W_i^V=\mathbb R^{d_{model}\times d_v},W_i^O=\mathbb R^{hd_v\times d_{model}}\)表示投影参数矩阵
在这篇论文工作中,作者采用\(h=8\)的多头注意力,且\(d_k=d_v=d_{model}/h=64\)
注意力在本模型中的应用
Transformer通过三种不同的方式使用多头注意力:
- 在编码器-解码器注意力层,\(queries\)来自上一解码器层,\(keys\)和\(values\)来自编码器的输出。
- 编码器包含自注意力层,\(Q,K,V\)均来自解码器上一层的输出
- 解码器中的自注意力层
基于位置的前馈神经网络
和全连接层差不多,由两个线性变换和ReLU激活组成,公式为: \[ FFN(x) = \max(0, xW_1+b_1)W_2+b_2 \]
Embedding和Softmax
和其它序列转换模型类似,论文学习embeddings来将输入tokens和输出tokens转换成\(d_{model}\)维向量,还使用常规线性变换和softmax函数来转换解码器输出用于预测下一个token的概率
位置编码
由于Transformer中没有包含RNN和CNN,为了使用序列的顺序信息,需要注入序列中tokens的位置相关信息,因此在编码器和解码器堆的底部向input embedding添加了位置编码,位置编码有很多种选择,论文采用的是正余弦函数: \[ PE_{(pos, 2i)}=\sin(pos/10000^{2i/d_{model}}) \\ PE_{(pos, 2i+1)}=\cos(pos/10000^{2i/d_{model}}) \] 其中,\(pos\)表示位置,\(i\)表示维度,也就是说,位置编码的每个维度对应一个正弦。
作者也实验过其它方式的位置编码,但结果差不多,而选择正弦版本的原因则是认为它可以允许模型外推到比训练期间遇到的序列长度更长的序列长度。