0. Summary
reformer主要提出的Locality-sensitive hashing attention,根据attention的稀疏和softmax的最大元素支配性质只关心与query最近的K,通过Locality-sensitive hashing实现由query找key,但是受到Q=K的限制。还用到Reversible residual layer降低中间层的内存、以及feed forward层进行Chunking进一步降低显存。
1. Problem Statement
Reformer提出了Transformer中的三个问题。
- problem1: 注意力机制的计算需要\(O(L^2)\)的时间和空间复杂度。
- problem2: transformer的层数较多,而N层模型的内存消耗是单层模型的N倍,因为需要存储每一层中的激活以进行反向传播。链式法则\([g(f(x))]'=g'(f(x))*f'(x)\)。
- problem3: 前馈层的维度通常比注意激活的维度大得多。\(d_{ff}>d_{model}\)。在一些模型中\(d_{ff}=4K\)甚至更多,需要消耗大量显存。
针对以上三个问题分别提出了三个解决方法:Locality-sensitive hashing attention、Reversible residual layer、Chunking。
2. Methods
2.1 Locality-sensitive hashing attention
hasing attention:计算和存储全矩阵\(QK^T\)是没有必要的,因为我们只对\(softmax(QK^T)\)感兴趣,而softmax有最大元素支配的性质。所以对于\(q_i\)我们只关心\(K\)中与之最接近的前几个\(k\)。 \[ \operatorname{Attention}(q_i, K, V)=\operatorname{softmax}\left(\frac{q_i K^{T}}{\sqrt{d_{k}}}\right) V \] Locality-sensitive hashing:期望距离近的向量以较高的概率获得相同的散列。哈希大小为b,生成随机的矩阵\(R^{[d_k,b/2]}\)。散列函数为\(h(x)=argmax([xR;-xR])\)
LSH attention:对于LSH attention,Q=K,只需要计算Q和K矩阵的LSH散列,然后仅计算同一哈希桶中的k和q向量的标准关注度。
LSH attention 流程:
- 按照桶号对查询进行排序,桶内按照序列位置排序。
- hash桶大小不相同,一个桶中的key和query的数量可能不一样,跨桶批处理困难。为了解决分配不均的问题,令\(k=\frac{q}{|q|}\),这样\(h(k)=h(q)\)。论文中对常规的Transformer做了\(K = Q\)的实验,证明不影响效果。\(K = Q\)带来另一个问题就是通常会更注意自身,可以加一个mask屏蔽掉。
- 分块计算:令块的大小\(m=\frac{2l}{b}\),\(l\)是序列长度,\(b\)是桶的数量。在当前块与前一个块的并中计算权重。
2.2 Reversible residual layer
reformer中:\(Y_1=X_1+Attention(X_2);Y_2=X_2+FeedForward(Y_1)\),使用可逆残差层而不是标准残差可以在训练过程中仅将激活存储一次,而不是N次。在反向传播时只使用模型参数就可以从下一层的激活结果中恢复任何给定层的激活结果,从而不用保存中间层的激活结果。
2.3 Chunking
比较厚的层仍会占用大量内存,前馈层的计算在序列中是完全独立的,所以可以分块处理,分chunk分开进行运算。 \[ Y_{2}=\left[Y_{2}^{(1)} ; \ldots ; Y_{2}^{(c)}\right]=\left[X_{2}^{(1)}+\text { FeedForward }\left(Y_{1}^{(1)}\right) ; \ldots ; X_{2}^{(c)}+\text { FeedForward }\left(Y_{1}^{(c)}\right)\right] \]
3. Evaluation
Shared-QK效果:从下图实验结果可以看出共享QK机制并没有比标准注意力机制效果差。
可逆层的效果:这里还是用标准Transformer跟可逆网络层对比,二者所使用的参数基本一样,学习曲线图如上:二者曲线基本一致,这说明可逆网络结构在节省内存的前提下,并没有损伤精度。
LSH attention in Transformer:相比全注意力机制,LSH注意力是一个近似的方法,从下面的实验图可以看出随着hash函数的增加,精确度也越来越高。在nrounds = 8的时候,精确度已经跟全注意力机制相匹敌了;但是hash函数越多,计算代价就越高,所以这个超参数可以根据实际计算资源进行调整。
不同注意力机制的速度:可以看出,随着序列长度的不断增加,标准注意力机制变得越来越慢,而LSH注意力机制基本变化不大,提速效果非常明显。
4. Conclusion
Reformer 针对 Transformer 中的三个问题提出了三个解决方法Locality-sensitive hashing attention、Reversible residual layer、Chunking,在与 Transformer 模型的性能相当的情况下,降低了在长序列任务下的时间与空间复杂度。
5. Notes
Trax:实现了reformer的过程可以学习code
Transformers也对reformer进行了实现code。
Reference
[1] Kitaev N, Kaiser Ł, Levskaya A. Reformer: The efficient transformer[J]. arXiv preprint arXiv:2001.04451, 2020.
[2] Vaswani A, Shazeer N, Parmar N, et al. Attention is all you need[C]//Advances in neural information processing systems. 2017: 5998-6008.
[3] Beltagy I, Peters M E, Cohan A. Longformer: The long-document transformer[J]. arXiv preprint arXiv:2004.05150, 2020.