从零开始用pytorch搭建Transformer模型(中文可以翻译成变形金刚)。
训练它来实现一个有趣的实例:两数之和。
输入输出类似如下:
输入:"12345+54321" 输出:"66666"
我们把这个任务当做一个机器翻译任务来进行。输入是一个字符序列,输出也是一个字符序列(seq-to-seq).
这和机器翻译的输入输出结构是类似的,所以可以用Transformer来做。
参考资料:
论文《Attention is All you needed》: https://arxiv.org/pdf/1706.03762.pdf
哈佛博客:https://github.com/harvardnlp/annotated-transformer/
一,准备数据
import random
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
# 定义字典
words_x = '<PAD>,1,2,3,4,5,6,7,8,9,0,<SOS>,<EOS>,+'
vocab_x = {word: i for i, word in enumerate(words_x.split(','))}
vocab_xr = [k for k, v in vocab_x.items()] #反查词典
words_y = '<PAD>,1,2,3,4,5,6,7,8,9,0,<SOS>,<EOS>'
vocab_y = {word: i for i, word in enumerate(words_y.split(','))}
vocab_yr = [k for k, v in vocab_y.items()] #反查词典
#两数相加数据集
def get_data():
# 定义词集合
words = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
# 每个词被选中的概率
p = np.array([7, 5, 5, 7, 6, 5, 7, 6, 5, 7])
p = p / p.sum()
# 随机采样n1个词作为s1
n1 = random.randint(10, 20)
s1 = np.random.choice(words, size=n1, replace=True, p=p)
s1 = s1.tolist()
# 随机采样n2个词作为s2
n2 = random.randint(10, 20)
s2 = np.random.choice(words, size=n2, replace=True, p=p)
s2 = s2.tolist()
# x等于s1和s2字符上的相加
x = s1 + ['+'] + s2
# y等于s1和s2数值上的相加
y = int(''.join(s1)) + int(''.join(s2))
y = list(str(y))
# 加上首尾符号
x = ['<SOS>'] + x + ['<EOS>']
y = ['<SOS>'] + y + ['<EOS>']
# 补pad到固定长度
x = x + ['<PAD>'] * 50
y = y + ['<PAD>'] * 51
x = x[:50]
y = y[:51]
# 编码成token
token_x = [vocab_x[i] for i in x]
token_y = [vocab_y[i] for i in y]
# 转tensor
tensor_x = torch.LongTensor(token_x)
tensor_y = torch.LongTensor(token_y)
return tensor_x, tensor_y
def show_data(tensor_x,tensor_y) ->"str":
words_x = "".join([vocab_xr[i] for i in tensor_x.tolist()])
words_y = "".join([vocab_yr[i] for i in tensor_y.tolist()])
return words_x,words_y
x,y = get_data()
print(x,y,"\n")
print(show_data(x,y))
# 定义数据集
class TwoSumDataset(torch.utils.data.Dataset):
def __init__(self,size = 100000):
super(Dataset, self).__init__()
self.size = size
def __len__(self):
return self.size
def __getitem__(self, i):
return get_data()
ds_train = TwoSumDataset(size = 100000)
ds_val = TwoSumDataset(size = 10000)
# 数据加载器
dl_train = DataLoader(dataset=ds_train,
batch_size=200,
drop_last=True,
shuffle=True)
dl_val = DataLoader(dataset=ds_val,
batch_size=200,
drop_last=True,
shuffle=False)
for src,tgt in dl_train:
print(src.shape)
print(tgt.shape)
break
torch.Size([200, 50])
torch.Size([200, 51])
二,定义模型
下面,我们会像搭积木建城堡那样从低往高地构建Transformer模型。
先构建6个基础组件:多头注意力、前馈网络、层归一化、残差连接、单词嵌入、位置编码。类似用最基础的积木块搭建了 墙壁,屋顶,篱笆,厅柱,大门,窗户 这样的模块。
然后用这6个基础组件构建了3个中间成品: 编码器,解码器,产生器。类似用基础组件构建了城堡的主楼,塔楼,花园。
最后用这3个中间成品组装成Tranformer完整模型。类似用主楼,塔楼,花园这样的中间成品拼凑出一座完整美丽的城堡。
1, 多头注意力: MultiHeadAttention (用于融合不同单词之间的信息, 三处使用场景,①Encoder self-attention, ② Decoder masked-self-attention, ③ Encoder-Decoder cross-attention)
2, 前馈网络: PositionwiseFeedForward (用于逐位置将多头注意力融合后的信息进行高维映射变换,简称FFN)
3, 层归一化: LayerNorm (用于稳定输入,每个样本在Sequece和Feature维度归一化,相比BatchNorm更能适应NLP领域变长序列)
4, 残差连接: ResConnection (用于增强梯度流动以降低网络学习难度, 可以先LayerNorm再Add,LayerNorm也可以放在残差Add之后)
5, 单词嵌入: WordEmbedding (用于编码单词信息,权重要学习,输出乘了sqrt(d_model)来和位置编码保持相当量级)
6, 位置编码: PositionEncoding (用于编码位置信息,使用sin和cos函数直接编码绝对位置)
7, 编码器: TransformerEncoder (用于将输入Sequence编码成与Sequence等长的memory向量序列, 由N个TransformerEncoderLayer堆叠而成)
8, 解码器: TransformerDecoder (用于将编码器编码的memory向量解码成另一个不定长的向量序列, 由N个TransformerDecoderLayer堆叠而成)
9, 生成器: Generator (用于将解码器解码的向量序列中的每个向量映射成为输出词典中的词,一般由一个Linear层构成)
10, 变形金刚: Transformer (用于Seq2Seq转码,例如用于机器翻译,采用EncoderDecoder架构,由Encoder, Decoder 和 Generator组成)
import torch
from torch import nn
import torch.nn.functional as F
import copy
import math
import numpy as np
import pandas as pd
def clones(module, N):
"Produce N identical layers."
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
1,多头注意力 MultiHeadAttention
需要逐步理解 ScaledDotProductAttention->MultiHeadAttention->MaskedMultiHeadAttention
先理解什么是 ScaledDotProductAttention,再理解MultiHeadAttention, 然后理解MaskedMultiHeadAttention
class ScaledDotProductAttention(nn.Module):
"Compute 'Scaled Dot Product Attention'"
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self,query, key, value, mask=None, dropout=None):
d_k = query.size(-1)
scores = query@key.transpose(-2,-1) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e20)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return p_attn@value, p_attn
class MultiHeadAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"Take in model size and number of heads."
super(MultiHeadAttention, self).__init__()
assert d_model % h == 0
# We assume d_v always equals d_k
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None #记录 attention矩阵结果
self.dropout = nn.Dropout(p=dropout)
self.attention = ScaledDotProductAttention()
def forward(self, query, key, value, mask=None):
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
nbatches = query.size(0)
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = [
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))
]
# 2) Apply attention on all the projected vectors in batch.
x, self.attn = self.attention(query, key, value, mask=mask,
dropout=self.dropout)
# 3) "Concat" using a view and apply a final linear.
x = x.transpose(1, 2).contiguous() \
.view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)
#为了让训练过程与解码过程信息流一致,遮挡tgt序列后面元素,设置其注意力为0
def tril_mask(data):
"Mask out future positions."
size = data.size(-1) #size为序列长度
full = torch.full((1,size,size),1,dtype=torch.int,device=data.device)
mask = torch.tril(full).bool()
return mask
#设置对<PAD>的注意力为0
def pad_mask(data, pad=0):
"Mask out pad positions."
mask = (data!=pad).unsqueeze(-2)
return mask
#计算一个batch数据的src_mask和tgt_mask
class MaskedBatch:
"Object for holding a batch of data with mask during training."
def __init__(self, src, tgt=None, pad=0):
self.src = src
self.src_mask = pad_mask(src,pad)
if tgt is not None:
self.tgt = tgt[:,:-1] #训练时,拿tgt的每一个词输入,去预测下一个词,所以最后一个词无需输入
self.tgt_y = tgt[:, 1:] #第一个总是<SOS>