特别感谢实验室王老师和汤老师对本次讨论的大力支持~
在查看self-attention的过程中,我对Q、K矩阵的mask操作不太理解,认为原self-attention的mask操作不完整,因此进行了以下探索。
# 本文使用的self-attention借鉴了TENER模型的代码
# 2019-TENER: Adapting Transformer Encoder for Named Entity Recognition
import torch
torch.random.manual_seed(1)
torch.set_printoptions(profile='full')
mask_1 = torch.tensor([[1, 1, 0],
[1, 1, 1],
[1, 0, 0],
[1, 0, 0]])
mask_2 = torch.tensor([[[1, 1, 0], [1, 1, 0], [0, 0, 0]],
[[1, 1, 1], [1, 1, 1], [1, 1, 1]],
[[1, 0, 0], [0, 0, 0], [0, 0, 0]],
[[1, 0, 0], [0, 0, 0], [0, 0, 0]]])
attn = torch.rand(4, 1, 3, 3) # batch_size, n_head, seq_len, seq_len
V = torch.rand(4, 1, 3, 5) # batch_size, n_head, seq_len, dim
print(V)
print()
# attn.masked_fill_(mask=mask_1[:,None, None].eq(0), value=float('-inf'))
attn.masked_fill_(mask=mask_2[:, None].eq(0), value=float('-inf'))
print(attn)
attn = torch.softmax(attn, dim=-1)
attn = torch.where(torch.isnan(attn), torch.full_like(attn, 0), attn) # softmax会出现nan的情况,因为一行数全部都是-inf,分母变为0
print(attn)
print()
A = torch.matmul(attn, V) # batch_size, n_head, seq_len, dim
print( A )
print()
# 简化的一个FFN
W_1 = torch.rand(5, 4)
W_2 = torch.rand(4, 4)
A = A.transpose(1,2).contiguous().view(4,-1, 1*5)
print( torch.matmul(torch.matmul(A, W_1), W_2) )
'''
mask_1的最后的结果--A
tensor([[[5.2098, 4.7500, 1.7479, 2.4909],
[5.2427, 4.7842, 1.7666, 2.5070],
[5.0799, 4.6148, 1.6737, 2.4275]],
[[4.9464, 4.4922, 1.7391, 2.4124],
[5.1933, 4.7331, 1.8184, 2.5240],
[4.9818, 4.5266, 1.7505, 2.4284]],
[[3.7648, 3.5294, 1.3037, 1.7531],
[3.7648, 3.5294, 1.3037, 1.7531],
[3.7648, 3.5294, 1.3037, 1.7531]],
[[3.8360, 3.5142, 1.3775, 1.8828],
[3.8360, 3.5142, 1.3775, 1.8828],
[3.8360, 3.5142, 1.3775, 1.8828]]])
mask_2的最后的结果--A
tensor([[[5.2098, 4.7500, 1.7479, 2.4909],
[5.2427, 4.7842, 1.7666, 2.5070],
[0.0000, 0.0000, 0.0000, 0.0000]],
[[4.9464, 4.4922, 1.7391, 2.4124],
[5.1933, 4.7331, 1.8184, 2.5240],
[4.9818, 4.5266, 1.7505, 2.4284]],
[[3.7648, 3.5294, 1.3037, 1.7531],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000]],
[[3.8360, 3.5142, 1.3775, 1.8828],
[0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000]]])
'''
疑惑:在上述代码中,mask_1是原self-attention的操作,我的问题是,在mask_1最后得到的A矩阵中,padding位置仍然有数值,理想情况应该是padding位置是0才对,这样才不会对以后的操作,如attn*v,有影响。
解惑:按照我的想法,对mask矩阵进行了修改,padding的位置变为了0,即mask_2的结果。
结论:上述代码中,mask_1,mask_2起到了不同的作用,都是对QK的结果(attn-[batch_size, n_head, seq_len, seq_len])进行mask。
其中,mask_1只对最后一个维度进行mask,也就是最后的那个seq_len;mask_2则对[seq_len, seq_len]进行mask。
从上述结果可以看出,对QK乘积进行mask的结果中,padding位置是否有数值对其他位置是不影响的--其他位置数据没有发生变化
,因此,原self-attention对padding的处理更简单,padding位置即使不是0,对以后的操作也不会产生影响。