Transformer解码器如何使用编码器的输出?

在 Transformer 模型中,解码器(Decoder)通过「编码器 - 解码器注意力机制(Encoder-Decoder Attention)」利用编码器的输出,从而在生成目标序列时动态地关注输入序列的相关部分。以下是具体机制的详细解析:

一、编码器的输出是什么?

编码器的核心作用是将输入序列(如源语言句子)编码为上下文语义向量,具体输出为:

  • 最后一层编码器每个位置的隐藏状态,记为 K(键) 和 V(值)(在 Transformer 中,K和V通常是同一隐藏状态的线性变换,即 K=V=EncoderOutput)。
  • 形状为:序列长度隐藏层维度(批量处理时为(batch_size) 序列长度隐藏层维度)。

二、解码器如何使用编码器的 K 和 V?

解码器的每个层包含两个注意力子层:

  • 自注意力(Self-Attention):处理解码器当前层的输入(已生成的目标序列部分),生成 查询向量(Query, Q)
  • 编码器 - 解码器注意力(Encoder-Decoder Attention):使用自注意力输出的 Q,与编码器的 K 和 V 计算注意力,得到上下文向量,用于指导目标序列的生成。
关键步骤解析:

2. 编码器 - 解码器注意力计算 

输出:上下文向量C 与解码器自注意力的输出结合(通过残差连接和层归一化),输入到前馈神经网络(FFN),生成下一个位置的隐藏状态。

三、为什么解码器需要编码器的输出?

  • 核心作用:让解码器在生成每个目标词时,能够动态聚焦于输入序列中的相关部分(类似传统注意力机制)。例如:
    • 在机器翻译中,解码器生成 “你好” 时,通过编码器 - 解码器注意力关注输入序列中的 “Hello”;生成 “世界” 时,关注 “world”。
  • 本质:编码器的输出是输入序列的全局语义表示,解码器通过注意力机制从中 “提取” 与当前生成位置最相关的信息,实现 “内容指导生成”。

四、对比:解码器的自注意力 vs. 编码器 - 解码器注意力 

自注意力(解码器)的Q、K、V 均来自解码器当前层输入 ,作用是建模目标序列内部的依赖关系(如已生成词的顺序)

编码器 - 解码器注意力Q 来自解码器自注意力,K、V 来自编码器输出,作用是建立目标序列与输入序列的跨模态关联

五、代码示例(简化版 PyTorch)

import torch
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, n_heads)  # 自注意力
        self.enc_dec_attn = nn.MultiheadAttention(d_model, n_heads)  # 编码器-解码器注意力
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        # tgt: 解码器输入,形状 (目标序列长度, batch_size, d_model)
        # memory: 编码器输出,形状 (输入序列长度, batch_size, d_model)
        
        # 自注意力
        tgt2, _ = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = self.norm1(tgt + tgt2)  # 残差连接
        
        # 编码器-解码器注意力(Q来自tgt,K/V来自memory)
        tgt2, _ = self.enc_dec_attn(tgt, memory, memory, attn_mask=memory_mask)
        tgt = self.norm2(tgt + tgt2)  # 残差连接
        
        # 前馈网络
        tgt2 = self.ffn(tgt)
        tgt = self.norm3(tgt + tgt2)
        return tgt
  • 关键参数memory 即编码器的输出,在解码器中作为 enc_dec_attn 的 K 和 V 输入。
  • 多头注意力:通过多头机制(n_heads)将注意力计算拆分为多个子空间,增强模型捕捉不同语义关联的能力。

六、总结

解码器通过编码器 - 解码器注意力机制,将编码器输出的全局语义信息与解码器当前生成的局部信息动态结合,实现了 “根据输入内容指导目标序列生成” 的核心逻辑。这种设计使得 Transformer 在机器翻译、文本生成等任务中能够高效地建模跨序列的依赖关系,是其强大性能的关键原因之一。

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值