ALiBi位置嵌入、稀疏注意力、FlashAttention、多查询注意力、条件计算和80G A100 GPU。
要点和技巧
要点
- Self-Attention层的时间复杂度是O(n²d)或O(nd²),其中:
- Q, K, V:(n, k) x (k, d),假定k=d有复杂度为O(nd²)
- QK^T:(n, d)和(d, n)运算,得到(n,n)矩阵,复杂度为O(n²d)
- softmax计算:对每行做softmax,复杂度为O(n),则n行的复杂度为O(n²)
- 加权和:(n, n) 与 (n, d)运算,得到(n, d)矩阵,复杂度为O(n²d)
- Transformer结构中,可学习的矩阵权重的形状与输入token的数量n无关。因此,经过训练的2K context长度的Transformer可以处理100K长度的输入。但是如果没有在100K上训练,推理期间无法在100K token上产生有意义的结果。
- 由于Self-Attention的O(n²d)复杂度,在巨大的语料库上训练vanilla Transformer是不可行的。根据估算,2K长度的LLaMA训练成本为3M美元,因此100K长度的LLaMA将花费150M美元。
- 一种选择是在2K token上训练模型,然后在更长的上下文(如65K)上微调模型。但由于Positional Embedding是正弦编码,他不能和原始的Transformer一起工作。
技巧
- 解决这个问题的首要技巧是删除正弦编码并使用Alibi,这种方式不会影响精度,并且允许模型在2K上训练,在100K上微调。
- 不需要计算所有token之间的attention分数。有些token比其他token更重要,因此可以使用稀疏注意力来加快训练和推理的速度。
- 在GPU上可以使用Flash Attention实现高效的注意力。它通过平铺来避免不适合GPU SRAM的中间矩阵计算,这有助于加快训练和推理的速度。
- 使用Multi-Query Attention替代Multi-Head Attention。这意味着在线性投影K和V时,可以在所有head上共享权重,这显著加快了增量推理的速度。
- 使用条件计算避免将所有模型参数应用于输入序列中的所有token。CoLT5仅将繁重的计算应用于最重要的token,而使用轻量版本的层处理其他token。这能够加快训练和推理的速度。
- 为了适应更大的Context,需要大量的GPU RAM,因此使用80G A100 GPU。
综上,训练和推理的速度越快,可以使用的上下文长度就越大。
原始Transformer和Context Length
在Transformer结构中,可学习矩阵权重的形状不取决于输入token的数量n。训练长context Transformer的解决方案是分两阶段训练它:
- 在2K token上下文长度上训练基本模型
- 在更长的Context (65K或100K)上继续训练(微调)
回顾Multi-head Attention
符号定义
- Q, K, V:与论文中信息检索有关的符号,将Query插入系统并搜索最接近的Key
- n:输入的token数
- d:文本embedding维度
- h:attention head的数量
- k:Q和K的线性投影尺寸
- v:V的线性投影尺寸
Multi-head Attention
- 首先是look-up Embedding层,对于一个给定的token,返回一个大小为(1,d)的向量。因此对于n个token的序列,可以得到(n,d)大小的文本Embedding矩阵X。随后我们给它加上位置正弦嵌入。
- Multi-head Attention层的目的是计算该token序列的新嵌入,但(1)按照token之间相对于上下文的重要性加权,(2)按token的相对位置加权。
- 使用h个attention head并行地处理这个(n,d)矩阵X。要获得所有attention head的Q,K和V,需要将X线性投影到k, k和v维度。这一步可以通过将X乘以h个形状为(d, k), (d, k)和(d, v)的矩阵来实现。
- Attention head返回h个大小为(n,v)的注意力分数矩阵。然后,将所有head中的片段串联起来 - 即(n, h*v) - 并对其进行线性投影用于下一步。
Scaled Dot-product Attention
- Q, K, V时大小为(n, k), (n, k)和(n, v)的3个线性投影,通过乘以每个头部单独的可学习权重获得。
- 通过计算Q和K转置之间的距离(点积)来获得注意力分数。将(n,k)乘以(k,n)得到(n,n)。然后,将其乘以mask矩阵,将一些token归零(例如在decoding阶段)。然后对其应用softmax缩放到0至1。这样就得到了形状为(n, n)的矩阵n_ij,表示第i个和第j个令牌之间的相对注意力分数。该分数显示了这些令牌在长度为n的特定上下文中的“接近”程度。
- 然后,将这个Attention分数矩阵(n,n)乘以大小为(n,d)的值V,以获得由这些相对注意力分数加权的文本嵌入。
Transformer的复杂度和Context长度
矩阵乘法中(a,b) * (b,c)的复杂度为O(a * b * c)。为了简化计算,假定k * h = O(d)。
注意力层的复杂度由两部分组成
- Q, K, V的线性投影:大小为(n,d)的embedding矩阵乘以k个可学习矩阵(d, k)、(d, k)和(d, v)。因此,复杂度~O(nd²);
- Q和K的乘法转换,然后乘以V:其中(n, k) * (k, n) = (n, n)和(n, n) * (n, v) = (n, v)。复杂度~O(n²d)
因此,有
- 当d>n时,项O(nd²)更重要(例如在LLaMA中,n=2K, d=4K)
- 当n>d时,项O(n²d)更重要(例如用n=65K和d=4K训练MosaicML)
增加上下文长度的优化技术
Trick 1:更好的位置编码ALiBi 【加速训练】
两阶段训练方案不能用在传统的Transformer架构是因为位置正弦编码没有“外推”能力。带线性偏差的注意力(Attention with Linear Biases, AliBi)位置嵌入应用于注意力头(而不是网络底部),它以与其距离成正比的惩罚来偏置query-key注意力分数(在softmax之前)。
这个技巧加速了训练。
Trick 2:稀疏Attention 【加速训练/推理】
不是全部的100K个上下文中的所有token都相互有关。减少计算次数的一种方法是在计算注意力分数时只考虑一些token。添加稀疏性的目的是使得计算线性于n,而不是二次于n。有一些方法来解决如何选择token之间的连接:
滑动窗口Attention在每个token周围使用固定大小的窗口注意力。在这一模式中,给定一个固定的窗口大小w,每个token在每一侧都会处理w/2个token。这种方式的复杂度为O(n*w),它与输入序列长度n呈线性缩放。为了使其高效,w应该相比于n更小。
BigBird则结合了全局、局部和随机机制。
Trick 3:FlashAttention 【加速训练/推理】
Attention层中有一些计算是被多次重复的:
- S = Q * K
- P = softmax(S)
- O = P * V
FlashAttention实现了注意力层算法以利用GPU内存,并计算精确的注意力。
当GPU进行操作时,输入数据必须存在于名为SRAM的“快速”存储器中。数据从“慢”的HBM存储器复制到SRAM,在计算结束后返回HBM。SRAM内存比HBM快得多,但是大小要小得多(40GB A100 GPU中的20MB)。
因此,访问HBM是一项昂贵的操作。GPU内存利用率——注意力层中的主要问题是P、S、O等中间乘法结果的规模很大(n,n)。需要把他们保存在HBM,并且在Attention操作期间再次读取它们。将P、S、O在HBM和SRAM之间来回移动是瓶颈,因此FlashAttention的主要**是将Q、K、V矩阵拆分成块,将这些块从HBM加载到SRAM,然后计算这些块的Attention输出。这一过程称为平铺(Tiling)。
这里的Matrix Multiplication操作针对GPU进行优化。FlashAttention将几种乘法和softmax与平铺操作进行融合,并优化了对HBM的访问。Pytorch 2.0版本内置了FlashAttention。
Trick 4:Multi-Query注意力 【加速推理】
原始的Multi-Head Attention中,每个Head都有一个单独的线性层,用于K和V矩阵。在推理过程中,解码器先前token的key和value被缓存以防止重新计算他们,因此GPU内存的使用量随着token的生成而增长。
多查询注意力(Multi-Query Attention, MQA)是一种优化,建议在线性投影K和V时在所有Attention Head共享权重,因此只需要保留两个大小为(n, k)和(n, v)的矩阵。一个大模型可以有多达96个头(如GPT-3),这意味着使用MQA可以节省96倍的key/value解码器缓存消耗。
这种优化在生成长文本时特别有益。其优点是显著加快了增量注意力score的计算,而训练的速度基本保持不变。
Trick 5:条件计算 【加速训练/推理】
当d>n时,模型的瓶颈不是注意力层,而是FFN和projection层。减少FLOPs(是floating point operations的缩写,意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度)的常见方法是采用某种形式的条件计算,避免将所有模型参数应用于输入序列中的所有token。
在Trick 2中,有些token比其他token更重要。基于类似的原因,在CoLT5论文中,作者将所有的FFN和注意力计算分为两个分支:重和轻。轻层应用于所有的token,而重层只应用于重要token。
“轻前馈分支和重前馈分支仅在hidden维度上不同,轻分支的hidden维度比标准T5前馈层小,而重分支的hidden维度更大。”
对于高达64K输入token的超长序列,这种方法被证明优于现有的LongT5模型的速度和准确性。
Reference
[1] The Secret Sauce behind 100K context window in LLMs: all tricks in one place https://blog.gopenai.com/how-to-speed-up-llms-and-use-100k-context-window-all-tricks-in-one-place-ffd40577b4c
[2] Sliding Window Attention https://paperswithcode.com/method/sliding-window-attention
[3] FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness https://arxiv.org/abs/2205.14135
[4] Fast Transformer Decoding: One Write-Head is All You Need https://arxiv.org/abs/1911.02150
[5] Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation https://arxiv.org/abs/2108.12409
[6] Big Bird: Transformers for Longer Sequences https://arxiv.org/abs/2007.14062
[7] PaLM: Scaling Language Modeling with Pathways https://arxiv.org/abs/2204.02311
[8] CoLT5: Faster Long-Range Transformers with Conditional Computation https://arxiv.org/abs/2303.09752