transformer的self-attention中,mask是怎么起作用的

本文探讨了在Transformer模型中self-attention的mask操作,通过代码示例展示了不同mask方式对Q、K矩阵的影响。作者指出,原始的mask操作可能未完全排除padding的影响,但实验结果显示,即使padding位置非零,也不会对后续计算造成影响。文中还提出了修改后的mask矩阵,确保padding位置为0,以避免潜在问题。最后,作者总结了两种mask在处理QK乘积时的不同作用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

特别感谢实验室王老师和汤老师对本次讨论的大力支持~

在查看self-attention的过程中,我对Q、K矩阵的mask操作不太理解,认为原self-attention的mask操作不完整,因此进行了以下探索。

# 本文使用的self-attention借鉴了TENER模型的代码
# 2019-TENER: Adapting Transformer Encoder for Named Entity Recognition
import torch

torch.random.manual_seed(1)

torch.set_printoptions(profile='full')

mask_1 = torch.tensor([[1, 1, 0],
                       [1, 1, 1],
                       [1, 0, 0],
                       [1, 0, 0]])

mask_2 = torch.tensor([[[1, 1, 0], [1, 1, 0], [0, 0, 0]],
                       [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
                       [[1, 0, 0], [0, 0, 0], [0, 0, 0]],
                       [[1, 0, 0], [0, 0, 0], [0, 0, 0]]])

attn = torch.rand(4, 1, 3, 3)   # batch_size, n_head, seq_len, seq_len

V = torch.rand(4, 1, 3, 5)      # batch_size, n_head, seq_len, dim
print(V)
print()

# attn.masked_fill_(mask=mask_1[:,None, None].eq(0), value=float('-inf'))
attn.masked_fill_(mask=mask_2[:, None].eq(0), value=float('-inf'))

print(attn)
attn = torch.softmax(attn, dim=-1)
attn = torch.where(torch.isnan(attn), torch.full_like(attn, 0), attn)   # softmax会出现nan的情况,因为一行数全部都是-inf,分母变为0

print(attn)
print()
A = torch.matmul(attn, V)		# batch_size, n_head, seq_len, dim
print( A )
print()

# 简化的一个FFN
W_1 = torch.rand(5, 4)
W_2 = torch.rand(4, 4)
A = A.transpose(1,2).contiguous().view(4,-1, 1*5)
print( torch.matmul(torch.matmul(A, W_1), W_2) )

'''
mask_1的最后的结果--A
tensor([[[5.2098, 4.7500, 1.7479, 2.4909],
         [5.2427, 4.7842, 1.7666, 2.5070],
         [5.0799, 4.6148, 1.6737, 2.4275]],

        [[4.9464, 4.4922, 1.7391, 2.4124],
         [5.1933, 4.7331, 1.8184, 2.5240],
         [4.9818, 4.5266, 1.7505, 2.4284]],

        [[3.7648, 3.5294, 1.3037, 1.7531],
         [3.7648, 3.5294, 1.3037, 1.7531],
         [3.7648, 3.5294, 1.3037, 1.7531]],

        [[3.8360, 3.5142, 1.3775, 1.8828],
         [3.8360, 3.5142, 1.3775, 1.8828],
         [3.8360, 3.5142, 1.3775, 1.8828]]])

mask_2的最后的结果--A
tensor([[[5.2098, 4.7500, 1.7479, 2.4909],
         [5.2427, 4.7842, 1.7666, 2.5070],
         [0.0000, 0.0000, 0.0000, 0.0000]],

        [[4.9464, 4.4922, 1.7391, 2.4124],
         [5.1933, 4.7331, 1.8184, 2.5240],
         [4.9818, 4.5266, 1.7505, 2.4284]],

        [[3.7648, 3.5294, 1.3037, 1.7531],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]],

        [[3.8360, 3.5142, 1.3775, 1.8828],
         [0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000]]])
'''

疑惑:在上述代码中,mask_1是原self-attention的操作,我的问题是,在mask_1最后得到的A矩阵中,padding位置仍然有数值,理想情况应该是padding位置是0才对,这样才不会对以后的操作,如attn*v,有影响。

解惑:按照我的想法,对mask矩阵进行了修改,padding的位置变为了0,即mask_2的结果。

结论:上述代码中,mask_1,mask_2起到了不同的作用,都是对QK的结果(attn-[batch_size, n_head, seq_len, seq_len])进行mask。
其中,mask_1只对最后一个维度进行mask,也就是最后的那个seq_len;mask_2则对[seq_len, seq_len]进行mask。
从上述结果可以看出,对QK乘积进行mask的结果中,padding位置是否有数值对其他位置是不影响的--其他位置数据没有发生变化,因此,原self-attention对padding的处理更简单,padding位置即使不是0,对以后的操作也不会产生影响。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值