FlashAttention 原理之伪代码解释
A t t n ( Q , K , V ) = s o f t m a x ( Q K T ) ⏟ 大小为 N × N × V , \mathrm{Attn}(Q, K, V) \;=\; \underbrace{\mathrm{softmax}\bigl(Q K^T\bigr)}_{\text{大小为 }N\times N} \;\times\; V, Attn(Q,K,V)=大小为 N×N softmax(QKT)×V,
其中 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N\times d} Q,K,V∈RN×d, N N N 常常很大,直接构造和存储 s o f t m a x ( Q K T ) \mathrm{softmax}(QK^T) softmax(QKT) 需要 O ( N 2 ) O(N^2) O(N2) 的空间,并且读写代价很高。
FlashAttention 通过将 Q, K, V 分别切成小的块(Block)并在片上(SRAM)完成大部分计算,从而避免在外部大内存(HBM)中存储或频繁读写大规模中间矩阵。算法中用到了按行块切的 query(大小为 B r × d B_r \times d Br×d)和按列块切的 key/value(大小为 B c × d B_c \times d Bc×d),然后在两重循环中分块累加结果。
输入:
- 矩阵 Q , K , V ∈ R N × d Q,K,V\in \mathbb{R}^{N\times d} Q,K,V∈RN×d 存放于外部大内存(HBM)。
- 片上(on-chip)有大小为 M M M 的高速缓存(SRAM)。
- 还会在HBM/片上分别为中间变量分配空间,如输出 O ∈ R N × d O\in \mathbb{R}^{N\times d} O∈RN×d、部分归一化因子 ℓ ∈ R N \ell\in \mathbb{R}^{N} ℓ∈RN、以及部分最大值 m ∈ R N m\in \mathbb{R}^{N} m∈RN(用于数值稳定)。
切分:
- 将 Q 按“行方向”分块,得到 T r = ⌈ N B r ⌉ T_r = \bigl\lceil \tfrac{N}{B_r}\bigr\rceil Tr=⌈BrN⌉ 个子块 Q 1 , … , Q T r Q_1,\dots,Q_{T_r} Q1,…,QTr,每块大小约为 B r × d B_r\times d Br×d。
- 将 K, V 按“列方向”分块,得到 T c = ⌈ N B c ⌉ T_c = \bigl\lceil \tfrac{N}{B_c}\bigr\rceil Tc=⌈BcN⌉ 个子块 ( K 1 , V 1 ) , … , ( K T c , V T c ) (K_1,V_1),\dots,(K_{T_c},V_{T_c}) (K1,V1),…,(KTc,VTc),每块大小约为 B c × d B_c\times d Bc×d。
- 相应地,也把输出 O、部分归一化因子 ℓ \ell ℓ、和部分最大值 m m m 按同样的行块大小 B r B_r Br 切分成 O 1 , … , O T r , ℓ 1 , … , ℓ T r , m 1 , … , m T r O_1,\dots,O_{T_r}, \ell_1,\dots,\ell_{T_r}, m_1,\dots,m_{T_r} O1,…,OTr,ℓ1,…,ℓTr,m1,…,mTr。
双层循环:
- 外层循环 (第 5 行):遍历 Key/Value 的列块索引 j = 1 … T c j = 1 \dots T_c j=1…Tc。
- 将第 j j j 个块 ( K j , V j ) (K_j, V_j) (Kj,Vj) 从外部内存(HBM)加载到片上SRAM。
- 内层循环 (第 7 行):遍历 Query 的行块索引 i = 1 … T r i = 1 \dots T_r i=1…Tr。
- 将 Q i , O i , ℓ i , m i Q_i, O_i, \ell_i, m_i Qi,Oi,ℓi,mi 从 HBM 加载到片上SRAM(它们都是大小为 B r B_r Br 或 B r × d B_r\times d Br×d 级别)。
- 计算局部注意力得分 S i j = Q i K j T S_{ij} = Q_i K_j^T Sij=QiKjT (大小为 ( B r × d ) × ( B c × d ) → B r × B c (B_r \times d) \times (B_c \times d) → B_r \times B_c (Br×d)×(Bc×d)→Br×Bc)。
- 做数值稳定的行级最大值 m ~ i j = r o w m a x ( S i j ) \tilde{m}{ij} = \mathrm{rowmax}(S{ij}) m~ij=rowmax(Sij),再计算 P ~ i j = exp ( S i j − m ~ i j ) \tilde{P}{ij} = \exp\bigl(S{ij} - \tilde{m}_{ij}\bigr) P~ij=exp(Sij−m~ij)。
- 行和 ℓ i j = r o w s u m ( P ~ i j ) \ell_{ij} = \mathrm{rowsum}\bigl(\tilde{P}_{ij}\bigr) ℓij=rowsum(P~ij)。
- 用 m ~ i j \tilde{m}{ij} m~ij 和 ℓ i j \ell{ij} ℓij 来更新行块级的最大值 m i n e w m_i^{new} minew 和部分归一化因子 ℓ i n e w \ell_i^{new} ℓinew,并将新的输出块 O i O_i Oi 与新块 P ~ i j V j \tilde{P}_{ij}V_j P~ijVj 合并。
计算 O i O_i Oi的结果
m
i
n
e
w
=
max
(
m
i
,
m
~
i
j
)
,
ℓ
i
n
e
w
=
e
m
i
−
m
i
n
e
w
ℓ
i
+
e
m
~
i
j
−
m
i
n
e
w
ℓ
i
j
.
m_i^{new} = \max\bigl(m_i,\;\tilde{m}_{ij}\bigr), \quad \ell_i^{new} = e^{\,m_i - m_i^{new}}\;\ell_i \;+\; e^{\,\tilde{m}_{ij} - m_i^{new}}\;\ell_{ij}.
minew=max(mi,m~ij),ℓinew=emi−minewℓi+em~ij−minewℓij.
若
m
~
i
j
\tilde{m}_{ij}
m~ij 比原先
m
i
m_i
mi 大,就把整条“数值刻度”往上挪到
m
~
i
j
\tilde{m}{ij}
m~ij;否则保留原先的刻度
m
i
m_i
mi。
ℓ
i
\ell_i
ℓi 和
ℓ
i
j
\ell_{ij}
ℓij 都要根据新的刻度
m
i
n
e
w
m_i^{new}
minew 进行相应的指数缩放,然后相加得到新的部分归一化因子。
现在我们得到了新的刻度 m i n e w m_i^{new} minew 和新的部分归一化因子 ℓ i n e w \ell_i^{new} ℓinew。剩下的,就是把老的输出 O i O_i Oi 和当前块新增的输出 P ~ i j V j \tilde{P}_{ij}V_j P~ijVj 合在一起,放到一个统一的“刻度”下。第 12 行的写法大致如下(作一点数学化展开):
O i ( n e w ) = [ d i a g ( ℓ i n e w ) ] − 1 ⏟ 相当于每行除以 ℓ i n e w ( d i a g ( ℓ i ) e m i − m i n e w O i ⏟ "重刻度"后的旧输出 + e m ~ i j − m i n e w P ~ i j V j ⏟ “重刻度”后的新块输出 ) . \begin{aligned} O_i^{(\mathrm{new})} &= \underbrace{\bigl[\mathrm{diag}(\ell_i^{new})\bigr]^{-1}}{ \text{相当于每行除以 }\ell_i^{new} } \;\Bigl( \underbrace{ \mathrm{diag}(\ell_i)\;e^{\,m_i - m_i^{new}} \;O_i }{ \text{"重刻度"后的旧输出} } \;+\; \underbrace{ e^{\,\tilde{m}{ij} - m_i^{new}} \;\tilde{P}{ij}\;V_j }_{ \text{“重刻度”后的新块输出} } \Bigr). \end{aligned} Oi(new)= [diag(ℓinew)]−1相当于每行除以 ℓinew( diag(ℓi)emi−minewOi"重刻度"后的旧输出+“重刻度”后的新块输出 em~ij−minewP~ijVj).
旧输出 O i O_i Oi 的重刻度:
- 先将 O i O_i Oi 乘上原先“局部归一化”用到的因子 ℓ i \ell_i ℓi,并且再乘上 e m i − m i n e w e^{m_i - m_i^{new}} emi−minew,把它转移到新的刻度 ℓ i n e w \ell_i^{new} ℓinew 所在的坐标系。
- 新块的输出
P
~
i
j
V
j
\tilde{P}_{ij}V_j
P~ijVj:
这是在本次块的局部最大值 m ~ i j \tilde{m}{ij} m~ij 上做的 exponent,所以要乘上 e m ~ i j − m i n e w e^{\tilde{m}{ij} - m_i^{new}} em~ij−minew 去对齐到新的刻度。 - 再除以新的
ℓ
i
n
e
w
\ell_i^{new}
ℓinew:
由于我们想让最终的 O i O_i Oi 处在 “输出 = 1 (部分Softmax分母) × ( 加权和 ) ” “输出 = \frac{1}{\text{(部分Softmax分母)}}\times(加权和)” “输出=(部分Softmax分母)1×(加权和)” 的形式,所以最后整体还要除以 ℓ i n e w \ell_i^{new} ℓinew。
代码里用 [ d i a g ( ℓ i n e w ) ] − 1 \bigl[\mathrm{diag}(\ell_i^{new})\bigr]^{-1} [diag(ℓinew)]−1 的写法,是因为在实现中 O i O_i Oi 维度是 B r × d B_r \times d Br×d,而 ℓ i n e w \ell_i^{new} ℓinew 是长度为 B r B_r Br 的向量,需要对每个行分别做除法。
最终把这个新的 O i ( n e w ) O_i^{(\mathrm{new})} Oi(new) 写回HBM,并且更新 ℓ i ← ℓ i n e w \ell_i \leftarrow \ell_i^{new} ℓi←ℓinew、 m i ← m i n e w m_i \leftarrow m_i^{new} mi←minew(见第 13 行),就完成了对第 j j j 块 Key/Value 的处理。
小结: 第 12 行就是做了一次“加权合并”,把“老的注意力累加输出”和“新的块注意力输出”整合到同一个数值基准,并更新归一化,使得当所有块都处理完后,O_i 就等价于 s o f t m a x ( Q i K T ) V \mathrm{softmax}(Q_iK^T)V softmax(QiKT)V 中“i 行块”的那部分结果。
FlashAttention公式推理:https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad