一、注意力机制基本原理
-
动机
传统的 RNN/CNN 在处理长序列时,容易出现“信息遗忘”或“长依赖难以捕捉”的问题。注意力机制通过“动态加权”——在不同位置分配不同“关注度”(权重)——让模型在解码或特征聚合时,将更多注意力放在最相关的输入上,从而:- 缓解长距离依赖
- 提升对关键信息的捕获能力
- 增强可解释性:权重分布可视化后,可以直观看到模型聚焦的位置
-
核心思想
给定一组“值” V V V和与之对应的“键” K K K,对于每个“查询” Q Q Q,计算它与所有键的相似度分数(Score),再对这些分数做 Softmax,得到权重向量;最后以权重为系数,对值向量做加权求和,得到注意力输出:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K ⊤ d k ) V \mathrm{Attention}(Q,K,V) \;=\; \mathrm{softmax}\Bigl(\tfrac{QK^\top}{\sqrt{d_k}}\Bigr)\,V Attention(Q,K,V)=softmax(dkQK⊤)V
其中, d k \sqrt{d_k} dk 是缩放因子,用以控制内积数值量级,确保权重是单位方差的,保持 Softmax 梯度稳定。
二、Query、Key、Value 定义与工作流程
组件 | 作用 | 形状(Batch, Seq, Dim) |
---|---|---|
Query | 要“查询”的向量集合,一般来自解码器当前步或同一序列 | ( B , L q , d k ) (B, L_q, d_k) (B,Lq,dk) |
Key | 用于“匹配”查询的向量集合,一般来自编码器输出或同一序列 | ( B , L k , d k ) (B, L_k, d_k) (B,Lk,dk) |
Value | 与 Key 一一对应的“值”向量集合,最终以加权和形式输出 | ( B , L v , d v ) (B, L_v, d_v) (B,Lv,dv), 通常 L k = L v L_k=L_v Lk=Lv |
-
投影(Projection)
通常先用三组线性变换,将同一样本输入 X \mathbf{X} X 映射到三个子空间:
Q = X W Q , K = X W K , V = X W V Q = XW^Q,\quad K = XW^K,\quad V = XW^V Q=XWQ,K=XWK,V=XWV
其中 (WQ,WK\in\mathbb{R}^{d_{\text{model}}\times d_k}),(WV\in\mathbb{R}{d_{\text{model}}\times d_v})。 -
计算相似度(Score)
S c o r e ( Q i , K j ) = Q i ⋅ K j ( i = 1 … L q , j = 1 … L k ) \mathrm{Score}(Q_i,K_j) \;=\; Q_i \cdot K_j \quad (i=1\dots L_q,\; j=1\dots L_k) Score(Qi,Kj)=Qi⋅Kj(i=1…Lq,j=1…Lk)
得到一个 ( L q × L k ) (L_q\times L_k) (Lq×Lk) 的打分矩阵。 -
缩放与归一化
S c o r e ~ = S c o r e d k , A = s o f t m a x ( S c o r e ~ ) ( 沿 K 维度 ) \widetilde{\mathrm{Score}} = \frac{\mathrm{Score}}{\sqrt{d_k}},\quad A = \mathrm{softmax}\bigl(\widetilde{\mathrm{Score}}\bigr)\quad (\text{沿 }K\text{ 维度}) Score =dkScore,A=softmax(Score )(沿 K 维度)
使得打分数值不至于过大或过小,保持梯度有效。 -
加权求和输出
A t t e n t i o n ( Q , K , V ) = A V , y i = ∑ j = 1 L k A i j V j \mathrm{Attention}(Q,K,V) = A\;V,\quad y_i = \sum_{j=1}^{L_k} A_{ij}\,V_j Attention(Q,K,V)=AV,yi=j=1∑LkAijVj
最终输出与 Query 个数相同,每个 y i y_i yi 都融合了所有 Value,重点突出了与 Q i Q_i Qi 相关的部分。
三、缩放点乘注意力的数学公式
设单个 Query 向量维度为
d
k
d_k
dk,则整批计算可写为矩阵形式:
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
⊤
d
k
)
V
,
\mathrm{Attention}(Q,K,V) = \mathrm{softmax}\Bigl(\tfrac{QK^\top}{\sqrt{d_k}}\Bigr)\,V,
Attention(Q,K,V)=softmax(dkQK⊤)V,
- Q ∈ R B × L q × d k Q\in\mathbb{R}^{B\times L_q\times d_k} Q∈RB×Lq×dk
- K ∈ R B × L k × d k K\in\mathbb{R}^{B\times L_k\times d_k} K∈RB×Lk×dk
- V ∈ R B × L k × d v V\in\mathbb{R}^{B\times L_k\times d_v} V∈RB×Lk×dv
- 输出为 R B × L q × d v \mathbb{R}^{B\times L_q\times d_v} RB×Lq×dv
四、PyTorch 实现示例
下面给出一个简洁版的「缩放点乘注意力」函数,以及如何在多头注意力(Multi‑Head Attention)中使用它。
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(q, k, v, mask=None):
"""
q, k, v: 张量形状均为 (B, num_heads, L, d_k)
mask: 可选张量,形状 (B, 1, L_q, L_k),用于屏蔽不可关注位置
"""
d_k = q.size(-1)
# 1. 计算未缩放打分
scores = torch.matmul(q, k.transpose(-2, -1)) # -> (B, heads, L_q, L_k)
# 2. 缩放
scores = scores / math.sqrt(d_k)
# 3. 可选屏蔽
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 4. 归一化
attn = F.softmax(scores, dim=-1) # -> (B, heads, L_q, L_k)
# 5. 加权求和
output = torch.matmul(attn, v) # -> (B, heads, L_q, d_k)
return output, attn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_o = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, L, _ = x.size()
# 1. 线性映射并切分多头
q = self.w_q(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
k = self.w_k(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
v = self.w_v(x).view(B, L, self.num_heads, self.d_k).transpose(1,2)
# q,k,v: (B, heads, L, d_k)
# 2. 点乘注意力
attn_output, attn_weights = scaled_dot_product_attention(q, k, v, mask)
# attn_output: (B, heads, L, d_k)
# 3. 合并多头
attn_output = attn_output.transpose(1,2).contiguous().view(B, L, -1)
# 4. 最后一层线性
output = self.w_o(attn_output) # (B, L, d_model)
return output, attn_weights