FlashAttention 原理之 softmax 分块计算

FlashAttention 原理之 softmax 分块计算

在这里插入图片描述

标准 Softmax

Softmax 函数(也称为归一化指数函数)是一个将向量转换成概率分布的函数。对于输入向量 x,softmax 函数将其转换为一个概率分布向量,其中每个元素的值在 (0,1) 之间,且所有元素之和为 1。

s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}} softmax(xi)=jexjexi

归一化 Softmax

m ( x ) : = max ⁡ i x i m(x) := \max_i x_i m(x):=imaxxi

这里, x = [ x 1 , x 2 , … , x B ] x = [x_1, x_2, \dots, x_B] x=[x1,x2,,xB] 表示一个有 B 个分量的向量(例如,模型输出的对各类的打分)。 m ( x ) m(x) m(x) 则是向量 x x x 中所有分量的最大值。

f ( x ) : = [ e   x 1 − m ( x ) ,   e   x 2 − m ( x ) ,   … ,   e   x B − m ( x ) ] f(x) := \bigl[e^{\,x_1 - m(x)},\, e^{\,x_2 - m(x)},\,\dots,\, e^{\,x_B - m(x)}\bigr] f(x):=[ex1m(x),ex2m(x),,exBm(x)]
这里做了一个“减最大值”的操作,即把每个 x i x_i xi 都减去整个向量的最大分量 m ( x ) m(x) m(x),然后取指数。这样做的好处是数值更稳定:当 x i x_i xi 很大时,直接算 e x i e^{x_i} exi 容易导致溢出;但减去最大值以后,指数部分变为 x i − m ( x ) x_i - m(x) xim(x)(一个相对较小的或非正的数),从而避免数值爆炸。

ℓ ( x ) : = ∑ i f ( x ) i \ell(x) := \sum_i f(x)_i (x):=if(x)i
这里把向量 f ( x ) f(x) f(x) 的各个分量加起来得到标量 ℓ ( x ) \ell(x) (x)

softmax ( x ) : = f ( x ) ℓ ( x ) \text{softmax}(x) := \frac{f(x)}{\ell(x)} softmax(x):=(x)f(x)
把每个分量 f ( x ) i f(x)_i f(x)i 除以 ℓ ( x ) \ell(x) (x) 后,就得到了标准的 Softmax 输出向量。

softmax ( x ) i = e   x i − m ( x ) ∑ j e   x j − m ( x )   . \text{softmax}(x)_i = \frac{e^{\,x_i - m(x)}}{\sum_j e^{\,x_j - m(x)}} \,. softmax(x)i=jexjm(x)exim(x).

由于每一项都经过指数函数且被总和归一化,它满足所有分量都非负且所有分量之和为 1,因此是一个有效的概率分布。

分块 Softmax

假设我们有两个同维度向量
x ( 1 ) 和 x ( 2 ) ∈ R B \mathbf{x}^{(1)} 和 \mathbf{x}^{(2)} \in \mathbb{R}^B x(1)x(2)RB,把它们拼接(concatenate)成
x = [ x ( 1 ) ,   x ( 2 ) ] ∈ R 2 B . \mathbf{x} = \bigl[\mathbf{x}^{(1)},\,\mathbf{x}^{(2)}\bigr] \in \mathbb{R}^{2B}. x=[x(1),x(2)]R2B.

下面的公式说明,如何在不重复完整计算的情况下,用“各自的部分计算结果”组合成拼接后向量的 Softmax。先给出它的步骤,再解释其意义和好处:

最大值 m(x) 的分块计算

定义单个向量的最大值

对于 x ( 1 ) ∈ R B \mathbf{x}^{(1)}\in \mathbb{R}^B x(1)RB,我们先定义
m ( x ( 1 ) )    =    max ⁡ i ( x i ( 1 ) ) , m ( x ( 2 ) )    =    max ⁡ i ( x i ( 2 ) ) . m\bigl(\mathbf{x}^{(1)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(1)}_i\Bigr), \quad m\bigl(\mathbf{x}^{(2)}\bigr) \;=\; \max_i \Bigl(\mathbf{x}^{(2)}_i\Bigr). m(x(1))=imax(xi(1)),m(x(2))=imax(xi(2)).

定义拼接向量的最大值

由于 x \mathbf{x} x 是把 x ( 1 ) \mathbf{x}^{(1)} x(1) x ( 2 ) \mathbf{x}^{(2)} x(2) 拼到一起,那么
m ( x )    =    m ( [ x ( 1 ) ,   x ( 2 ) ] )    =    max ⁡ ( m ( x ( 1 ) ) ,    m ( x ( 2 ) ) ) . m(\mathbf{x}) \;=\; m\bigl([\mathbf{x}^{(1)},\, \mathbf{x}^{(2)}]\bigr) \;=\; \max\bigl(m(\mathbf{x}^{(1)}), \;m(\mathbf{x}^{(2)})\bigr). m(x)=m([x(1),x(2)])=max(m(x(1)),m(x(2))).

这样就不需要对拼接后的 x \mathbf{x} x 再扫描一次去找最大值,而是只要比较两个子向量各自的最大值即可。

“指数向量” f(x) 的分块计算

Recall
f ( x )    =    [ e   x 1 − m ( x ) ,    e   x 2 − m ( x ) ,    … ,    e   x 2 B − m ( x ) ] . f(\mathbf{x}) \;=\; \Bigl[e^{\,x_1 - m(\mathbf{x})},\; e^{\,x_2 - m(\mathbf{x})},\;\dots,\; e^{\,x_{2B} - m(\mathbf{x})}\Bigr]. f(x)=[ex1m(x),ex2m(x),,ex2Bm(x)].

由于 x \mathbf{x} x 拆成了两块 x ( 1 ) \mathbf{x}^{(1)} x(1) x ( 2 ) \mathbf{x}^{(2)} x(2),我们分别对每块计算其对应的“指数向量”:
f ( x ( 1 ) ) 和 f ( x ( 2 ) ) . f\bigl(\mathbf{x}^{(1)}\bigr) \quad\text{和}\quad f\bigl(\mathbf{x}^{(2)}\bigr). f(x(1))f(x(2)).

然后拼起来即可。但要记住,每一块真正要减去的“中心化值”是整个 x \mathbf{x} x 的最大值 m ( x ) m(\mathbf{x}) m(x)
,因此它们之间会出现一个额外的“补偿系数”:

f ( x )    =    [ e   m ( x ( 1 ) ) − m ( x )   ⋅ f ( x ( 1 ) ) ,      e   m ( x ( 2 ) ) − m ( x )   ⋅ f ( x ( 2 ) ) ] . f(\mathbf{x}) \;=\; \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\cdot f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr]. f(x)=[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))].

直观上看,如果某一块(比如 x ( 1 ) \mathbf{x}^{(1)} x(1))的最大元素是整个拼接向量的最大元素,那么它带来的指数因子就会是 e   m ( x ( 1 ) ) − m ( x ) = e 0 = 1 e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} = e^0 = 1 em(x(1))m(x)=e0=1。而另一块若不是最大的,就会额外乘上一个小于 1 的因子。

归一化项 ℓ ( x ) \ell(x) (x) 的分块计算

Softmax 要把向量的指数项归一化到和为 1,所以我们需要计算
ℓ ( x )    =    ∑ i = 1 2 B e   x i − m ( x ) . \ell(\mathbf{x}) \;=\; \sum_{i=1}^{2B} e^{\,x_i - m(\mathbf{x})}. (x)=i=12Bexim(x).

利用分块思想,可以分成两段求和,再用与上一步相同的补偿系数连接起来:

ℓ ( x )    =    e   m ( x ( 1 ) ) − m ( x )   ℓ ( x ( 1 ) ) ⏟ 第1块贡献    +    e   m ( x ( 2 ) ) − m ( x )   ℓ ( x ( 2 ) ) ⏟ 第2块贡献 , \ell(\mathbf{x}) \;=\; \underbrace{e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr)}{\text{第1块贡献}} \;+\; \underbrace{e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr)}{\text{第2块贡献}}, (x)= em(x(1))m(x)(x(1))1块贡献+ em(x(2))m(x)(x(2))2块贡献,

同理,也只需要各块自己内部的和,再用一个相对的尺度因子即可。

最大值 s o f t m a x ( x ) \mathrm{softmax}(x) softmax(x) 的分块形式

把上面得到的 f ( x ) f(\mathbf{x}) f(x) ℓ ( x ) \ell(\mathbf{x}) (x) 带入到
s o f t m a x ( x )    =    f ( x ) ℓ ( x ) , \mathrm{softmax}(\mathbf{x}) \;=\; \frac{f(\mathbf{x})}{\ell(\mathbf{x})}, softmax(x)=(x)f(x),

就得到在分块后的 Softmax 形式:

s o f t m a x ( x )    =    [ e   m ( x ( 1 ) ) − m ( x )   f ( x ( 1 ) ) ,      e   m ( x ( 2 ) ) − m ( x )   f ( x ( 2 ) ) ] e   m ( x ( 1 ) ) − m ( x )   ℓ ( x ( 1 ) )    +    e   m ( x ( 2 ) ) − m ( x )   ℓ ( x ( 2 ) ) . \mathrm{softmax}(\mathbf{x}) \;=\; \frac{ \Bigl[ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(1)}\bigr), \;\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})} \,f\bigl(\mathbf{x}^{(2)}\bigr) \Bigr] }{ e^{\,m(\mathbf{x}^{(1)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(1)}\bigr) \;+\; e^{\,m(\mathbf{x}^{(2)}) - m(\mathbf{x})}\,\ell\bigl(\mathbf{x}^{(2)}\bigr) }. softmax(x)=em(x(1))m(x)(x(1))+em(x(2))m(x)(x(2))[em(x(1))m(x)f(x(1)),em(x(2))m(x)f(x(2))].

为什么这样做?

  1. 数值稳定性
    跟单向量计算 Softmax 类似,这里也要减去整段向量的最大值 m ( x ) m(\mathbf{x}) m(x),以避免 e z e^z ez 里的 z z z 太大或太小导致溢出/下溢。
  2. 减少重复计算
    如果我们已经知道各块各自的 max ⁡ \max max 值和求和 ℓ ( x ( k ) ) \ell(\mathbf{x}^{(k)}) (x(k)),那就无需把 x \mathbf{x} x 整体重新扫描、求最大值、求指数和,总结出公式即可快速拼成拼接后向量的软最大值。
  3. 方便分布式或分块处理
    在实际系统里, x ( 1 ) \mathbf{x}^{(1)} x(1) x ( 2 ) \mathbf{x}^{(2)} x(2) 可能来自不同子网络或不同设备。这种分块计算可以让每一块先在本地完成自己的 Softmax 部分计算,最后再做一次简短的组合归一化即可。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值