目录
1.代码实现
#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)
#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])
embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法
#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
"""
embed_v:表示对中心词进行embedding层
embed_u:对上下文词进行embedding层
"""
v = embed_v(center) #中心词的词向量表达
u = embed_u(contexts_and_negatives) #上下文词的词向量表达
#用中心词来预测上下文词
#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法
pred = torch.bmm(v, u.permute(0, 2, 1)) #矩阵乘法