Label Smoothing Cross Entropy介绍
Label Smoothing 是一种正则化技术,用于改进分类任务中的交叉熵损失函数。传统的交叉熵损失函数假设目标标签是硬性(hard)的,即每个样本只有一个正确的类别标签,并且该类别的概率为 1,其他类别的概率为 0。然而,这种硬性标签可能会导致模型过拟合训练数据,尤其是在训练数据有限或标签可能存在噪声的情况下。
Label Smoothing 的基本思想是对目标标签进行“平滑”处理,将原本硬性的标签分布替换为一个更柔和的分布。这样可以减少模型对单一类别的过度自信,从而提高模型的泛化能力。
- 提出论文: Rethinking the Inception Architecture for Computer Vision
- 会议: CVPR 2016
Label Smoothing(标签平滑)的工作原理
1. 传统交叉熵损失
在分类问题中,交叉熵损失函数定义如下:
Cross Entropy Loss=−i=1∑Cyilog(pi)
其中:
- C 是类别总数。
- $y_i $是目标标签的 one-hot 编码(硬性标签),即正确类别的值为 1,其他类别的值为 0。
- pi 是模型预测的第 i 类的概率。
在硬性标签的情况下,模型会努力最大化正确类别的概率 pi,而完全忽略其他类别的概率。
2. Label Smoothing 的引入
Label Smoothing 将目标标签从硬性分布转换为软性分布,软标签具体公式如下:
yi′={1−ϵ,C−1ϵ,if i=true classotherwise
其中:
- yi′ 是平滑后的目标标签分布。
- ϵ 是平滑参数,通常取值在 [0, 1] 范围内(例如 0.1)。
- 正确类别的概率被降低为 1−ϵ ,而其他类别的概率被提升为 C−1ϵ 。
3. 平滑后的交叉熵损失
使用平滑后的标签分布yi′,交叉熵损失变为:
Smoothed Cross EntropyLoss=−i=1∑Cyi′log(pi)
展开后可以写为:
Smoothed Cross Entropy Loss=−(1−ϵ)log(ptrue)−i=true∑C−1ϵlog(pi)
其中:
- ptrue 是模型对正确类别的预测概率。
- ∑i=trueC−1ϵlog(pi)是对其他类别的惩罚项。
通过这种方式,模型不仅需要最大化正确类别的概率,还需要关注其他类别的预测结果,从而避免对单一类别的过度自信。
Label Smoothing 的优点
- 减少过拟合
Label Smoothing 防止模型对训练数据中的硬性标签过于依赖,从而提高了模型的泛化能力。
- 改善模型的校准
使用 Label Smoothing 后,模型的预测概率通常更加接近真实分布,而不是过度集中在某个类别上。
- 缓解标签噪声的影响
如果训练数据中的标签存在噪声,Label Smoothing 可以通过平滑标签分布来降低噪声对模型的影响。
- 增强模型的鲁棒性
在对抗攻击等场景下,Label Smoothing 可以使模型对输入扰动更加鲁棒。
Label Smoothing 的实现
以下是一个基于 PyTorch 的 Label Smoothing 实现示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
| import torch import torch.nn as nn import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, label_smoothing=0.1, class_weights=None, reduction='mean'): super(LabelSmoothingCrossEntropy, self).__init__() self.label_smoothing = label_smoothing self.reduction = reduction self.class_weights = class_weights
def forward(self, preds, targets): n_classes = preds.size(-1) log_preds = F.log_softmax(preds, dim=-1)
smooth_labels = torch.full_like(preds, self.label_smoothing / (n_classes - 1)) smooth_labels.scatter_(1, targets.unsqueeze(1), 1 - self.label_smoothing)
if self.class_weights is not None: loss = -(smooth_labels * log_preds * self.class_weights.unsqueeze(0)).sum(dim=-1) else: loss = -(smooth_labels * log_preds).sum(dim=-1)
if self.reduction == 'mean': return loss.mean() elif self.reduction == 'sum': return loss.sum() else: return loss
preds = torch.tensor([[2.0, 1.0, 0.1], [1.5, 2.5, 0.3]]) targets = torch.tensor([0, 1])
criterion = LabelSmoothingCrossEntropy(label_smoothing=0.1, class_weights= torch.as_tensor([1, 2, 1])) loss = criterion(preds, targets) print("Label Smoothing Cross Entropy Loss:", loss.item())
|
- PyTorch 1.10之后Cross Entropy Loss 已经原生支持标签平滑功能
1 2 3 4 5 6 7 8 9 10 11 12 13
| import torch from torch.nn import CrossEntropyLoss
preds = torch.tensor([[2.0, 1.0, 0.1], [1.5, 2.5, 0.3]]) targets = torch.tensor([0, 1])
criterion = CrossEntropyLoss(weight=None, label_smoothing=0.1)
loss = criterion(preds, targets) print("Label Smoothing Cross Entropy Loss:", loss.item())
|
Label Smoothing 的注意事项
- 选择合适的平滑参数 ϵ 的值通常在 0.1 左右。如果 ϵ 过大,可能会导致模型对正确类别的学习不足;如果 ϵ 过小,则效果可能不明显。
- 适用于大规模分类任务
Label Smoothing 在类别数量较多的任务中效果更显著,因为平滑后的分布能够更好地反映类别间的关联性。
- 与知识蒸馏结合
Label Smoothing 常与知识蒸馏(Knowledge Distillation)结合使用。通过使用教师模型生成软标签,学生模型可以学习到更加丰富的类别间关系。