transfromer-XL论文详解 -- 潘登同学的NLP笔记
Transformer-XL是对Transformer的改进或变种,主要是解决长序列的问题,其中XL表示extra long,在最近流行的XLNet中就是使用Transformer-XL作为基础模块。在下文中,是将Trm-XL放在类似GPT这样的语言模型框架中来介绍,所以理解的时候要放在整个模型中去理解,而不是一个单独的Trm-XL。
Vanilla Transformer
transformer作为一种特征提取器,在NLP中有广泛的应用。但是Trm需要对输入序列设置一个固定的长度,比如在BERT中,默认长度是512。如果文本序列长度短于固定长度,可以通过填充的方式来解决。如果序列长度超过固定长度,处理起来就比较麻烦。一种处理方式,就是将文本划分为多个segments。训练的时候,对每个segment单独处理,segments之间没有联系,如下图(a)所示。这存在两个问题,1)因为segments之间独立训练,所以不同的token之间,最长的依赖关系,就取决于segment的长度;2)出于效率的考虑,在划分segments的时候,不考虑句子的自然边界,而是根据固定的长度来划分序列,导致分割出来的segments在语义上是不完整的。
在预测的时候,会对固定长度的segment做计算,一般取最后一个位置的隐向量作为输出。为了充分利用上下文关系,在每做完一次预测之后,就对整个序列向右移动一个位置,再做一次计算,如上图(b)所示,这导致计算效率非常低。
Segment-Level Recurrence
为了解决上面提到的问题,在Trm的基础上,Trm-XL提出了一个改进,在对当前segment进行处理的时候,缓存并利用上一个segment中所有layer的隐向量序列,而且上一个segment的所有隐向量序列只参与前向计算,不再进行反向传播,这就是所谓的segment-level Recurrence。
Trm本身是可以设置multi-heads,但是在后文中为了简化描述采用单个head。将两个连续的segments表示为
- $S_{\tau}=[x_{\tau,1},x_{\tau,2},\ldots,x_{\tau,L}]$
- $S_{\tau+1}=[x_{\tau+1,1},x_{\tau+1,2},\ldots,x_{\tau+1,L}]$
L是序列长度假设整个模型中,包含N层Trm,那么每个segment中就有N组长度为L的隐向量序列,将第$\tau$个segment的第n层隐向量序列表示为 $h_{\tau}^{n}\in R^{L\times d}$,d是隐向量维度.那么第$\tau+1$个segment的第n层隐向量序列,可以由下面的一组公式计算得出。 $$ \tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \qquad (表示对两个向量的拼接,拼接后为2L\times d) \ \qquad \ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \ \qquad \ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n) $$ 注意q的计算方式不变,只使用当前segment中的隐向量,计算得到的q序列长度仍然是L。k和v采用拼接之后的$\tilde{h}$来计算,计算出来的序列长度是2L。之后的计算就是标准的Transformer计算。计算出来的第n层隐向量序列长度仍然是L,而不是2L。Trm的输出隐向量序列长度取决于query的序列长度,而不是key和value。
推导一下:
- $Q[L\times d] \cdot K^T[d\times 2L] = [L\times 2L] \cdot V[2L\times d] = [L\times d]$
训练和预测过程如下图所示。这张图上有一个点需要注意,在当前segment中,第n层的每个隐向量的计算,都是利用下一层中包括当前位置在内的,连续前L个长度的隐向量,这是在上面的公式组中没有体现出来的,也是文中没有明说的。每一个位置的隐向量,除了自己的位置,都跟下一层中前(L-1)个位置的token存在依赖关系,而且每往下走一层,依赖关系长度会增加(L-1),如下图中Evaluation phase所示,所以最长的依赖关系长度是N(L-1),N是模型中layer的数量。N通常要比L小很多,比如在BERT中,N=12或者24,L=512,依赖关系长度可以近似为$O(N\times L)$ 。在对长文本进行计算的时候,可以缓存上一个segment的隐向量的结果,不必重复计算,大幅提高计算效率。
上文中,我们只保存了上一个segment,实际操作的时候,可以保存尽可能多的segments,只要内存或者显存放得下。论文中的试验在训练的时候,只缓存一个segment,在预测的时候,会缓存多个segments。
Relative Position Encodeings
在vanilla Trm中,为了表示序列中token的顺序关系,在模型的输入端,对每个token的输入embedding,加一个位置embedding。位置编码embedding或者采用正弦\余弦函数来生成,或者通过学习得到。在Trm-XL中,这种方法行不通,每个segment都添加相同的位置编码,多个segments之间无法区分位置关系。Trm-XL放弃使用绝对位置编码,而是采用相对位置编码,在计算当前位置隐向量的时候,考虑与之依赖token的相对位置关系。具体操作是,在算attention score的时候,只考虑query向量与key向量的相对位置关系,并且将这种相对位置关系,加入到每一层Trm的attention的计算中。
我们对两种方法做个对比。下面一组公式是vanilla Trm计算attention的方式, $E_x$表示token的输入embedding,U是绝对位置编码embedding,两个W分别是query矩阵和key矩阵。下面的公式是对$(E_{x_i}+U_i)W_q^TW_k(E_{x_j}+U_j)$做了分解。 $$ A_{i,j}^{abs} = E_{x_i}^TW_q^TW_KE_{x_j} + E_{x_i}^TW_q^TW_KU_j + U_{i}^TW_q^TW_KE_{x_j} + U_{i}^TW_q^TW_KU_{j} $$
下面一组公式,是Trm-XL计算attention的方式。首先,将绝对位置编码U,替换成了相对位置编码$R_{i-j}$ 。插一句,因为i只利用之前的序列,所以i-j>=0。并且把$W_k$矩阵分为$W_{k,E}和W_{k,R}$,用于分别生成基于内容的key向量和基于位置的key向量, $$ A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j} $$
相对位置关系用一个位置编码矩阵$R\in R^{L_{max}\times d}$ 来表示,第i行表示相对位置间隔为i的位置向量。论文中强调R采用正弦函数生成,而不是通过学习得到的,好处是预测时,可以使用比训练距离更长的位置向量。
最终总结
最后来看一下Trm-XL的完整计算公式,如下所示,只有前3行与vanilla Trm不同,后3行是一样的。第3行公式中,计算A的时候直接采用query向量,而不再使用 表示。最后需要注意的是,每一层在计算attention的时候,都要包含相对位置编码。而在vanilla Trm中,只有在输入embedding中才包含绝对位置编码,在中间层计算的时候,是不包含位置编码的。
$$ \tilde{h}_{\tau+1}^{n-1} = [SG(h_{\tau}^{n-1}),h_{\tau+1}^{n-1}] \ \qquad \ q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n = h_{\tau+1}^{n-1}W_q^T,\tilde{h}_{\tau+1}^{n-1}W_k^T,\tilde{h}_{\tau+1}^{n-1}W_v^T \ \qquad \ {h}_{\tau+1}^{n-1} = Transformer\quad Layer(q_{\tau+1}^n,k_{\tau+1}^n,v_{\tau+1}^n) \ \qquad \ A_{i,j}^{rel} = E_{x_i}^TW_q^TW_{k,E}E_{x_j} + E_{x_i}^TW_q^TW_{k,R}R_{i-j} + U_{i}^TW_q^TW_{k,E}E_{x_j} + U_{j}^TW_q^TW_{k,R}R_{i-j} \ \qquad \ \alpha_{\tau}^n = Masked\quad Softmax(A_{\tau}^n)V_{\tau}^n \ \qquad \ o_{\tau}^n = LayerNorm(Linear(\alpha_{\tau}^n)+{h}_{\tau+1}^{n-1}) \ \qquad \ h_{\tau}^n = Positionwise\quad Feed\quad Forward(o_{\tau}^n) $$
总结,Trm-XL为了解决长序列的问题,对上一个segment做了缓存,可供当前segment使用,但是也带来了位置关系问题,为了解决位置问题,又打了个补丁,引入了相对位置编码。