三维矩阵*三维矩阵
import torch
tensors = torch.tensor([[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]],[[1,2],[1,2],[1,2]]])
print(tensors.data)
print(tensors.data.shape)
mul_result = torch.matmul(tensors,tensors.transpose(1, 2))
print(tensors.transpose(1, 2))
print(tensors.transpose(1, 2).data.shape)
print(mul_result.data)
print(mul_result.data.shape)
实验结果:
tensor([[[1, 2],
[1, 2],
[1, 2]],
[[1, 2],
[1, 2],
[1, 2]],
[[1, 2],
[1, 2],
[1, 2]]])
torch.Size([3, 3, 2])
tensor([[[1, 1, 1],
[2, 2, 2]],
[[1, 1, 1],
[2, 2, 2]],
[[1, 1, 1],
[2, 2, 2]]])
torch.Size([3, 2, 3])
tensor([[[5, 5, 5],
[5, 5, 5],
[5, 5, 5]],
[[5, 5, 5],
[5, 5, 5],
[5, 5, 5]],
[[5, 5, 5],
[5, 5, 5],
[5, 5, 5]]])
torch.Size([3, 3, 3])
<