FlashAttention 原理之伪代码解释

FlashAttention 原理之伪代码解释

在这里插入图片描述

A t t n ( Q , K , V )    =    s o f t m a x ( Q K T ) ⏟ 大小为  N × N    ×    V , \mathrm{Attn}(Q, K, V) \;=\; \underbrace{\mathrm{softmax}\bigl(Q K^T\bigr)}_{\text{大小为 }N\times N} \;\times\; V, Attn(Q,K,V)=大小为 N×N softmax(QKT)×V,

其中 Q , K , V ∈ R N × d Q, K, V \in \mathbb{R}^{N\times d} Q,K,VRN×d N N N 常常很大,直接构造和存储 s o f t m a x ( Q K T ) \mathrm{softmax}(QK^T) softmax(QKT) 需要 O ( N 2 ) O(N^2) O(N2) 的空间,并且读写代价很高。

FlashAttention 通过将 Q, K, V 分别切成小的块(Block)并在片上(SRAM)完成大部分计算,从而避免在外部大内存(HBM)中存储或频繁读写大规模中间矩阵。算法中用到了按行块切的 query(大小为 B r × d B_r \times d Br×d)和按列块切的 key/value(大小为 B c × d B_c \times d Bc×d),然后在两重循环中分块累加结果。

在这里插入图片描述

输入:

  • 矩阵 Q , K , V ∈ R N × d Q,K,V\in \mathbb{R}^{N\times d} Q,K,VRN×d 存放于外部大内存(HBM)。
  • 片上(on-chip)有大小为 M M M 的高速缓存(SRAM)。
  • 还会在HBM/片上分别为中间变量分配空间,如输出 O ∈ R N × d O\in \mathbb{R}^{N\times d} ORN×d、部分归一化因子 ℓ ∈ R N \ell\in \mathbb{R}^{N} RN、以及部分最大值 m ∈ R N m\in \mathbb{R}^{N} mRN(用于数值稳定)。

切分:

  1. 将 Q 按“行方向”分块,得到 T r = ⌈ N B r ⌉ T_r = \bigl\lceil \tfrac{N}{B_r}\bigr\rceil Tr=BrN 个子块 Q 1 , … , Q T r Q_1,\dots,Q_{T_r} Q1,,QTr,每块大小约为 B r × d B_r\times d Br×d
  2. 将 K, V 按“列方向”分块,得到 T c = ⌈ N B c ⌉ T_c = \bigl\lceil \tfrac{N}{B_c}\bigr\rceil Tc=BcN 个子块 ( K 1 , V 1 ) , … , ( K T c , V T c ) (K_1,V_1),\dots,(K_{T_c},V_{T_c}) (K1,V1),,(KTc,VTc),每块大小约为 B c × d B_c\times d Bc×d
  3. 相应地,也把输出 O、部分归一化因子 ℓ \ell 、和部分最大值 m m m 按同样的行块大小 B r B_r Br 切分成 O 1 , … , O T r , ℓ 1 , … , ℓ T r , m 1 , … , m T r O_1,\dots,O_{T_r}, \ell_1,\dots,\ell_{T_r}, m_1,\dots,m_{T_r} O1,,OTr,1,,Tr,m1,,mTr

双层循环:

  • 外层循环 (第 5 行):遍历 Key/Value 的列块索引 j = 1 … T c j = 1 \dots T_c j=1Tc
  • 将第 j j j 个块 ( K j , V j ) (K_j, V_j) (Kj,Vj) 从外部内存(HBM)加载到片上SRAM。
  • 内层循环 (第 7 行):遍历 Query 的行块索引 i = 1 … T r i = 1 \dots T_r i=1Tr
  • Q i , O i , ℓ i , m i Q_i, O_i, \ell_i, m_i Qi,Oi,i,mi 从 HBM 加载到片上SRAM(它们都是大小为 B r B_r Br B r × d B_r\times d Br×d 级别)。
  • 计算局部注意力得分 S i j = Q i K j T S_{ij} = Q_i K_j^T Sij=QiKjT (大小为 ( B r × d ) × ( B c × d ) → B r × B c (B_r \times d) \times (B_c \times d) → B_r \times B_c (Br×d)×(Bc×d)Br×Bc)。
  • 做数值稳定的行级最大值 m ~ i j = r o w m a x ( S i j ) \tilde{m}{ij} = \mathrm{rowmax}(S{ij}) m~ij=rowmax(Sij),再计算 P ~ i j = exp ⁡ ( S i j − m ~ i j ) \tilde{P}{ij} = \exp\bigl(S{ij} - \tilde{m}_{ij}\bigr) P~ij=exp(Sijm~ij)
  • 行和 ℓ i j = r o w s u m ( P ~ i j ) \ell_{ij} = \mathrm{rowsum}\bigl(\tilde{P}_{ij}\bigr) ij=rowsum(P~ij)
  • m ~ i j \tilde{m}{ij} m~ij ℓ i j \ell{ij} ij 来更新行块级的最大值 m i n e w m_i^{new} minew 和部分归一化因子 ℓ i n e w \ell_i^{new} inew,并将新的输出块 O i O_i Oi 与新块 P ~ i j V j \tilde{P}_{ij}V_j P~ijVj 合并。

计算 O i O_i Oi的结果

m i n e w = max ⁡ ( m i ,    m ~ i j ) , ℓ i n e w = e   m i − m i n e w    ℓ i    +    e   m ~ i j − m i n e w    ℓ i j . m_i^{new} = \max\bigl(m_i,\;\tilde{m}_{ij}\bigr), \quad \ell_i^{new} = e^{\,m_i - m_i^{new}}\;\ell_i \;+\; e^{\,\tilde{m}_{ij} - m_i^{new}}\;\ell_{ij}. minew=max(mi,m~ij),inew=emiminewi+em~ijminewij.
m ~ i j \tilde{m}_{ij} m~ij 比原先 m i m_i mi 大,就把整条“数值刻度”往上挪到 m ~ i j \tilde{m}{ij} m~ij;否则保留原先的刻度 m i m_i mi
ℓ i \ell_i i ℓ i j \ell_{ij} ij 都要根据新的刻度 m i n e w m_i^{new} minew 进行相应的指数缩放,然后相加得到新的部分归一化因子。

现在我们得到了新的刻度 m i n e w m_i^{new} minew 和新的部分归一化因子 ℓ i n e w \ell_i^{new} inew。剩下的,就是把老的输出 O i O_i Oi 和当前块新增的输出 P ~ i j V j \tilde{P}_{ij}V_j P~ijVj 合在一起,放到一个统一的“刻度”下。第 12 行的写法大致如下(作一点数学化展开):

O i ( n e w ) = [ d i a g ( ℓ i n e w ) ] − 1 ⏟ 相当于每行除以  ℓ i n e w    ( d i a g ( ℓ i )    e   m i − m i n e w    O i ⏟ "重刻度"后的旧输出    +    e   m ~ i j − m i n e w    P ~ i j    V j ⏟ “重刻度”后的新块输出 ) . \begin{aligned} O_i^{(\mathrm{new})} &= \underbrace{\bigl[\mathrm{diag}(\ell_i^{new})\bigr]^{-1}}{ \text{相当于每行除以 }\ell_i^{new} } \;\Bigl( \underbrace{ \mathrm{diag}(\ell_i)\;e^{\,m_i - m_i^{new}} \;O_i }{ \text{"重刻度"后的旧输出} } \;+\; \underbrace{ e^{\,\tilde{m}{ij} - m_i^{new}} \;\tilde{P}{ij}\;V_j }_{ \text{“重刻度”后的新块输出} } \Bigr). \end{aligned} Oi(new)= [diag(inew)]1相当于每行除以 inew( diag(i)emiminewOi"重刻度"后的旧输出+重刻度后的新块输出 em~ijminewP~ijVj).

旧输出 O i O_i Oi 的重刻度:

  1. 先将 O i O_i Oi 乘上原先“局部归一化”用到的因子 ℓ i \ell_i i,并且再乘上 e m i − m i n e w e^{m_i - m_i^{new}} emiminew,把它转移到新的刻度 ℓ i n e w \ell_i^{new} inew 所在的坐标系。
  2. 新块的输出 P ~ i j V j \tilde{P}_{ij}V_j P~ijVj
    这是在本次块的局部最大值 m ~ i j \tilde{m}{ij} m~ij 上做的 exponent,所以要乘上 e m ~ i j − m i n e w e^{\tilde{m}{ij} - m_i^{new}} em~ijminew 去对齐到新的刻度。
  3. 再除以新的 ℓ i n e w \ell_i^{new} inew
    由于我们想让最终的 O i O_i Oi 处在 “输出 = 1 (部分Softmax分母) × ( 加权和 ) ” “输出 = \frac{1}{\text{(部分Softmax分母)}}\times(加权和)” 输出=(部分Softmax分母)1×(加权和) 的形式,所以最后整体还要除以 ℓ i n e w \ell_i^{new} inew
    代码里用 [ d i a g ( ℓ i n e w ) ] − 1 \bigl[\mathrm{diag}(\ell_i^{new})\bigr]^{-1} [diag(inew)]1 的写法,是因为在实现中 O i O_i Oi 维度是 B r × d B_r \times d Br×d,而 ℓ i n e w \ell_i^{new} inew 是长度为 B r B_r Br 的向量,需要对每个行分别做除法。

最终把这个新的 O i ( n e w ) O_i^{(\mathrm{new})} Oi(new) 写回HBM,并且更新 ℓ i ← ℓ i n e w \ell_i \leftarrow \ell_i^{new} iinew m i ← m i n e w m_i \leftarrow m_i^{new} miminew(见第 13 行),就完成了对第 j j j 块 Key/Value 的处理。

小结: 第 12 行就是做了一次“加权合并”,把“老的注意力累加输出”和“新的块注意力输出”整合到同一个数值基准,并更新归一化,使得当所有块都处理完后,O_i 就等价于 s o f t m a x ( Q i K T ) V \mathrm{softmax}(Q_iK^T)V softmax(QiKT)V 中“i 行块”的那部分结果。

FlashAttention公式推理:https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad

### Faster-Block 技术文档和实现方式 #### 背景介绍 Faster-Block 是一种旨在加速神经网络推理过程的技术,特别适用于卷积神经网络(CNN)。该技术通过优化计算图中的某些操作来提高效率。具体来说,在 FlashAttention-2 中提到的工作分配策略可以类比理解为更快更有效的资源调度方法[^1]。 #### 实现原理 为了提升性能,Faster-Block 主要采用了以下几个方面的改进措施: - **并行化处理**:类似于FlashAttention-2中所采用的方法,对于适合并行执行的任务进行了重新划分,使得多个线程能够同时工作而不需要频繁地等待其他线程完成其任务。 - **减少同步开销**:通过对算法逻辑进行调整以及合理安排数据访问模式,降低了不同硬件单元间相互协调所需的时间成本. - **缓存友好型设计**:考虑到现代处理器架构的特点,对常用的数据结构进行了针对性改造,使其更加适应CPU缓存机制,进而加快了整体运算速度. #### Python代码示例 下面给出一段简单的Python伪代码用于展示如何构建一个基于上述原则的快速模块: ```python import torch.nn as nn class FastConvBlock(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(FastConvBlock, self).__init__() # 使用分组卷积代替标准卷积层以增加并行度 self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, groups=max(1,in_channels//4), # 动态设置分组数 kernel_size=kernel_size, stride=stride, padding=padding) def forward(self, x): return self.conv(x) def create_faster_block(): block = FastConvBlock(in_channels=64, out_channels=128) return block ``` 此段代码定义了一个名为 `FastConvBlock` 的类,它继承自 PyTorch 的 `nn.Module` 类,并实现了带有更高并行性的卷积操作。注意这里引入了分组卷积的概念,这有助于进一步分解计算任务,促进多核或多GPU环境下的高效利用。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值