扫码打开虎嗅APP

搜索历史
删除
完成
全部删除
热搜词
FlashAttention通过优化GPU内存访问而非计算效率,将Transformer训练速度提升3倍、内存消耗降低20倍,首次实现64K长序列的有效处理。 ## 1. Transformer的瓶颈本质 - **问题根源在IO而非计算**:传统注意力机制因频繁读写HBM显存导致性能瓶颈,而非计算能力不足。 - **标准实现缺陷**:三步计算流程(QK^T→Softmax→PV)导致O(N²d)次HBM访问,是输入数据量的N²倍。 ## 2. FlashAttention的核心创新 - **分块计算(Tiling)**:将K/V分块处理,通过逐步合并softmax统计量避免存储N×N矩阵,HBM访问量降至Θ(N²d/M)。 - **重计算(Recomputation)**:反向传播时动态重建中间矩阵,牺牲FLOP换取9倍HBM访问量减少。 - **理论最优性证明**:算法达到精确注意力下HBM访问次数的理论下界。 ## 3. 性能提升与实验结果 - **训练加速**:GPT-2训练比HuggingFace快3.5倍(9.5天→2.7天),BERT-large在MLPerf快15%。 - **内存效率**:内存占用比PyTorch标准实现低20倍,支持64K长序列(Path-256任务准确率63.1%)。 - **硬件普适性**:A100加速2-4倍,RTX 3090加速2.5-4.5倍,T4受限于SRAM大小加速有限。 ## 4. 扩展与模型质量提升 - **稀疏注意力扩展**:块稀疏版本在LRA基准上提速2.8倍,保持精度。 - **长上下文优势**:GPT-2上下文从1K扩至4K时perplexity降低0.7;医疗/法律文本分类任务中,16K序列比512序列提升4.3-8.5分。 ## 5. 行业启示 - **IO-Awareness优先**:算法设计需平衡计算与内存访问,FLOP数≠实际效率。 - **技术复用价值**:经典优化方法(分块/重计算)在深度学习时代仍能产生突破性影响。
2026-05-01 18:05

FlashAttention 论文精读:一个IO 感知的注意力算法,如何改变了大模型的训练速度

本文来自微信公众号: 歪睿老哥 ,作者:歪睿老哥,原文标题:《FlashAttention 论文精读:一个 IO 感知的注意力算法,如何改变了大模型的训练速度》


故事是这样的。


2022夏天,斯坦福的一帮人发了篇论文,标题叫《FlashAttention:Fast and Memory-Efficient Exact Attention with IO-Awareness》。


这名字听着挺无聊对吧,一堆缩写,一个技术名词,跟每天几百篇论文里随便哪篇长得都一样。


但我看了之后,觉得这件事挺有意思的。


因为它解决了一个几乎所有做Transformer的人都踩过坑的问题:大模型推理慢,到底慢在哪里。


很多人第一反应是"算不动",觉得算力不够,换更大的GPU就行。


但这篇文章的作者Tri Dao一帮人说了句不一样的话:


不是算不动,是数据搬不动。


他们管这个叫IO-Awareness——在写注意力算法的时候,不要只盯着FLOP数看,要把GPU显存层级之间的读写次数也算进去。


听起来像废话对吧,但就是这个"废话",让Transformer的训练速度提升了3倍,内存消耗降低了20倍,还让Transformer第一次在64K长度的序列上跑出了好结果。


64K。


之前所有方法,要么跑不了,要么跑出来跟随机猜差不多。


这篇文章,我想把这个论文的核心内容,用人话讲清楚。


1.背景:Transformer的瓶颈到底在哪里


问题:注意力机制的平方复杂度


先看最基本的注意力公式:


O=softmax(QK^T/sqrt(d))V


三个矩阵相乘,再加一个softmax。看起来很简单对吧。


但问题在于,Q和K的矩阵乘法会生成一个N×N的注意力矩阵,其中N是序列长度。


这意味着时间和空间复杂度都是O(N²)。


序列长度翻倍,计算量翻四倍。


这个平方复杂度是Transformer的天然缺陷,从2017年那篇"Attention Is All You Need"出来就没变过。


现有方案的局限


过去几年,一堆人想解决这个问题。


方案主要分两类:


第一类,近似注意力。用稀疏化、低秩分解、核函数近似等等手段,把注意力矩阵从N×N压缩到接近N×1。


这类方法理论上把计算复杂度降到了线性或近线性,但实际跑起来,墙钟时间(wall-clock time)并没有明显加速。


为什么?


因为很多方案只关注减少FLOP,忽略了内存访问的开销。


第二类,稀疏注意力。让每个token只关注有限的其他token,直接剪掉大量注意力连接。


这种方法确实减少了计算量,但稀疏模式本身也有内存访问的overhead,而且效果往往不如dense attention。


作者的判断


这篇文章的作者认为,现有方案没效果的根本原因不是算法不行,而是没有考虑GPU内存层级的IO特性。


现代GPU的内存层级是这样的:



DRAM(系统内存):容量最大(几十GB到几百GB),速度最慢(大约12.8 GB/s)


HBM(高带宽内存):GPU显存,容量中等(40-80 GB),速度中等(大约1.5-2.0 TB/s)


SRAM(片上缓存):容量最小(A100每个SM大约192 KB),速度最快(大约19 TB/s)


从HBM到SRAM,带宽差了一个数量级。


但现代GPU的计算速度已经超过内存速度了。操作越来越被内存访问(IO)而不是计算本身瓶颈住。


所以,关键问题不是FLOP多不多,而是有多少数据在HBM和SRAM之间来回搬。


这就是"IO-Awareness"的核心思想。


2.标准注意力实现的问题


标准算法:三步走


标准的注意力实现,通常是三步:


第一步:计算QK^T


把Q和K从HBM读到SRAM,在芯片上算QK^T,结果写回HBM。


这一步产生了一个N×N的注意力分数矩阵S。


第二步:Softmax


把S从HBM读出来,逐行做softmax,得到P矩阵。P再写回HBM。


第三步:PV


把P和V从HBM读出来,在芯片上算PV,结果写回HBM。


这三步看起来很自然对吧,但每步都在做一件事:把中间结果从HBM写出去,再从HBM读进来。


HBM访问次数的分析


让我们算一下HBM的访问总量。


前向传播:


第一步:读Q、K,写S→O(N²d+N²)次HBM访问


第二步:读S,写P→O(N²)次HBM访问


第三步:读P、V,写O→O(N²d+N²)次HBM访问


前向传播总共:O(N²d+N²)次HBM访问,是序列长度的平方级。


反向传播:


反向传播需要用到前向计算的S和P矩阵来计算梯度,所以同样的,要读S、P写回dQ、dK、dV。


反向传播也大约O(N²d+N²)次HBM访问。


核心矛盾


整个前向+反向传播,HBM访问总量大约是O(N²d)次。


但输入Q、K、V本身的总大小只有O(Nd),输出O也只有O(Nd)。


数据量是O(Nd)的东西,为什么要做O(N²d)的HBM访问?


多出来的O(N²)次访问,全是用在那个大得离谱的N×N注意力矩阵上。


这个矩阵太大,放不下SRAM,只能在HBM和SRAM之间反复搬。


这就是标准实现的根本问题。


3.FlashAttention算法:核心思路


两个关键技术


FlashAttention的思路很简单:用两个经典技术,避免把N×N注意力矩阵写到HBM上。


这两个技术是:


1.Tiling(分块计算)


2.Recomputation(重计算)


Tiling:分块做Softmax


标准softmax需要对整行做归一化,看起来必须把整行读进来才能算。


但math上有个技巧:softmax可以分块计算。


具体来说,如果我把一个向量x拆成两段x¹和x²,那么整个向量x的softmax结果,可以用x¹和x²各自的softmax统计量(最大值m和归一化因子ℓ)来逐步合并。



公式大概是:



这样,每次处理一个块,只需要记录两个小值(m和ℓ),就能把结果正确合并起来。


所以FlashAttention的做法是:


把K、V分成多个块


每次只把一个块加载到SRAM


对Q的每个块,和K的这个块算QK^T


在SRAM里算softmax,更新m和ℓ


逐步累积输出,最后写回HBM


关键:整个过程中,N×N注意力矩阵从来没有完整地出现在HBM上。


Recomputation:反向传播时不再读矩阵


那反向传播怎么办?


反向传播需要用到前向的S和P矩阵。标准做法是把前向的S和P存在HBM上,反向时直接读。


但FlashAttention说:不存了,反向时重新算。


它只存前向的输出O和softmax的统计量m、ℓ,这两个东西很小,O(Nd)的大小。


反向传播时,从HBM读Q、K、V,重新在SRAM里算S和P,然后再算梯度。


虽然多算了一些FLOP,但因为避免了从HBM读N×N矩阵的开销,实际运行时间反而更快。


这在学术上叫selective gradient checkpointing——梯度检查点的一种选择。


4.IO复杂度分析


理论保证


文章给出了一个严格分析。


标准注意力的HBM访问次数是Θ(N²d+N²)。


FlashAttention的HBM访问次数是Θ(N²d/M),其中M是SRAM的大小。


为什么是N²d/M?


因为SRAM能放下大小为Θ(M)的K、V块,每次能处理Θ(M/d)个K行。


对于N行的Q,需要N/(M/d)=Nd/M次扫描。每次扫描加载O(Nd)数据,所以总共O(N²d/M)次HBM访问。


实际差距


拿A100来算:


d=64(head维度)


M≈100KB(每个SM的SRAM)


标准注意力:O(N²×64)次HBM访问


FlashAttention:O(N²×64/100000)≈O(N²×0.00064)次HBM访问


HBM访问量减少了大约100倍。


虽然实际不可能完全达到理论极限(因为SRAM利用率、块大小选择等因素),但文章实验显示前向传播减少了约8倍,反向传播减少了约7倍,合计约9倍的HBM访问量降低。


下界证明


文章还证明了一个有意思的结论:


对于任何精确注意力算法,在所有可能的SRAM大小范围内,不可能渐近地优于O(N²d/M)的HBM访问下界。


换句话说,FlashAttention在这个意义上是最优的。


5.Block-Sparse FlashAttention


扩展:稀疏注意力


FlashAttention不只是精确注意力,还可以扩展到稀疏注意力。


思路很简单:如果注意力矩阵是块稀疏的(比如某些块全是零),那么在Tiling循环中直接跳过这些块就行。


算法跟FlashAttention几乎一样,只是加了一个if判断:如果当前块M_ij=0,跳过计算。


文章证明了Block-Sparse FlashAttention的HBM访问次数是Θ(N²d·s/M),其中s是非零块的比例。


s越小,加速越多。


实验显示,在LRA benchmark上,Block-Sparse FlashAttention相对于标准FlashAttention有2.8倍的加速,同时精度相当。



6.实验结果


训练速度


BERT-large:


在MLPerf 1.1上,FlashAttention比Nvidia记录快了15%(从20.0分钟降到17.4分钟)。


GPT-2 small:


比HuggingFace实现快3.5倍(从9.5天降到2.7天)


比Megatron-LM快2.0倍(从4.7天降到2.7天)


GPT-2 medium:


比HuggingFace实现快3.0倍(从21.0天降到6.9天)


比Megatron-LM快1.7倍(从11.5天降到6.9天)


Long-Range Arena:


平均加速2.4倍


模型质量提升


FlashAttention不只是更快,还能训练出更好的模型。


GPT-2长上下文:


用FlashAttention训练GPT-2 small,上下文长度从1K提升到4K,仍然比Megatron的1K版本快30%,且perplexity低了0.7。


长文档分类:


在MIMIC-III(医疗文本分类)和ECtHR(法律判决分类)上,增加序列长度带来显著提升:


MIMIC-III:16K序列比512序列提升4.3分


ECtHR:8K序列比512序列提升8.5分


PathFinder挑战:


Path-X(16K序列):FlashAttention的Transformer达到61.4%准确率,是第一个在这个任务上超过随机猜测的Transformer


Path-256(64K序列):Block-Sparse FlashAttention达到63.1%准确率


基准测试


在不同序列长度下:


序列长度128-512:FlashAttention比PyTorch标准实现快2-3倍


序列长度1024-2048:FlashAttention比所有近似注意力方法都快


内存占用:FlashAttention比PyTorch标准实现低20倍,比Linformer低2倍


不同硬件上的表现


A100:2-4倍加速


RTX 3090:2.5-4.5倍加速(HBM带宽更低,加速效果更明显)


T4:加速较少(SRAM更小,块大小需要更小)


7.总结


这篇文章的核心贡献,可以用一句话概括:


写注意力算法的时候,要把GPU内存层级的读写开销也算进去。


这个"IO-Awareness"的思想听起来简单,但在深度学习这个领域里,很少有人真正认真对待过。


大家习惯了看FLOP数,看理论复杂度,看benchmark上的accuracy。


但FLOP不等于wall-clock time,不等于内存使用量,不等于实际训练出来的模型质量。


FlashAttention用两个经典技术——Tiling和Recomputation——把注意力机制的HBM访问量从O(N²)降到了O(N²/M),在保持精确计算的同时实现了3倍的加速和20倍的内存节省。


而且它不只是更快,还让Transformer第一次真正具备了建模64K长度上下文的能力。


这就是IO-Awareness的力量。

本内容来源于网络 原文链接,观点仅代表作者本人,不代表虎嗅立场。
如涉及版权问题请联系 hezuo@huxiu.com,我们将及时核实并处理。

大 家 都 在 搜