RuntimeError: Expected tensor for argument #1 ‘indices’ to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)
RuntimeError: Expected tensor for argument #1 'indices' to have scalar type Long; but got torch.IntTensor instead (while checking arguments for embedding)
题目限制,错误信息不完整。
定位到了这一行:
out = model.forward(batch.src, batch.trg,
batch.src_mask, batch.trg_mask)
打印出来看看类型:
print(batch.src.type())
print(batch.src_mask.type())
print(batch.trg.type())
print(batch.trg_mask.type())
把torch.IntTensor 对应的变量转化为Long,理论上batch.src.long()
应该可以解决,但是又提示让转化为int64,总之就是打印变量类型,根据提示去转化吧。
print大法好啊。
out = model.forward(batch.src.to(torch.int64), batch.trg.to(torch.int64),
batch.src_mask, batch.trg_mask)