FlashAttention 原理之 softmax 分块计算
标准 Softmax
Softmax 函数(也称为归一化指数函数)是一个将向量转换成概率分布的函数。对于输入向量 x,softmax 函数将其转换为一个概率分布向量,其中每个元素的值在 (0,1) 之间,且所有元素之和为 1。
s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(xi)=∑jexjexi
归一化 Softmax
m ( x ) : = max i x i m(x) := \max_i x_i m(x):=imaxxi
这里, x = [ x 1 , x 2 , … , x B ] x = [x_1, x_2, \dots, x_B] x=[x1,x2,…,xB] 表示一个有 B 个分量的向量(例如,模型输出的对各类的打分)。 m ( x ) m(x) m(x) 则是向量 x x x 中所有分量的最大值。
f
(
x
)
:
=
[
e
x
1
−
m
(
x
)
,
e
x
2
−
m
(
x
)
,
…
,
e
x
B
−
m
(
x
)
]
f(x) := \bigl[e^{\,x_1 - m(x)},\, e^{\,x_2 - m(x)},\,\dots,\, e^{\,x_B - m(x)}\bigr]
f(x):=[ex1−m(x),ex2−m(x),…,exB−m(x)]
这里做了一个“减最大值”的操作,即把每个
x
i
x_i
xi 都减去整个向量的最大分量
m
(
x
)
m(x)
m(x),然后取指数。这样做的好处是数值更稳定:当
x
i
x_i
xi 很大时,直接算
e
x
i
e^{x_i}
exi 容易导致溢出;但减去最大值以后,指数部分变为
x
i
−
m
(
x
)
x_i - m(x)
xi−m(x)(一个相对较小的或非正的数),从而避免数值爆炸。
ℓ
(
x
)
:
=
∑
i
f
(
x
)
i
\ell(x) := \sum_i f(x)_i
ℓ(x):=i∑f(x)i
这里把向量
f
(
x
)
f(x)
f(x) 的各个分量加起来得到标量
ℓ
(
x
)
\ell(x)
ℓ(x)。
softmax
(
x
)
:
=
f
(
x
)
ℓ
(
x
)
\text{softmax}(x) := \frac{f(x)}{\ell(x)}
softmax(x):=ℓ(x)f(x)
把每个分量
f
(
x
)
i
f(x)_i
f(x)i 除以
ℓ
(
x
)
\ell(x)
ℓ(x) 后,就得到了标准的 Softmax 输出向量。
softmax ( x ) i = e x i − m ( x ) ∑ j e x j − m ( x ) . \text{softmax}(x)_i = \frac{e^{\,x_i - m(x)}}{\sum_j e^{\,x_j - m(x)}} \,. softmax(x)i=∑jexj−m(x)exi−m(x).
由于每一项都经过指数函数且被总和归一化,它满足所有分量都非负且所有分量之和为 1,因此是一个有效的概率分布。
分块 Softmax
假设我们有两个同维度向量
x
(
1
)
和
x
(
2
)
∈
R
B
\mathbf{x}^{(1)} 和 \mathbf{x}^{(2)} \in \mathbb{R}^B
x(1)和x(2)∈RB,把它们拼接(concatenate)成
x
=
[
x
(
1
)
,
x
(
2
)
]
∈
R
2
B
.
\mathbf{x} = \bigl[\mathbf{x}^{(1)},\,\mathbf{x}^{(2)}\bigr] \in \mathbb{R}^{2B}.
x=[x(1),x(2)]∈R2B.
下面的公式说明,如何在不重复完整计算的情况下,用“各自的部分计算结果”组合成拼接后向量的 Softmax。先给出它的步骤,再解释其意义和好处:
最大值 m(x) 的分块计算
定义单个向量的最大值
对于
x
(
1
)
∈
R
B
\mathbf{x}^{(1)}\in \mathbb{R}^B
x(1)∈RB,我们先定义
m
(
x
(
1
)
)
=
max
i
(
x
i
(
1
)
)
,
m
(
x
(
2
)
)
=
max
i
(
x
i
(
2
)
)
.
m\bigl(\mathbf{x}^{(1)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(1)}_i\Bigr), \quad m\bigl(\mathbf{x}^{(2)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(2)}_i\Bigr).
m(x(1))=imax(xi(1)),m(x(2))=imax(xi(2)).
定义拼接向量的最大值
由于
x
\mathbf{x}
x 是把
x
(
1
)
\mathbf{x}^{(1)}
x(1) 和
x
(
2
)
\mathbf{x}^{(2)}
x(2) 拼到一起,那么
m
(
x
)
=
m
(
[
x
(
1
)
,
x
(
2
)
]
)
=
max
(
m
(
x
(
1
)
)
,
m
(
x
(
2
)
)
)
.
m(\mathbf{x}) \;=\; m\bigl([\mathbf{x}^{(1)},\, \mathbf{x}^{(2)}]\bigr) \;=\; \max\bigl(m(\mathbf{x}^{(1)}), \;m(\mathbf{x}^{(2)})\bigr).
m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2))).
这样就不需要对拼接后的 x \mathbf{x} x 再扫描一次去找最大值,而是只要比较两个子向量各自的最大值即可。
“指数向量” f(x) 的分块计算
Recall
f
(
x
)
=
[
e
x
1
−
m
(
x
)
,
e
x
2
−
m
(
x
)
,
…
,
e
x
2
B
−
m
(
x
)
]
.
f(\mathbf{x}) \;=\; \Bigl[e^{\,x_1 - m(\mathbf{x})},\; e^{\,x_2 - m(\mathbf{x})},\;\dots,\; e^{\,x_{2B} - m(\mathbf{x})}\Bigr].
f(x)=[ex1−m(x),ex2−m(x),…,ex2B−m(x)].
由于
x
\mathbf{x}
x 拆成了两块
x
(
1
)
\mathbf{x}^{(1)}
x(1)、
x
(
2
)
\mathbf{x}^{(2)}
x(2),我们分别对每块计算其对应的“指数向量”:
f
(
x
(
1
)
)
和
f
(
x
(
2
)
)
.
f\bigl(\mathbf{x}^{(1)}\bigr) \quad\text{和}\quad f\bigl(\mathbf{x}^{(2)}\bigr).
f(x(1))和f(x(2)).
然后拼起来即可。但要记住,每一块真正要减去的“中心化值”是整个
x
\mathbf{x}
x 的最大值
m
(
x
)
m(\mathbf{x})
m(x)
,因此它们之间会出现一个额外的“补偿系数”:
f ( x ) = [ e m ( x ( 1 ) ) − m ( x ) ⋅ f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) ⋅ f ( x ( 2 ) ) ] . f(\mathbf{x}) \;=\; \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr]. f(x)=[em(x(1))−m(x)⋅f(x(1)),em(x(2))−m(x)⋅f(x(2))].
直观上看,如果某一块(比如 x ( 1 ) \mathbf{x}^{(1)} x(1))的最大元素是整个拼接向量的最大元素,那么它带来的指数因子就会是 e m ( x ( 1 ) ) − m ( x ) = e 0 = 1 e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} = e^0 = 1 em(x(1))−m(x)=e0=1。而另一块若不是最大的,就会额外乘上一个小于 1 的因子。
归一化项 ℓ ( x ) \ell(x) ℓ(x) 的分块计算
Softmax 要把向量的指数项归一化到和为 1,所以我们需要计算
ℓ
(
x
)
=
∑
i
=
1
2
B
e
x
i
−
m
(
x
)
.
\ell(\mathbf{x}) \;=\; \sum_{i=1}^{2B} e^{\,x_i - m(\mathbf{x})}.
ℓ(x)=i=1∑2Bexi−m(x).
利用分块思想,可以分成两段求和,再用与上一步相同的补偿系数连接起来:
ℓ ( x ) = e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) ⏟ 第1块贡献 + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) ⏟ 第2块贡献 , \ell(\mathbf{x}) \;=\; \underbrace{e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr)}{\text{第1块贡献}} \;+\; \underbrace{e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr)}{\text{第2块贡献}}, ℓ(x)= em(x(1))−m(x)ℓ(x(1))第1块贡献+ em(x(2))−m(x)ℓ(x(2))第2块贡献,
同理,也只需要各块自己内部的和,再用一个相对的尺度因子即可。
最大值 s o f t m a x ( x ) \mathrm{softmax}(x) softmax(x) 的分块形式
把上面得到的
f
(
x
)
f(\mathbf{x})
f(x) 和
ℓ
(
x
)
\ell(\mathbf{x})
ℓ(x) 带入到
s
o
f
t
m
a
x
(
x
)
=
f
(
x
)
ℓ
(
x
)
,
\mathrm{softmax}(\mathbf{x}) \;=\; \frac{f(\mathbf{x})}{\ell(\mathbf{x})},
softmax(x)=ℓ(x)f(x),
就得到在分块后的 Softmax 形式:
s o f t m a x ( x ) = [ e m ( x ( 1 ) ) − m ( x ) f ( x ( 1 ) ) , e m ( x ( 2 ) ) − m ( x ) f ( x ( 2 ) ) ] e m ( x ( 1 ) ) − m ( x ) ℓ ( x ( 1 ) ) + e m ( x ( 2 ) ) − m ( x ) ℓ ( x ( 2 ) ) . \mathrm{softmax}(\mathbf{x}) \;=\; \frac{ \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr] }{ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr) \;+\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr) }. softmax(x)=em(x(1))−m(x)ℓ(x(1))+em(x(2))−m(x)ℓ(x(2))[em(x(1))−m(x)f(x(1)),em(x(2))−m(x)f(x(2))].
为什么这样做?
- 数值稳定性
跟单向量计算 Softmax 类似,这里也要减去整段向量的最大值 m ( x ) m(\mathbf{x}) m(x),以避免 e z e^z ez 里的 z z z 太大或太小导致溢出/下溢。 - 减少重复计算
如果我们已经知道各块各自的 max \max max 值和求和 ℓ ( x ( k ) ) \ell(\mathbf{x}^{(k)}) ℓ(x(k)),那就无需把 x \mathbf{x} x 整体重新扫描、求最大值、求指数和,总结出公式即可快速拼成拼接后向量的软最大值。 - 方便分布式或分块处理
在实际系统里, x ( 1 ) \mathbf{x}^{(1)} x(1) 和 x ( 2 ) \mathbf{x}^{(2)} x(2) 可能来自不同子网络或不同设备。这种分块计算可以让每一块先在本地完成自己的 Softmax 部分计算,最后再做一次简短的组合归一化即可。