import torch
class_num = 3
batch_size = 2
label = torch.LongTensor(batch_size, 1).random_() % class_num
print(label)
#tensor([[6],
# [0],
# [3],
# [2]])
xx=torch.zeros(batch_size, class_num).scatter_(1, label, 1)##原来0的位置 职位1
print(xx)
'''
tensor([[1],#####################代表 01 10位置变为1
[0]])####################代表 01 10位置变为1
tensor([[0., 1., 0.],
[1., 0., 0.]])
tensor([[0],##################代表 00 12
新一代蒸馏算法
最新推荐文章于 2024-10-01 18:52:56 发布