什么是 logits?
logits
是模型输出的未归一化预测值,通常是全连接层的输出。在分类任务中,logits 的形状通常为 (batch_size, num_labels)
,其中 batch_size
是一个批次中的样本数,num_labels
是分类任务中的类别数。
logits
是模型的输出。假设logits
的形状为(batch_size, num_labels)
,例如(32, 3)
,表示每个批次有32
个样本,每个样本有3
个类别的预测值。
什么是交叉熵损失函数?
交叉熵损失函数(Cross-Entropy Loss
)是一种常用于分类任务的损失函数。它衡量的是预测分布与真实分布之间的差异。具体而言,它会计算每个样本的预测类别与真实类别之间的距离,然后取平均值。
在 PyTorch 中,交叉熵损失函数可以通过 torch.nn.CrossEntropyLoss
来实现。该函数结合了 LogSoftmax
和 NLLLoss
两个操作,适用于未归一化的 logits
。
具体示例
假设有一个分类任务,模型的输出和标签如下:
logits = torch.tensor([[2.0, 0.5, 0.3], [0.2, 2.0, 0.5]])
labels = torch.tensor([0, 1])
解释如下:
logits
的形状是 (2, 3),表示有 2 个样本,每个样本有 3 个类别的预测值。labels
的形状是 (2,),表示有 2 个样本的真实类别标签。
利用 labels 和 logits 计算损失:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels) # 计算交叉熵损失
完整计算代码:
import torch
import torch.nn as nn
logits = torch.tensor([[2.0, 0.5, 0.3], [0.2, 2.0, 0.5]])
labels = torch.tensor([0, 1])
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
print(loss.item())