【coding】手写多头注意力

本文介绍了如何使用PyTorch实现自注意力机制的编码模板,包括SelfAttention函数和MultiHeadAttention模块,重点讲解了softmax、mask处理和dropout的运用。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

公式比较简单,softmax(q*k/sqrt(d_k))*v

记录一个coding模版,记得加上mask和dropout 

多头的话,就是把d_model拆分成多个头,然后交换sequence_length和n_head进行自注意力计算,得到的张量再还原回去,最后过一层线性层作为输出。

def SelfAttention(q,k,v,mask,dropout):
    d_k = q.size(-1)
    scores = torch.matmul(q,k.transpose(-1,-2))
    if mask:
        scores.mask_fill(mask==0,1e-9)
    scores = F.softmax(scores,dim=-1)
    if dropout:
        scores = dropout(scores)
    return torch.matmul(scores,v)

class MultiHeadAttention(nn.Moudle):
    def __init__(self,):
        super().__init__()
    
    def forward(self, n_head, d_model, q, k, v, mask, dropout):
        assert (d_model%head)==0
        n_batch = q.size(0)
        d_k = q.size(-1)

        # 四个线性变换矩阵,维度和输入维度一致
        w_q = nn.Linear(d_model,d_model)
        w_k = nn.Linear(d_model,d_model)
        w_v = nn.Linear(d_model,d_model)
        w_o = nn.Linear(d_model,d_model)

        # 将输入的q、k、v通过线性变换,这里使用transpose交换1、2维度的目的是我们后续计算self-attention的时候是分别在每个头上
        # 对sequence_length*d_k矩阵计算的,计算完self-attention再还原回去
        q = w_q(q).view(n_batch, -1, n_head, d_k).transpose(1,2)
        k = w_k(k).view(n_batch, -1, n_head, d_k).transpose(1,2)
        v = w_v(v).view(n_batch, -1, n_head, d_k).transpose(1,2)

        # 如果有mask,需要对第1个纬度升维,因为q、k、v已经从三维变成了四维,这里的mask要对应上
        if mask:
            mask.unsqueeze(1)

        atten_scores = SelfAttention(q,k,v,mask,nn.Dropout(p=dropout))

        # 通过self-attention后的张量是4维的,且进行self-attention前的q、k、v是交换了sequence_length和n_head的,所以现在得先还原维度
        atten_scores = atten_scores.transpose(1,2).contiguous().view(n_batch, -1, d_k*n_head)

        reutrn w_o(atten_scores)

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小小的香辛料

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值